sshrsa.c 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. /*
  2. * RSA implementation just sufficient for ssh client-side
  3. * initialisation step
  4. *
  5. * Rewritten for more speed by Joris van Rantwijk, Jun 1999.
  6. */
  7. #include <stdio.h>
  8. #include <stdlib.h>
  9. #include <string.h>
  10. #include "ssh.h"
  11. typedef unsigned short *Bignum;
  12. static unsigned short Zero[1] = { 0 };
  13. #if defined TESTMODE || defined RSADEBUG
  14. #ifndef DLVL
  15. #define DLVL 10000
  16. #endif
  17. #define debug(x) bndebug(#x,x)
  18. static int level = 0;
  19. static void bndebug(char *name, Bignum b) {
  20. int i;
  21. int w = 50-level-strlen(name)-5*b[0];
  22. if (level >= DLVL)
  23. return;
  24. if (w < 0) w = 0;
  25. dprintf("%*s%s%*s", level, "", name, w, "");
  26. for (i=b[0]; i>0; i--)
  27. dprintf(" %04x", b[i]);
  28. dprintf("\n");
  29. }
  30. #define dmsg(x) do {if(level<DLVL){dprintf("%*s",level,"");printf x;}} while(0)
  31. #define enter(x) do { dmsg(x); level += 4; } while(0)
  32. #define leave(x) do { level -= 4; dmsg(x); } while(0)
  33. #else
  34. #define debug(x)
  35. #define dmsg(x)
  36. #define enter(x)
  37. #define leave(x)
  38. #endif
  39. static Bignum newbn(int length) {
  40. Bignum b = malloc((length+1)*sizeof(unsigned short));
  41. if (!b)
  42. abort(); /* FIXME */
  43. b[0] = length;
  44. return b;
  45. }
  46. static void freebn(Bignum b) {
  47. free(b);
  48. }
  49. /*
  50. * Compute c = a * b.
  51. * Input is in the first len words of a and b.
  52. * Result is returned in the first 2*len words of c.
  53. */
  54. static void bigmul(unsigned short *a, unsigned short *b, unsigned short *c,
  55. int len)
  56. {
  57. int i, j;
  58. unsigned long ai, t;
  59. for (j = len - 1; j >= 0; j--)
  60. c[j+len] = 0;
  61. for (i = len - 1; i >= 0; i--) {
  62. ai = a[i];
  63. t = 0;
  64. for (j = len - 1; j >= 0; j--) {
  65. t += ai * (unsigned long) b[j];
  66. t += (unsigned long) c[i+j+1];
  67. c[i+j+1] = (unsigned short)t;
  68. t = t >> 16;
  69. }
  70. c[i] = (unsigned short)t;
  71. }
  72. }
  73. /*
  74. * Compute a = a % m.
  75. * Input in first 2*len words of a and first len words of m.
  76. * Output in first 2*len words of a (of which first len words will be zero).
  77. * The MSW of m MUST have its high bit set.
  78. */
  79. static void bigmod(unsigned short *a, unsigned short *m, int len)
  80. {
  81. unsigned short m0, m1;
  82. unsigned int h;
  83. int i, k;
  84. /* Special case for len == 1 */
  85. if (len == 1) {
  86. a[1] = (((long) a[0] << 16) + a[1]) % m[0];
  87. a[0] = 0;
  88. return;
  89. }
  90. m0 = m[0];
  91. m1 = m[1];
  92. for (i = 0; i <= len; i++) {
  93. unsigned long t;
  94. unsigned int q, r, c;
  95. if (i == 0) {
  96. h = 0;
  97. } else {
  98. h = a[i-1];
  99. a[i-1] = 0;
  100. }
  101. /* Find q = h:a[i] / m0 */
  102. t = ((unsigned long) h << 16) + a[i];
  103. q = t / m0;
  104. r = t % m0;
  105. /* Refine our estimate of q by looking at
  106. h:a[i]:a[i+1] / m0:m1 */
  107. t = (long) m1 * (long) q;
  108. if (t > ((unsigned long) r << 16) + a[i+1]) {
  109. q--;
  110. t -= m1;
  111. r = (r + m0) & 0xffff; /* overflow? */
  112. if (r >= (unsigned long)m0 &&
  113. t > ((unsigned long) r << 16) + a[i+1])
  114. q--;
  115. }
  116. /* Substract q * m from a[i...] */
  117. c = 0;
  118. for (k = len - 1; k >= 0; k--) {
  119. t = (long) q * (long) m[k];
  120. t += c;
  121. c = t >> 16;
  122. if ((unsigned short) t > a[i+k]) c++;
  123. a[i+k] -= (unsigned short) t;
  124. }
  125. /* Add back m in case of borrow */
  126. if (c != h) {
  127. t = 0;
  128. for (k = len - 1; k >= 0; k--) {
  129. t += m[k];
  130. t += a[i+k];
  131. a[i+k] = (unsigned short)t;
  132. t = t >> 16;
  133. }
  134. }
  135. }
  136. }
  137. /*
  138. * Compute (base ^ exp) % mod.
  139. * The base MUST be smaller than the modulus.
  140. * The most significant word of mod MUST be non-zero.
  141. * We assume that the result array is the same size as the mod array.
  142. */
  143. static void modpow(Bignum base, Bignum exp, Bignum mod, Bignum result)
  144. {
  145. unsigned short *a, *b, *n, *m;
  146. int mshift;
  147. int mlen, i, j;
  148. /* Allocate m of size mlen, copy mod to m */
  149. /* We use big endian internally */
  150. mlen = mod[0];
  151. m = malloc(mlen * sizeof(unsigned short));
  152. for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j];
  153. /* Shift m left to make msb bit set */
  154. for (mshift = 0; mshift < 15; mshift++)
  155. if ((m[0] << mshift) & 0x8000) break;
  156. if (mshift) {
  157. for (i = 0; i < mlen - 1; i++)
  158. m[i] = (m[i] << mshift) | (m[i+1] >> (16-mshift));
  159. m[mlen-1] = m[mlen-1] << mshift;
  160. }
  161. /* Allocate n of size mlen, copy base to n */
  162. n = malloc(mlen * sizeof(unsigned short));
  163. i = mlen - base[0];
  164. for (j = 0; j < i; j++) n[j] = 0;
  165. for (j = 0; j < base[0]; j++) n[i+j] = base[base[0] - j];
  166. /* Allocate a and b of size 2*mlen. Set a = 1 */
  167. a = malloc(2 * mlen * sizeof(unsigned short));
  168. b = malloc(2 * mlen * sizeof(unsigned short));
  169. for (i = 0; i < 2*mlen; i++) a[i] = 0;
  170. a[2*mlen-1] = 1;
  171. /* Skip leading zero bits of exp. */
  172. i = 0; j = 15;
  173. while (i < exp[0] && (exp[exp[0] - i] & (1 << j)) == 0) {
  174. j--;
  175. if (j < 0) { i++; j = 15; }
  176. }
  177. /* Main computation */
  178. while (i < exp[0]) {
  179. while (j >= 0) {
  180. bigmul(a + mlen, a + mlen, b, mlen);
  181. bigmod(b, m, mlen);
  182. if ((exp[exp[0] - i] & (1 << j)) != 0) {
  183. bigmul(b + mlen, n, a, mlen);
  184. bigmod(a, m, mlen);
  185. } else {
  186. unsigned short *t;
  187. t = a; a = b; b = t;
  188. }
  189. j--;
  190. }
  191. i++; j = 15;
  192. }
  193. /* Fixup result in case the modulus was shifted */
  194. if (mshift) {
  195. for (i = mlen - 1; i < 2*mlen - 1; i++)
  196. a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift));
  197. a[2*mlen-1] = a[2*mlen-1] << mshift;
  198. bigmod(a, m, mlen);
  199. for (i = 2*mlen - 1; i >= mlen; i--)
  200. a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift));
  201. }
  202. /* Copy result to buffer */
  203. for (i = 0; i < mlen; i++)
  204. result[result[0] - i] = a[i+mlen];
  205. /* Free temporary arrays */
  206. for (i = 0; i < 2*mlen; i++) a[i] = 0; free(a);
  207. for (i = 0; i < 2*mlen; i++) b[i] = 0; free(b);
  208. for (i = 0; i < mlen; i++) m[i] = 0; free(m);
  209. for (i = 0; i < mlen; i++) n[i] = 0; free(n);
  210. }
  211. int makekey(unsigned char *data, struct RSAKey *result,
  212. unsigned char **keystr) {
  213. unsigned char *p = data;
  214. Bignum bn[2];
  215. int i, j;
  216. int w, b;
  217. result->bits = 0;
  218. for (i=0; i<4; i++)
  219. result->bits = (result->bits << 8) + *p++;
  220. for (j=0; j<2; j++) {
  221. w = 0;
  222. for (i=0; i<2; i++)
  223. w = (w << 8) + *p++;
  224. result->bytes = b = (w+7)/8; /* bits -> bytes */
  225. w = (w+15)/16; /* bits -> words */
  226. bn[j] = newbn(w);
  227. if (keystr) *keystr = p; /* point at key string, second time */
  228. for (i=1; i<=w; i++)
  229. bn[j][i] = 0;
  230. for (i=b; i-- ;) {
  231. unsigned char byte = *p++;
  232. if (i & 1)
  233. bn[j][1+i/2] |= byte<<8;
  234. else
  235. bn[j][1+i/2] |= byte;
  236. }
  237. debug(bn[j]);
  238. }
  239. result->exponent = bn[0];
  240. result->modulus = bn[1];
  241. return p - data;
  242. }
  243. void rsaencrypt(unsigned char *data, int length, struct RSAKey *key) {
  244. Bignum b1, b2;
  245. int w, i;
  246. unsigned char *p;
  247. debug(key->exponent);
  248. memmove(data+key->bytes-length, data, length);
  249. data[0] = 0;
  250. data[1] = 2;
  251. for (i = 2; i < key->bytes-length-1; i++) {
  252. do {
  253. data[i] = random_byte();
  254. } while (data[i] == 0);
  255. }
  256. data[key->bytes-length-1] = 0;
  257. w = (key->bytes+1)/2;
  258. b1 = newbn(w);
  259. b2 = newbn(w);
  260. p = data;
  261. for (i=1; i<=w; i++)
  262. b1[i] = 0;
  263. for (i=key->bytes; i-- ;) {
  264. unsigned char byte = *p++;
  265. if (i & 1)
  266. b1[1+i/2] |= byte<<8;
  267. else
  268. b1[1+i/2] |= byte;
  269. }
  270. debug(b1);
  271. modpow(b1, key->exponent, key->modulus, b2);
  272. debug(b2);
  273. p = data;
  274. for (i=key->bytes; i-- ;) {
  275. unsigned char b;
  276. if (i & 1)
  277. b = b2[1+i/2] >> 8;
  278. else
  279. b = b2[1+i/2] & 0xFF;
  280. *p++ = b;
  281. }
  282. freebn(b1);
  283. freebn(b2);
  284. }
  285. int rsastr_len(struct RSAKey *key) {
  286. Bignum md, ex;
  287. md = key->modulus;
  288. ex = key->exponent;
  289. return 4 * (ex[0]+md[0]) + 10;
  290. }
  291. void rsastr_fmt(char *str, struct RSAKey *key) {
  292. Bignum md, ex;
  293. int len = 0, i;
  294. md = key->modulus;
  295. ex = key->exponent;
  296. for (i=1; i<=ex[0]; i++) {
  297. sprintf(str+len, "%04x", ex[i]);
  298. len += strlen(str+len);
  299. }
  300. str[len++] = '/';
  301. for (i=1; i<=md[0]; i++) {
  302. sprintf(str+len, "%04x", md[i]);
  303. len += strlen(str+len);
  304. }
  305. str[len] = '\0';
  306. }
  307. #ifdef TESTMODE
  308. #ifndef NODDY
  309. #define p1 10007
  310. #define p2 10069
  311. #define p3 10177
  312. #else
  313. #define p1 3
  314. #define p2 7
  315. #define p3 13
  316. #endif
  317. unsigned short P1[2] = { 1, p1 };
  318. unsigned short P2[2] = { 1, p2 };
  319. unsigned short P3[2] = { 1, p3 };
  320. unsigned short bigmod[5] = { 4, 0, 0, 0, 32768U };
  321. unsigned short mod[5] = { 4, 0, 0, 0, 0 };
  322. unsigned short a[5] = { 4, 0, 0, 0, 0 };
  323. unsigned short b[5] = { 4, 0, 0, 0, 0 };
  324. unsigned short c[5] = { 4, 0, 0, 0, 0 };
  325. unsigned short One[2] = { 1, 1 };
  326. unsigned short Two[2] = { 1, 2 };
  327. int main(void) {
  328. modmult(P1, P2, bigmod, a); debug(a);
  329. modmult(a, P3, bigmod, mod); debug(mod);
  330. sub(P1, One, a); debug(a);
  331. sub(P2, One, b); debug(b);
  332. modmult(a, b, bigmod, c); debug(c);
  333. sub(P3, One, a); debug(a);
  334. modmult(a, c, bigmod, b); debug(b);
  335. modpow(Two, b, mod, a); debug(a);
  336. return 0;
  337. }
  338. #endif