sshrsa.c 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. /*
  2. * RSA implementation just sufficient for ssh client-side
  3. * initialisation step
  4. *
  5. * Rewritten for more speed by Joris van Rantwijk, Jun 1999.
  6. */
  7. #include <stdio.h>
  8. #include <stdlib.h>
  9. #include <string.h>
  10. #include "ssh.h"
  11. typedef unsigned short *Bignum;
  12. static unsigned short Zero[1] = { 0 };
  13. #if defined TESTMODE || defined RSADEBUG
  14. #ifndef DLVL
  15. #define DLVL 10000
  16. #endif
  17. #define debug(x) bndebug(#x,x)
  18. static int level = 0;
  19. static void bndebug(char *name, Bignum b) {
  20. int i;
  21. int w = 50-level-strlen(name)-5*b[0];
  22. if (level >= DLVL)
  23. return;
  24. if (w < 0) w = 0;
  25. dprintf("%*s%s%*s", level, "", name, w, "");
  26. for (i=b[0]; i>0; i--)
  27. dprintf(" %04x", b[i]);
  28. dprintf("\n");
  29. }
  30. #define dmsg(x) do {if(level<DLVL){dprintf("%*s",level,"");printf x;}} while(0)
  31. #define enter(x) do { dmsg(x); level += 4; } while(0)
  32. #define leave(x) do { level -= 4; dmsg(x); } while(0)
  33. #else
  34. #define debug(x)
  35. #define dmsg(x)
  36. #define enter(x)
  37. #define leave(x)
  38. #endif
  39. static Bignum newbn(int length) {
  40. Bignum b = malloc((length+1)*sizeof(unsigned short));
  41. if (!b)
  42. abort(); /* FIXME */
  43. b[0] = length;
  44. return b;
  45. }
  46. static void freebn(Bignum b) {
  47. free(b);
  48. }
  49. /*
  50. * Compute c = a * b.
  51. * Input is in the first len words of a and b.
  52. * Result is returned in the first 2*len words of c.
  53. */
  54. static void bigmul(unsigned short *a, unsigned short *b, unsigned short *c,
  55. int len)
  56. {
  57. int i, j;
  58. unsigned long ai, t;
  59. for (j = len - 1; j >= 0; j--)
  60. c[j+len] = 0;
  61. for (i = len - 1; i >= 0; i--) {
  62. ai = a[i];
  63. t = 0;
  64. for (j = len - 1; j >= 0; j--) {
  65. t += ai * (unsigned long) b[j];
  66. t += (unsigned long) c[i+j+1];
  67. c[i+j+1] = (unsigned short)t;
  68. t = t >> 16;
  69. }
  70. c[i] = (unsigned short)t;
  71. }
  72. }
  73. /*
  74. * Compute a = a % m.
  75. * Input in first 2*len words of a and first len words of m.
  76. * Output in first 2*len words of a (of which first len words will be zero).
  77. * The MSW of m MUST have its high bit set.
  78. */
  79. static void bigmod(unsigned short *a, unsigned short *m, int len)
  80. {
  81. unsigned short m0, m1;
  82. unsigned int h;
  83. int i, k;
  84. /* Special case for len == 1 */
  85. if (len == 1) {
  86. a[1] = (((long) a[0] << 16) + a[1]) % m[0];
  87. a[0] = 0;
  88. return;
  89. }
  90. m0 = m[0];
  91. m1 = m[1];
  92. for (i = 0; i <= len; i++) {
  93. unsigned long t;
  94. unsigned int q, r, c;
  95. if (i == 0) {
  96. h = 0;
  97. } else {
  98. h = a[i-1];
  99. a[i-1] = 0;
  100. }
  101. /* Find q = h:a[i] / m0 */
  102. t = ((unsigned long) h << 16) + a[i];
  103. q = t / m0;
  104. r = t % m0;
  105. /* Refine our estimate of q by looking at
  106. h:a[i]:a[i+1] / m0:m1 */
  107. t = (long) m1 * (long) q;
  108. if (t > ((unsigned long) r << 16) + a[i+1]) {
  109. q--;
  110. t -= m1;
  111. r = (r + m0) & 0xffff; /* overflow? */
  112. if (r >= m0 && t > ((unsigned long) r << 16) + a[i+1])
  113. q--;
  114. }
  115. /* Substract q * m from a[i...] */
  116. c = 0;
  117. for (k = len - 1; k >= 0; k--) {
  118. t = (long) q * (long) m[k];
  119. t += c;
  120. c = t >> 16;
  121. if ((unsigned short) t > a[i+k]) c++;
  122. a[i+k] -= (unsigned short) t;
  123. }
  124. /* Add back m in case of borrow */
  125. if (c != h) {
  126. t = 0;
  127. for (k = len - 1; k >= 0; k--) {
  128. t += m[k];
  129. t += a[i+k];
  130. a[i+k] = (unsigned short)t;
  131. t = t >> 16;
  132. }
  133. }
  134. }
  135. }
  136. /*
  137. * Compute (base ^ exp) % mod.
  138. * The base MUST be smaller than the modulus.
  139. * The most significant word of mod MUST be non-zero.
  140. * We assume that the result array is the same size as the mod array.
  141. */
  142. static void modpow(Bignum base, Bignum exp, Bignum mod, Bignum result)
  143. {
  144. unsigned short *a, *b, *n, *m;
  145. int mshift;
  146. int mlen, i, j;
  147. /* Allocate m of size mlen, copy mod to m */
  148. /* We use big endian internally */
  149. mlen = mod[0];
  150. m = malloc(mlen * sizeof(unsigned short));
  151. for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j];
  152. /* Shift m left to make msb bit set */
  153. for (mshift = 0; mshift < 15; mshift++)
  154. if ((m[0] << mshift) & 0x8000) break;
  155. if (mshift) {
  156. for (i = 0; i < mlen - 1; i++)
  157. m[i] = (m[i] << mshift) | (m[i+1] >> (16-mshift));
  158. m[mlen-1] = m[mlen-1] << mshift;
  159. }
  160. /* Allocate n of size mlen, copy base to n */
  161. n = malloc(mlen * sizeof(unsigned short));
  162. i = mlen - base[0];
  163. for (j = 0; j < i; j++) n[j] = 0;
  164. for (j = 0; j < base[0]; j++) n[i+j] = base[base[0] - j];
  165. /* Allocate a and b of size 2*mlen. Set a = 1 */
  166. a = malloc(2 * mlen * sizeof(unsigned short));
  167. b = malloc(2 * mlen * sizeof(unsigned short));
  168. for (i = 0; i < 2*mlen; i++) a[i] = 0;
  169. a[2*mlen-1] = 1;
  170. /* Skip leading zero bits of exp. */
  171. i = 0; j = 15;
  172. while (i < exp[0] && (exp[exp[0] - i] & (1 << j)) == 0) {
  173. j--;
  174. if (j < 0) { i++; j = 15; }
  175. }
  176. /* Main computation */
  177. while (i < exp[0]) {
  178. while (j >= 0) {
  179. bigmul(a + mlen, a + mlen, b, mlen);
  180. bigmod(b, m, mlen);
  181. if ((exp[exp[0] - i] & (1 << j)) != 0) {
  182. bigmul(b + mlen, n, a, mlen);
  183. bigmod(a, m, mlen);
  184. } else {
  185. unsigned short *t;
  186. t = a; a = b; b = t;
  187. }
  188. j--;
  189. }
  190. i++; j = 15;
  191. }
  192. /* Fixup result in case the modulus was shifted */
  193. if (mshift) {
  194. for (i = mlen - 1; i < 2*mlen - 1; i++)
  195. a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift));
  196. a[2*mlen-1] = a[2*mlen-1] << mshift;
  197. bigmod(a, m, mlen);
  198. for (i = 2*mlen - 1; i >= mlen; i--)
  199. a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift));
  200. }
  201. /* Copy result to buffer */
  202. for (i = 0; i < mlen; i++)
  203. result[result[0] - i] = a[i+mlen];
  204. /* Free temporary arrays */
  205. for (i = 0; i < 2*mlen; i++) a[i] = 0; free(a);
  206. for (i = 0; i < 2*mlen; i++) b[i] = 0; free(b);
  207. for (i = 0; i < mlen; i++) m[i] = 0; free(m);
  208. for (i = 0; i < mlen; i++) n[i] = 0; free(n);
  209. }
  210. int makekey(unsigned char *data, struct RSAKey *result,
  211. unsigned char **keystr) {
  212. unsigned char *p = data;
  213. Bignum bn[2];
  214. int i, j;
  215. int w, b;
  216. result->bits = 0;
  217. for (i=0; i<4; i++)
  218. result->bits = (result->bits << 8) + *p++;
  219. for (j=0; j<2; j++) {
  220. w = 0;
  221. for (i=0; i<2; i++)
  222. w = (w << 8) + *p++;
  223. result->bytes = b = (w+7)/8; /* bits -> bytes */
  224. w = (w+15)/16; /* bits -> words */
  225. bn[j] = newbn(w);
  226. if (keystr) *keystr = p; /* point at key string, second time */
  227. for (i=1; i<=w; i++)
  228. bn[j][i] = 0;
  229. for (i=0; i<b; i++) {
  230. unsigned char byte = *p++;
  231. if ((b-i) & 1)
  232. bn[j][w-i/2] |= byte;
  233. else
  234. bn[j][w-i/2] |= byte<<8;
  235. }
  236. debug(bn[j]);
  237. }
  238. result->exponent = bn[0];
  239. result->modulus = bn[1];
  240. return p - data;
  241. }
  242. void rsaencrypt(unsigned char *data, int length, struct RSAKey *key) {
  243. Bignum b1, b2;
  244. int w, i;
  245. unsigned char *p;
  246. debug(key->exponent);
  247. memmove(data+key->bytes-length, data, length);
  248. data[0] = 0;
  249. data[1] = 2;
  250. for (i = 2; i < key->bytes-length-1; i++) {
  251. do {
  252. data[i] = random_byte();
  253. } while (data[i] == 0);
  254. }
  255. data[key->bytes-length-1] = 0;
  256. w = (key->bytes+1)/2;
  257. b1 = newbn(w);
  258. b2 = newbn(w);
  259. p = data;
  260. for (i=1; i<=w; i++)
  261. b1[i] = 0;
  262. for (i=0; i<key->bytes; i++) {
  263. unsigned char byte = *p++;
  264. if ((key->bytes-i) & 1)
  265. b1[w-i/2] |= byte;
  266. else
  267. b1[w-i/2] |= byte<<8;
  268. }
  269. debug(b1);
  270. modpow(b1, key->exponent, key->modulus, b2);
  271. debug(b2);
  272. p = data;
  273. for (i=0; i<key->bytes; i++) {
  274. unsigned char b;
  275. if (i & 1)
  276. b = b2[w-i/2] & 0xFF;
  277. else
  278. b = b2[w-i/2] >> 8;
  279. *p++ = b;
  280. }
  281. freebn(b1);
  282. freebn(b2);
  283. }
  284. int rsastr_len(struct RSAKey *key) {
  285. Bignum md, ex;
  286. md = key->modulus;
  287. ex = key->exponent;
  288. return 4 * (ex[0]+md[0]) + 10;
  289. }
  290. void rsastr_fmt(char *str, struct RSAKey *key) {
  291. Bignum md, ex;
  292. int len = 0, i;
  293. md = key->modulus;
  294. ex = key->exponent;
  295. for (i=1; i<=ex[0]; i++) {
  296. sprintf(str+len, "%04x", ex[i]);
  297. len += strlen(str+len);
  298. }
  299. str[len++] = '/';
  300. for (i=1; i<=md[0]; i++) {
  301. sprintf(str+len, "%04x", md[i]);
  302. len += strlen(str+len);
  303. }
  304. str[len] = '\0';
  305. }
  306. #ifdef TESTMODE
  307. #ifndef NODDY
  308. #define p1 10007
  309. #define p2 10069
  310. #define p3 10177
  311. #else
  312. #define p1 3
  313. #define p2 7
  314. #define p3 13
  315. #endif
  316. unsigned short P1[2] = { 1, p1 };
  317. unsigned short P2[2] = { 1, p2 };
  318. unsigned short P3[2] = { 1, p3 };
  319. unsigned short bigmod[5] = { 4, 0, 0, 0, 32768U };
  320. unsigned short mod[5] = { 4, 0, 0, 0, 0 };
  321. unsigned short a[5] = { 4, 0, 0, 0, 0 };
  322. unsigned short b[5] = { 4, 0, 0, 0, 0 };
  323. unsigned short c[5] = { 4, 0, 0, 0, 0 };
  324. unsigned short One[2] = { 1, 1 };
  325. unsigned short Two[2] = { 1, 2 };
  326. int main(void) {
  327. modmult(P1, P2, bigmod, a); debug(a);
  328. modmult(a, P3, bigmod, mod); debug(mod);
  329. sub(P1, One, a); debug(a);
  330. sub(P2, One, b); debug(b);
  331. modmult(a, b, bigmod, c); debug(c);
  332. sub(P3, One, a); debug(a);
  333. modmult(a, c, bigmod, b); debug(b);
  334. modpow(Two, b, mod, a); debug(a);
  335. return 0;
  336. }
  337. #endif