primecandidate.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. /*
  2. * primecandidate.c: implementation of the PrimeCandidateSource
  3. * abstraction declared in sshkeygen.h.
  4. */
  5. #include <assert.h>
  6. #include "ssh.h"
  7. #include "mpint.h"
  8. #include "mpunsafe.h"
  9. #include "sshkeygen.h"
  10. struct avoid {
  11. unsigned mod, res;
  12. };
  13. struct PrimeCandidateSource {
  14. unsigned bits;
  15. bool ready, try_sophie_germain;
  16. bool one_shot, thrown_away_my_shot;
  17. /* We'll start by making up a random number strictly less than this ... */
  18. mp_int *limit;
  19. /* ... then we'll multiply by 'factor', and add 'addend'. */
  20. mp_int *factor, *addend;
  21. /* Then we'll try to add a small multiple of 'factor' to it to
  22. * avoid it being a multiple of any small prime. Also, for RSA, we
  23. * may need to avoid it being _this_ multiple of _this_: */
  24. unsigned avoid_residue, avoid_modulus;
  25. /* Once we're actually running, this will be the complete list of
  26. * (modulus, residue) pairs we want to avoid. */
  27. struct avoid *avoids;
  28. size_t navoids, avoidsize;
  29. /* List of known primes that our number will be congruent to 1 modulo */
  30. mp_int **kps;
  31. size_t nkps, kpsize;
  32. };
  33. PrimeCandidateSource *pcs_new_with_firstbits(unsigned bits,
  34. unsigned first, unsigned nfirst)
  35. {
  36. PrimeCandidateSource *s = snew(PrimeCandidateSource);
  37. assert(first >> (nfirst-1) == 1);
  38. s->bits = bits;
  39. s->ready = false;
  40. s->try_sophie_germain = false;
  41. s->one_shot = false;
  42. s->thrown_away_my_shot = false;
  43. s->kps = NULL;
  44. s->nkps = s->kpsize = 0;
  45. s->avoids = NULL;
  46. s->navoids = s->avoidsize = 0;
  47. /* Make the number that's the lower limit of our range */
  48. mp_int *firstmp = mp_from_integer(first);
  49. mp_int *base = mp_lshift_fixed(firstmp, bits - nfirst);
  50. mp_free(firstmp);
  51. /* Set the low bit of that, because all (nontrivial) primes are odd */
  52. mp_set_bit(base, 0, 1);
  53. /* That's our addend. Now initialise factor to 2, to ensure we
  54. * only generate odd numbers */
  55. s->factor = mp_from_integer(2);
  56. s->addend = base;
  57. /* And that means the limit of our random numbers must be one
  58. * factor of two _less_ than the position of the low bit of
  59. * 'first', because we'll be multiplying the random number by
  60. * 2 immediately afterwards. */
  61. s->limit = mp_power_2(bits - nfirst - 1);
  62. /* avoid_modulus == 0 signals that there's no extra residue to avoid */
  63. s->avoid_residue = 1;
  64. s->avoid_modulus = 0;
  65. return s;
  66. }
  67. PrimeCandidateSource *pcs_new(unsigned bits)
  68. {
  69. return pcs_new_with_firstbits(bits, 1, 1);
  70. }
  71. void pcs_free(PrimeCandidateSource *s)
  72. {
  73. mp_free(s->limit);
  74. mp_free(s->factor);
  75. mp_free(s->addend);
  76. for (size_t i = 0; i < s->nkps; i++)
  77. mp_free(s->kps[i]);
  78. sfree(s->avoids);
  79. sfree(s->kps);
  80. sfree(s);
  81. }
  82. void pcs_try_sophie_germain(PrimeCandidateSource *s)
  83. {
  84. s->try_sophie_germain = true;
  85. }
  86. void pcs_set_oneshot(PrimeCandidateSource *s)
  87. {
  88. s->one_shot = true;
  89. }
  90. static void pcs_require_residue_inner(PrimeCandidateSource *s,
  91. mp_int *mod, mp_int *res)
  92. {
  93. /*
  94. * We already have a factor and addend. Ensure this one doesn't
  95. * contradict it.
  96. */
  97. mp_int *gcd = mp_gcd(mod, s->factor);
  98. mp_int *test1 = mp_mod(s->addend, gcd);
  99. mp_int *test2 = mp_mod(res, gcd);
  100. assert(mp_cmp_eq(test1, test2));
  101. mp_free(test1);
  102. mp_free(test2);
  103. /*
  104. * Reduce our input factor and addend, which are constraints on
  105. * the ultimate output number, so that they're constraints on the
  106. * initial cofactor we're going to make up.
  107. *
  108. * If we're generating x and we want to ensure ax+b == r (mod m),
  109. * how does that work? We've already checked that b == r modulo g
  110. * = gcd(a,m), i.e. r-b is a multiple of g, and so are a and m. So
  111. * let's write a=gA, m=gM, (r-b)=gR, and then we can start by
  112. * dividing that off:
  113. *
  114. * ax == r-b (mod m )
  115. * => gAx == gR (mod gM)
  116. * => Ax == R (mod M)
  117. *
  118. * Now the moduli A,M are coprime, which makes things easier.
  119. *
  120. * We're going to need to generate the x in this equation by
  121. * generating a new smaller value y, multiplying it by M, and
  122. * adding some constant K. So we have x = My + K, and we need to
  123. * work out what K will satisfy the above equation. In other
  124. * words, we need A(My+K) == R (mod M), and the AMy term vanishes,
  125. * so we just need AK == R (mod M). So our congruence is solved by
  126. * setting K to be R * A^{-1} mod M.
  127. */
  128. mp_int *A = mp_div(s->factor, gcd);
  129. mp_int *M = mp_div(mod, gcd);
  130. mp_int *Rpre = mp_modsub(res, s->addend, mod);
  131. mp_int *R = mp_div(Rpre, gcd);
  132. mp_int *Ainv = mp_invert(A, M);
  133. mp_int *K = mp_modmul(R, Ainv, M);
  134. mp_free(gcd);
  135. mp_free(Rpre);
  136. mp_free(Ainv);
  137. mp_free(A);
  138. mp_free(R);
  139. /*
  140. * So we know we have to transform our existing (factor, addend)
  141. * pair into (factor * M, addend * factor * K). Now we just need
  142. * to work out what the limit should be on the random value we're
  143. * generating.
  144. *
  145. * If we need My+K < old_limit, then y < (old_limit-K)/M. But the
  146. * RHS is a fraction, so in integers, we need y < ceil of it.
  147. */
  148. assert(!mp_cmp_hs(K, s->limit));
  149. mp_int *dividend = mp_add(s->limit, M);
  150. mp_sub_integer_into(dividend, dividend, 1);
  151. mp_sub_into(dividend, dividend, K);
  152. mp_free(s->limit);
  153. s->limit = mp_div(dividend, M);
  154. mp_free(dividend);
  155. /*
  156. * Now just update the real factor and addend, and we're done.
  157. */
  158. mp_int *addend_old = s->addend;
  159. mp_int *tmp = mp_mul(s->factor, K); /* use the _old_ value of factor */
  160. s->addend = mp_add(s->addend, tmp);
  161. mp_free(tmp);
  162. mp_free(addend_old);
  163. mp_int *factor_old = s->factor;
  164. s->factor = mp_mul(s->factor, M);
  165. mp_free(factor_old);
  166. mp_free(M);
  167. mp_free(K);
  168. s->factor = mp_unsafe_shrink(s->factor);
  169. s->addend = mp_unsafe_shrink(s->addend);
  170. s->limit = mp_unsafe_shrink(s->limit);
  171. }
  172. void pcs_require_residue(PrimeCandidateSource *s,
  173. mp_int *mod, mp_int *res_orig)
  174. {
  175. /*
  176. * Reduce the input residue to its least non-negative value, in
  177. * case it was given as a larger equivalent value.
  178. */
  179. mp_int *res_reduced = mp_mod(res_orig, mod);
  180. pcs_require_residue_inner(s, mod, res_reduced);
  181. mp_free(res_reduced);
  182. }
  183. void pcs_require_residue_1(PrimeCandidateSource *s, mp_int *mod)
  184. {
  185. mp_int *res = mp_from_integer(1);
  186. pcs_require_residue(s, mod, res);
  187. mp_free(res);
  188. }
  189. void pcs_require_residue_1_mod_prime(PrimeCandidateSource *s, mp_int *mod)
  190. {
  191. pcs_require_residue_1(s, mod);
  192. sgrowarray(s->kps, s->kpsize, s->nkps);
  193. s->kps[s->nkps++] = mp_copy(mod);
  194. }
  195. void pcs_avoid_residue_small(PrimeCandidateSource *s,
  196. unsigned mod, unsigned res)
  197. {
  198. assert(!s->avoid_modulus); /* can't cope with more than one */
  199. s->avoid_modulus = mod;
  200. s->avoid_residue = res % mod; /* reduce, just in case */
  201. }
  202. static int avoid_cmp(const void *av, const void *bv)
  203. {
  204. const struct avoid *a = (const struct avoid *)av;
  205. const struct avoid *b = (const struct avoid *)bv;
  206. return a->mod < b->mod ? -1 : a->mod > b->mod ? +1 : 0;
  207. }
  208. static uint64_t invert(uint64_t a, uint64_t m)
  209. {
  210. int64_t v0 = a, i0 = 1;
  211. int64_t v1 = m, i1 = 0;
  212. while (v0) {
  213. int64_t tmp, q = v1 / v0;
  214. tmp = v0; v0 = v1 - q*v0; v1 = tmp;
  215. tmp = i0; i0 = i1 - q*i0; i1 = tmp;
  216. }
  217. assert(v1 == 1 || v1 == -1);
  218. return i1 * v1;
  219. }
  220. void pcs_ready(PrimeCandidateSource *s)
  221. {
  222. /*
  223. * List all the small (modulus, residue) pairs we want to avoid.
  224. */
  225. init_smallprimes();
  226. #define ADD_AVOID(newmod, newres) do { \
  227. sgrowarray(s->avoids, s->avoidsize, s->navoids); \
  228. s->avoids[s->navoids].mod = (newmod); \
  229. s->avoids[s->navoids].res = (newres); \
  230. s->navoids++; \
  231. } while (0)
  232. unsigned limit = (mp_hs_integer(s->addend, 65536) ? 65536 :
  233. mp_get_integer(s->addend));
  234. /*
  235. * Don't be divisible by any small prime, or at least, any prime
  236. * smaller than our output number might actually manage to be. (If
  237. * asked to generate a really small prime, it would be
  238. * embarrassing to rule out legitimate answers on the grounds that
  239. * they were divisible by themselves.)
  240. */
  241. for (size_t i = 0; i < NSMALLPRIMES && smallprimes[i] < limit; i++)
  242. ADD_AVOID(smallprimes[i], 0);
  243. if (s->try_sophie_germain) {
  244. /*
  245. * If we're aiming to generate a Sophie Germain prime (i.e. p
  246. * such that 2p+1 is also prime), then we also want to ensure
  247. * 2p+1 is not congruent to 0 mod any small prime, because if
  248. * it is, we'll waste a lot of time generating a p for which
  249. * 2p+1 can't possibly work. So we have to avoid an extra
  250. * residue mod each odd q.
  251. *
  252. * We can simplify: 2p+1 == 0 (mod q)
  253. * => 2p == -1 (mod q)
  254. * => p == -2^{-1} (mod q)
  255. *
  256. * There's no need to do Euclid's algorithm to compute those
  257. * inverses, because for any odd q, the modular inverse of -2
  258. * mod q is just (q-1)/2. (Proof: multiplying it by -2 gives
  259. * 1-q, which is congruent to 1 mod q.)
  260. */
  261. for (size_t i = 0; i < NSMALLPRIMES && smallprimes[i] < limit; i++)
  262. if (smallprimes[i] != 2)
  263. ADD_AVOID(smallprimes[i], (smallprimes[i] - 1) / 2);
  264. }
  265. /*
  266. * Finally, if there's a particular modulus and residue we've been
  267. * told to avoid, put it on the list.
  268. */
  269. if (s->avoid_modulus)
  270. ADD_AVOID(s->avoid_modulus, s->avoid_residue);
  271. #undef ADD_AVOID
  272. /*
  273. * Sort our to-avoid list by modulus. Partly this is so that we'll
  274. * check the smaller moduli first during the live runs, which lets
  275. * us spot most failing cases earlier rather than later. Also, it
  276. * brings equal moduli together, so that we can reuse the residue
  277. * we computed from a previous one.
  278. */
  279. qsort(s->avoids, s->navoids, sizeof(*s->avoids), avoid_cmp);
  280. /*
  281. * Next, adjust each of these moduli to take account of our factor
  282. * and addend. If we want factor*x+addend to avoid being congruent
  283. * to 'res' modulo 'mod', then x itself must avoid being congruent
  284. * to (res - addend) * factor^{-1}.
  285. *
  286. * If factor == 0 modulo mod, then the answer will have a fixed
  287. * residue anyway, so we can discard it from our list to test.
  288. */
  289. int64_t factor_m = 0, addend_m = 0, last_mod = 0;
  290. size_t out = 0;
  291. for (size_t i = 0; i < s->navoids; i++) {
  292. int64_t mod = s->avoids[i].mod, res = s->avoids[i].res;
  293. if (mod != last_mod) {
  294. last_mod = mod;
  295. addend_m = mp_mod_known_integer(s->addend, mod);
  296. factor_m = mp_mod_known_integer(s->factor, mod);
  297. }
  298. if (factor_m == 0) {
  299. assert(res != addend_m);
  300. continue;
  301. }
  302. res = (res - addend_m) * invert(factor_m, mod);
  303. res %= mod;
  304. if (res < 0)
  305. res += mod;
  306. s->avoids[out].mod = mod;
  307. s->avoids[out].res = res;
  308. out++;
  309. }
  310. s->navoids = out;
  311. s->ready = true;
  312. }
  313. mp_int *pcs_generate(PrimeCandidateSource *s)
  314. {
  315. assert(s->ready);
  316. if (s->one_shot) {
  317. if (s->thrown_away_my_shot)
  318. return NULL;
  319. s->thrown_away_my_shot = true;
  320. }
  321. while (true) {
  322. mp_int *x = mp_random_upto(s->limit);
  323. int64_t x_res = 0, last_mod = 0;
  324. bool ok = true;
  325. for (size_t i = 0; i < s->navoids; i++) {
  326. int64_t mod = s->avoids[i].mod, avoid_res = s->avoids[i].res;
  327. if (mod != last_mod) {
  328. last_mod = mod;
  329. x_res = mp_mod_known_integer(x, mod);
  330. }
  331. if (x_res == avoid_res) {
  332. ok = false;
  333. break;
  334. }
  335. }
  336. if (!ok) {
  337. mp_free(x);
  338. if (s->one_shot)
  339. return NULL;
  340. continue; /* try a new x */
  341. }
  342. /*
  343. * We've found a viable x. Make the final output value.
  344. */
  345. mp_int *toret = mp_new(s->bits);
  346. mp_mul_into(toret, x, s->factor);
  347. mp_add_into(toret, toret, s->addend);
  348. mp_free(x);
  349. return toret;
  350. }
  351. }
  352. void pcs_inspect(PrimeCandidateSource *pcs, mp_int **limit_out,
  353. mp_int **factor_out, mp_int **addend_out)
  354. {
  355. *limit_out = mp_copy(pcs->limit);
  356. *factor_out = mp_copy(pcs->factor);
  357. *addend_out = mp_copy(pcs->addend);
  358. }
  359. unsigned pcs_get_bits(PrimeCandidateSource *pcs)
  360. {
  361. return pcs->bits;
  362. }
  363. unsigned pcs_get_bits_remaining(PrimeCandidateSource *pcs)
  364. {
  365. return mp_get_nbits(pcs->limit);
  366. }
  367. mp_int *pcs_get_upper_bound(PrimeCandidateSource *pcs)
  368. {
  369. /* Compute (limit-1) * factor + addend */
  370. mp_int *tmp = mp_mul(pcs->limit, pcs->factor);
  371. mp_int *bound = mp_add(tmp, pcs->addend);
  372. mp_free(tmp);
  373. mp_sub_into(bound, bound, pcs->factor);
  374. return bound;
  375. }
  376. mp_int **pcs_get_known_prime_factors(PrimeCandidateSource *pcs, size_t *nout)
  377. {
  378. *nout = pcs->nkps;
  379. return pcs->kps;
  380. }