numbertheory.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. import sys
  2. import numbers
  3. import itertools
  4. import unittest
  5. assert sys.version_info[:2] >= (3,0), "This is Python 3 code"
  6. def invert(a, b):
  7. "Multiplicative inverse of a mod b. a,b must be coprime."
  8. A = (a, 1, 0)
  9. B = (b, 0, 1)
  10. while B[0]:
  11. q = A[0] // B[0]
  12. A, B = B, tuple(Ai - q*Bi for Ai, Bi in zip(A, B))
  13. assert abs(A[0]) == 1
  14. return A[1]*A[0] % b
  15. def jacobi(n,m):
  16. """Compute the Jacobi symbol.
  17. The special case of this when m is prime is the Legendre symbol,
  18. which is 0 if n is congruent to 0 mod m; 1 if n is congruent to a
  19. non-zero square number mod m; -1 if n is not congruent to any
  20. square mod m.
  21. """
  22. assert m & 1
  23. acc = 1
  24. while True:
  25. n %= m
  26. if n == 0:
  27. return 0
  28. while not (n & 1):
  29. n >>= 1
  30. if (m & 7) not in {1,7}:
  31. acc *= -1
  32. if n == 1:
  33. return acc
  34. if (n & 3) == 3 and (m & 3) == 3:
  35. acc *= -1
  36. n, m = m, n
  37. class CyclicGroupRootFinder(object):
  38. """Class for finding rth roots in a cyclic group. r must be prime."""
  39. # Basic strategy:
  40. #
  41. # We write |G| = r^k u, with u coprime to r. This gives us a
  42. # nested sequence of subgroups G = G_0 > G_1 > ... > G_k, each
  43. # with index r in its predecessor. G_0 is the whole group, and the
  44. # innermost G_k has order u.
  45. #
  46. # Within G_k, you can take an rth root by raising an element to
  47. # the power of (r^{-1} mod u). If k=0 (so G = G_0 = G_k) then
  48. # that's all that's needed: every element has a unique rth root.
  49. # But if k>0, then things go differently.
  50. #
  51. # Define the 'rank' of an element g as the highest i such that
  52. # g \in G_i. Elements of rank 0 are the non-rth-powers: they don't
  53. # even _have_ an rth root. Elements of rank k are the easy ones to
  54. # take rth roots of, as above.
  55. #
  56. # In between, you can follow an inductive process, as long as you
  57. # know one element z of rank 0. Suppose we're trying to take the
  58. # rth root of some g with rank i. Repeatedly multiply g by z^{r^i}
  59. # until its rank increases; then take the root of that
  60. # (recursively), and divide off z^{r^{i-1}} once you're done.
  61. def __init__(self, r, order):
  62. self.order = order # order of G
  63. self.r = r
  64. self.k = next(k for k in itertools.count()
  65. if self.order % (r**(k+1)) != 0)
  66. self.u = self.order // (r**self.k)
  67. self.z = next(z for z in self.iter_elements()
  68. if self.index(z) == 0)
  69. self.zinv = self.inverse(self.z)
  70. self.root_power = invert(self.r, self.u) if self.u > 1 else 0
  71. self.roots_of_unity = {self.identity()}
  72. if self.k > 0:
  73. exponent = self.order // self.r
  74. for z in self.iter_elements():
  75. root_of_unity = self.pow(z, exponent)
  76. if root_of_unity not in self.roots_of_unity:
  77. self.roots_of_unity.add(root_of_unity)
  78. if len(self.roots_of_unity) == r:
  79. break
  80. def index(self, g):
  81. h = self.pow(g, self.u)
  82. for i in range(self.k+1):
  83. if h == self.identity():
  84. return self.k - i
  85. h = self.pow(h, self.r)
  86. assert False, ("Not a cyclic group! Raising {} to u r^k should give e."
  87. .format(g))
  88. def all_roots(self, g):
  89. try:
  90. r = self.root(g)
  91. except ValueError:
  92. return []
  93. return {r * rou for rou in self.roots_of_unity}
  94. def root(self, g):
  95. i = self.index(g)
  96. if i == 0 and self.k > 0:
  97. raise ValueError("{} has no {}th root".format(g, self.r))
  98. out = self.root_recurse(g, i)
  99. assert self.pow(out, self.r) == g
  100. return out
  101. def root_recurse(self, g, i):
  102. if i == self.k:
  103. return self.pow(g, self.root_power)
  104. z_in = self.pow(self.z, self.r**i)
  105. z_out = self.pow(self.zinv, self.r**(i-1))
  106. adjust = self.identity()
  107. while True:
  108. g = self.mul(g, z_in)
  109. adjust = self.mul(adjust, z_out)
  110. i2 = self.index(g)
  111. if i2 > i:
  112. return self.mul(self.root_recurse(g, i2), adjust)
  113. class AdditiveGroupRootFinder(CyclicGroupRootFinder):
  114. """Trivial test subclass for CyclicGroupRootFinder.
  115. Represents a cyclic group of any order additively, as the integers
  116. mod n under addition. This makes root-finding trivial without
  117. having to use the complicated algorithm above, and therefore it's
  118. a good way to test the complicated algorithm under conditions
  119. where the right answers are obvious."""
  120. def __init__(self, r, order):
  121. super().__init__(r, order)
  122. def mul(self, x, y):
  123. return (x + y) % self.order
  124. def pow(self, x, n):
  125. return (x * n) % self.order
  126. def inverse(self, x):
  127. return (-x) % self.order
  128. def identity(self):
  129. return 0
  130. def iter_elements(self):
  131. return range(self.order)
  132. class TestCyclicGroupRootFinder(unittest.TestCase):
  133. def testRootFinding(self):
  134. for order in 10, 11, 12, 18:
  135. grf = AdditiveGroupRootFinder(3, order)
  136. for i in range(order):
  137. try:
  138. r = grf.root(i)
  139. except ValueError:
  140. r = None
  141. if order % 3 == 0 and i % 3 != 0:
  142. self.assertEqual(r, None)
  143. else:
  144. self.assertEqual(r*3 % order, i)
  145. class RootModP(CyclicGroupRootFinder):
  146. """The live class that can take rth roots mod a prime."""
  147. def __init__(self, r, p):
  148. self.modulus = p
  149. super().__init__(r, p-1)
  150. def mul(self, x, y):
  151. return (x * y) % self.modulus
  152. def pow(self, x, n):
  153. return pow(x, n, self.modulus)
  154. def inverse(self, x):
  155. return invert(x, self.modulus)
  156. def identity(self):
  157. return 1
  158. def iter_elements(self):
  159. return range(1, self.modulus)
  160. def root(self, g):
  161. return 0 if g == 0 else super().root(g)
  162. class ModP(object):
  163. """Class that represents integers mod p as a field.
  164. All the usual arithmetic operations are supported directly,
  165. including division, so you can write formulas in a natural way
  166. without having to keep saying '% p' everywhere or call a
  167. cumbersome modular_inverse() function.
  168. """
  169. def __init__(self, p, n=0):
  170. self.p = p
  171. if isinstance(n, type(self)):
  172. self.check(n)
  173. n = n.n
  174. self.n = n % p
  175. def check(self, other):
  176. assert isinstance(other, type(self))
  177. assert isinstance(self, type(other))
  178. assert self.p == other.p
  179. def coerce_to(self, other):
  180. if not isinstance(other, type(self)):
  181. other = type(self)(self.p, other)
  182. else:
  183. self.check(other)
  184. return other
  185. def __int__(self):
  186. return self.n
  187. def __add__(self, rhs):
  188. rhs = self.coerce_to(rhs)
  189. return type(self)(self.p, (self.n + rhs.n) % self.p)
  190. def __neg__(self):
  191. return type(self)(self.p, -self.n % self.p)
  192. def __radd__(self, rhs):
  193. rhs = self.coerce_to(rhs)
  194. return type(self)(self.p, (self.n + rhs.n) % self.p)
  195. def __sub__(self, rhs):
  196. rhs = self.coerce_to(rhs)
  197. return type(self)(self.p, (self.n - rhs.n) % self.p)
  198. def __rsub__(self, rhs):
  199. rhs = self.coerce_to(rhs)
  200. return type(self)(self.p, (rhs.n - self.n) % self.p)
  201. def __mul__(self, rhs):
  202. rhs = self.coerce_to(rhs)
  203. return type(self)(self.p, (self.n * rhs.n) % self.p)
  204. def __rmul__(self, rhs):
  205. rhs = self.coerce_to(rhs)
  206. return type(self)(self.p, (self.n * rhs.n) % self.p)
  207. def __div__(self, rhs):
  208. rhs = self.coerce_to(rhs)
  209. return type(self)(self.p, (self.n * invert(rhs.n, self.p)) % self.p)
  210. def __rdiv__(self, rhs):
  211. rhs = self.coerce_to(rhs)
  212. return type(self)(self.p, (rhs.n * invert(self.n, self.p)) % self.p)
  213. def __truediv__(self, rhs): return self.__div__(rhs)
  214. def __rtruediv__(self, rhs): return self.__rdiv__(rhs)
  215. def __pow__(self, exponent):
  216. assert exponent >= 0
  217. n, b_to_n = 1, self
  218. total = type(self)(self.p, 1)
  219. while True:
  220. if exponent & n:
  221. exponent -= n
  222. total *= b_to_n
  223. n *= 2
  224. if n > exponent:
  225. break
  226. b_to_n *= b_to_n
  227. return total
  228. def __cmp__(self, rhs):
  229. rhs = self.coerce_to(rhs)
  230. return cmp(self.n, rhs.n)
  231. def __eq__(self, rhs):
  232. rhs = self.coerce_to(rhs)
  233. return self.n == rhs.n
  234. def __ne__(self, rhs):
  235. rhs = self.coerce_to(rhs)
  236. return self.n != rhs.n
  237. def __lt__(self, rhs):
  238. raise ValueError("Elements of a modular ring have no ordering")
  239. def __le__(self, rhs):
  240. raise ValueError("Elements of a modular ring have no ordering")
  241. def __gt__(self, rhs):
  242. raise ValueError("Elements of a modular ring have no ordering")
  243. def __ge__(self, rhs):
  244. raise ValueError("Elements of a modular ring have no ordering")
  245. def __str__(self):
  246. return "0x{:x}".format(self.n)
  247. def __repr__(self):
  248. return "{}(0x{:x},0x{:x})".format(type(self).__name__, self.p, self.n)
  249. def __hash__(self):
  250. return hash((type(self).__name__, self.p, self.n))
  251. class QuadraticFieldExtensionModP(object):
  252. """Class representing Z_p[sqrt(d)] for a given non-square d.
  253. """
  254. def __init__(self, p, d, n=0, m=0):
  255. self.p = p
  256. self.d = d
  257. if isinstance(n, ModP):
  258. assert self.p == n.p
  259. n = n.n
  260. if isinstance(m, ModP):
  261. assert self.p == m.p
  262. m = m.n
  263. if isinstance(n, type(self)):
  264. self.check(n)
  265. m += n.m
  266. n = n.n
  267. self.n = n % p
  268. self.m = m % p
  269. @classmethod
  270. def constructor(cls, p, d):
  271. return lambda *args: cls(p, d, *args)
  272. def check(self, other):
  273. assert isinstance(other, type(self))
  274. assert isinstance(self, type(other))
  275. assert self.p == other.p
  276. assert self.d == other.d
  277. def coerce_to(self, other):
  278. if not isinstance(other, type(self)):
  279. other = type(self)(self.p, self.d, other)
  280. else:
  281. self.check(other)
  282. return other
  283. def __int__(self):
  284. if self.m != 0:
  285. raise ValueError("Can't coerce a non-element of Z_{} to integer"
  286. .format(self.p))
  287. return int(self.n)
  288. def __add__(self, rhs):
  289. rhs = self.coerce_to(rhs)
  290. return type(self)(self.p, self.d,
  291. (self.n + rhs.n) % self.p,
  292. (self.m + rhs.m) % self.p)
  293. def __neg__(self):
  294. return type(self)(self.p, self.d,
  295. -self.n % self.p,
  296. -self.m % self.p)
  297. def __radd__(self, rhs):
  298. rhs = self.coerce_to(rhs)
  299. return type(self)(self.p, self.d,
  300. (self.n + rhs.n) % self.p,
  301. (self.m + rhs.m) % self.p)
  302. def __sub__(self, rhs):
  303. rhs = self.coerce_to(rhs)
  304. return type(self)(self.p, self.d,
  305. (self.n - rhs.n) % self.p,
  306. (self.m - rhs.m) % self.p)
  307. def __rsub__(self, rhs):
  308. rhs = self.coerce_to(rhs)
  309. return type(self)(self.p, self.d,
  310. (rhs.n - self.n) % self.p,
  311. (rhs.m - self.m) % self.p)
  312. def __mul__(self, rhs):
  313. rhs = self.coerce_to(rhs)
  314. n, m, N, M = self.n, self.m, rhs.n, rhs.m
  315. return type(self)(self.p, self.d,
  316. (n*N + self.d*m*M) % self.p,
  317. (n*M + m*N) % self.p)
  318. def __rmul__(self, rhs):
  319. return self.__mul__(rhs)
  320. def __div__(self, rhs):
  321. rhs = self.coerce_to(rhs)
  322. n, m, N, M = self.n, self.m, rhs.n, rhs.m
  323. # (n+m sqrt d)/(N+M sqrt d) = (n+m sqrt d)(N-M sqrt d)/(N^2-dM^2)
  324. denom = (N*N - self.d*M*M) % self.p
  325. if denom == 0:
  326. raise ValueError("division by zero")
  327. recipdenom = invert(denom, self.p)
  328. return type(self)(self.p, self.d,
  329. (n*N - self.d*m*M) * recipdenom % self.p,
  330. (m*N - n*M) * recipdenom % self.p)
  331. def __rdiv__(self, rhs):
  332. rhs = self.coerce_to(rhs)
  333. return rhs.__div__(self)
  334. def __truediv__(self, rhs): return self.__div__(rhs)
  335. def __rtruediv__(self, rhs): return self.__rdiv__(rhs)
  336. def __pow__(self, exponent):
  337. assert exponent >= 0
  338. n, b_to_n = 1, self
  339. total = type(self)(self.p, self.d, 1)
  340. while True:
  341. if exponent & n:
  342. exponent -= n
  343. total *= b_to_n
  344. n *= 2
  345. if n > exponent:
  346. break
  347. b_to_n *= b_to_n
  348. return total
  349. def __cmp__(self, rhs):
  350. rhs = self.coerce_to(rhs)
  351. return cmp((self.n, self.m), (rhs.n, rhs.m))
  352. def __eq__(self, rhs):
  353. rhs = self.coerce_to(rhs)
  354. return self.n == rhs.n and self.m == rhs.m
  355. def __ne__(self, rhs):
  356. rhs = self.coerce_to(rhs)
  357. return self.n != rhs.n or self.m != rhs.m
  358. def __lt__(self, rhs):
  359. raise ValueError("Elements of a modular ring have no ordering")
  360. def __le__(self, rhs):
  361. raise ValueError("Elements of a modular ring have no ordering")
  362. def __gt__(self, rhs):
  363. raise ValueError("Elements of a modular ring have no ordering")
  364. def __ge__(self, rhs):
  365. raise ValueError("Elements of a modular ring have no ordering")
  366. def __str__(self):
  367. if self.m == 0:
  368. return "0x{:x}".format(self.n)
  369. else:
  370. return "0x{:x}+0x{:x}*sqrt({:d})".format(self.n, self.m, self.d)
  371. def __repr__(self):
  372. return "{}(0x{:x},0x{:x},0x{:x},0x{:x})".format(
  373. type(self).__name__, self.p, self.d, self.n, self.m)
  374. def __hash__(self):
  375. return hash((type(self).__name__, self.p, self.d, self.n, self.m))
  376. class RootInQuadraticExtension(CyclicGroupRootFinder):
  377. """Take rth roots in the quadratic extension of Z_p."""
  378. def __init__(self, r, p, d):
  379. self.modulus = p
  380. self.constructor = QuadraticFieldExtensionModP.constructor(p, d)
  381. super().__init__(r, p*p-1)
  382. def mul(self, x, y):
  383. return x * y
  384. def pow(self, x, n):
  385. return x ** n
  386. def inverse(self, x):
  387. return 1/x
  388. def identity(self):
  389. return self.constructor(1, 0)
  390. def iter_elements(self):
  391. p = self.modulus
  392. for n_plus_m in range(1, 2*p-1):
  393. n_min = max(0, n_plus_m-(p-1))
  394. n_max = min(p-1, n_plus_m)
  395. for n in range(n_min, n_max + 1):
  396. m = n_plus_m - n
  397. assert(0 <= n < p)
  398. assert(0 <= m < p)
  399. assert(n != 0 or m != 0)
  400. yield self.constructor(n, m)
  401. def root(self, g):
  402. return 0 if g == 0 else super().root(g)
  403. class EquationSolverModP(object):
  404. """Class that can solve quadratics, cubics and quartics over Z_p.
  405. p must be a nontrivial prime (bigger than 3).
  406. """
  407. # This is a port to Z_p of reasonably standard algorithms for
  408. # solving quadratics, cubics and quartics over the reals.
  409. #
  410. # When you solve a cubic in R, you sometimes have to deal with
  411. # intermediate results that are complex numbers. In particular,
  412. # you have to solve a quadratic whose coefficients are in R but
  413. # its roots may be complex, and then having solved that quadratic,
  414. # you need to iterate over all three cube roots of the solution in
  415. # order to recover all the roots of your cubic. (Even if the cubic
  416. # ends up having three real roots, you can't calculate them
  417. # without going through those complex intermediate values.)
  418. #
  419. # So over Z_p, the same thing applies: we're going to need to be
  420. # able to solve any quadratic with coefficients in Z_p, even if
  421. # its discriminant turns out not to be a quadratic residue mod p,
  422. # and then we'll need to find _three_ cube roots of the result,
  423. # even if p == 2 (mod 3) so that numbers only have one cube root
  424. # each.
  425. #
  426. # Both of these problems can be solved at once if we work in the
  427. # finite field GF(p^2), i.e. make a quadratic field extension of
  428. # Z_p by adjoining a square root of some non-square d. The
  429. # multiplicative group of GF(p^2) is cyclic and has order p^2-1 =
  430. # (p-1)(p+1), with the mult group of Z_p forming the unique
  431. # subgroup of order (p-1) within it. So we've multiplied the group
  432. # order by p+1, which is even (since by assumption p > 3), and
  433. # therefore a square root is now guaranteed to exist for every
  434. # number in the Z_p subgroup. Moreover, no matter whether p itself
  435. # was congruent to 1 or 2 mod 3, p^2 is always congruent to 1,
  436. # which means that the mult group of GF(p^2) has order divisible
  437. # by 3. So there are guaranteed to be three distinct cube roots of
  438. # unity, and hence, three cube roots of any number that's a cube
  439. # at all.
  440. #
  441. # Quartics don't introduce any additional problems. To solve a
  442. # quartic, you factorise it into two quadratic factors, by solving
  443. # a cubic to find one of the coefficients. So if you can already
  444. # solve cubics, then you're more or less done. The only wrinkle is
  445. # that the two quadratic factors will have coefficients in GF(p^2)
  446. # but not necessarily in Z_p. But that doesn't stop us at least
  447. # _trying_ to solve them by taking square roots in GF(p^2) - and
  448. # if the discriminant of one of those quadratics has is not a
  449. # square even in GF(p^2), then its solutions will only exist if
  450. # you escalate further to GF(p^4), in which case the answer is
  451. # simply that there aren't any solutions in Z_p to that quadratic.
  452. def __init__(self, p):
  453. self.p = p
  454. self.nonsquare_mod_p = d = RootModP(2, p).z
  455. self.constructor = QuadraticFieldExtensionModP.constructor(p, d)
  456. self.sqrt = RootInQuadraticExtension(2, p, d)
  457. self.cbrt = RootInQuadraticExtension(3, p, d)
  458. def solve_quadratic(self, a, b, c):
  459. "Solve ax^2 + bx + c = 0."
  460. a, b, c = map(self.constructor, (a, b, c))
  461. assert a != 0
  462. return self.solve_monic_quadratic(b/a, c/a)
  463. def solve_monic_quadratic(self, b, c):
  464. "Solve x^2 + bx + c = 0."
  465. b, c = map(self.constructor, (b, c))
  466. s = b/2
  467. return [y - s for y in self.solve_depressed_quadratic(c - s*s)]
  468. def solve_depressed_quadratic(self, c):
  469. "Solve x^2 + c = 0."
  470. return self.sqrt.all_roots(-c)
  471. def solve_cubic(self, a, b, c, d):
  472. "Solve ax^3 + bx^2 + cx + d = 0."
  473. a, b, c, d = map(self.constructor, (a, b, c, d))
  474. assert a != 0
  475. return self.solve_monic_cubic(b/a, c/a, d/a)
  476. def solve_monic_cubic(self, b, c, d):
  477. "Solve x^3 + bx^2 + cx + d = 0."
  478. b, c, d = map(self.constructor, (b, c, d))
  479. s = b/3
  480. return [y - s for y in self.solve_depressed_cubic(
  481. c - 3*s*s, 2*s*s*s - c*s + d)]
  482. def solve_depressed_cubic(self, c, d):
  483. "Solve x^3 + cx + d = 0."
  484. c, d = map(self.constructor, (c, d))
  485. solutions = set()
  486. # To solve x^3 + cx + d = 0, set p = -c/3, then
  487. # substitute x = z + p/z to get z^6 + d z^3 + p^3 = 0.
  488. # Solve that quadratic for z^3, then take cube roots.
  489. p = -c/3
  490. for z3 in self.solve_monic_quadratic(d, p**3):
  491. # As I understand the theory, we _should_ only need to
  492. # take cube roots of one root of that quadratic: the other
  493. # one should give the same set of answers after you map
  494. # each one through z |-> z+p/z. But speed isn't at a
  495. # premium here, so I'll do this the way that must work.
  496. for z in self.cbrt.all_roots(z3):
  497. solutions.add(z + p/z)
  498. return solutions
  499. def solve_quartic(self, a, b, c, d, e):
  500. "Solve ax^4 + bx^3 + cx^2 + dx + e = 0."
  501. a, b, c, d, e = map(self.constructor, (a, b, c, d, e))
  502. assert a != 0
  503. return self.solve_monic_quartic(b/a, c/a, d/a, e/a)
  504. def solve_monic_quartic(self, b, c, d, e):
  505. "Solve x^4 + bx^3 + cx^2 + dx + e = 0."
  506. b, c, d, e = map(self.constructor, (b, c, d, e))
  507. s = b/4
  508. return [y - s for y in self.solve_depressed_quartic(
  509. c - 6*s*s, d - 2*c*s + 8*s*s*s, e - d*s + c*s*s - 3*s*s*s*s)]
  510. def solve_depressed_quartic(self, c, d, e):
  511. "Solve x^4 + cx^2 + dx + e = 0."
  512. c, d, e = map(self.constructor, (c, d, e))
  513. solutions = set()
  514. # To solve an equation of this form, we search for a value y
  515. # such that subtracting the original polynomial from (x^2+y)^2
  516. # yields a quadratic of the special form (ux+v)^2.
  517. #
  518. # Then our equation is rewritten as (x^2+y)^2 - (ux+v)^2 = 0
  519. # i.e. ((x^2+y) + (ux+v)) ((x^2+y) - (ux+v)) = 0
  520. # i.e. the product of two quadratics, each of which we then solve.
  521. #
  522. # To find y, we write down the discriminant of the quadratic
  523. # (x^2+y)^2 - (x^4 + cx^2 + dx + e) and set it to 0, which
  524. # gives a cubic in y. Maxima gives the coefficients as
  525. # (-8)y^3 + (4c)y^2 + (8e)y + (d^2-4ce).
  526. #
  527. # As above, we _should_ only need one value of y. But I go
  528. # through them all just in case, because I don't care about
  529. # speed, and because checking the assertions inside this loop
  530. # for every value is extra reassurance that I've done all of
  531. # this right.
  532. for y in self.solve_cubic(-8, 4*c, 8*e, d*d-4*c*e):
  533. # Subtract the original equation from (x^2+y)^2 to get the
  534. # coefficients of our quadratic residual.
  535. A, B, C = 2*y-c, -d, y*y-e
  536. # Expect that to have zero discriminant, i.e. a repeated root.
  537. assert B*B - 4*A*C == 0
  538. # If (Ax^2+Bx+C) == (ux+v)^2 then we have u^2=A, 2uv=B, v^2=C.
  539. # So we can either recover u as sqrt(A) or v as sqrt(C), and
  540. # whichever we did, find the other from B by division. But
  541. # either of the end coefficients might be zero, so we have
  542. # to be prepared to try either option.
  543. try:
  544. if A != 0:
  545. u = self.sqrt.root(A)
  546. v = B/(2*u)
  547. elif C != 0:
  548. v = self.sqrt.root(C)
  549. u = B/(2*v)
  550. else:
  551. # One last possibility is that all three coefficients
  552. # of our residual quadratic are 0, in which case,
  553. # obviously, u=v=0 as well.
  554. u = v = 0
  555. except ValueError:
  556. # If Ax^2+Bx+C looked like a perfect square going by
  557. # its discriminant, but actually taking the square
  558. # root of A or C threw an exception, that means that
  559. # it's the square of a polynomial whose coefficients
  560. # live in a yet-higher field extension of Z_p. In that
  561. # case we're not going to end up with roots of the
  562. # original quartic in Z_p if we start from here!
  563. continue
  564. # So now our quartic is factorised into the form
  565. # (x^2 - ux - v + y) (x^2 + ux + v + y).
  566. for x in self.solve_monic_quadratic(-u, y-v):
  567. solutions.add(x)
  568. for x in self.solve_monic_quadratic(u, y+v):
  569. solutions.add(x)
  570. return solutions
  571. class EquationSolverTest(unittest.TestCase):
  572. def testQuadratic(self):
  573. E = EquationSolverModP(11)
  574. solns = E.solve_quadratic(3, 2, 6)
  575. self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2"])
  576. def testCubic(self):
  577. E = EquationSolverModP(11)
  578. solns = E.solve_cubic(7, 2, 0, 2)
  579. self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2", "0x3"])
  580. def testQuartic(self):
  581. E = EquationSolverModP(11)
  582. solns = E.solve_quartic(9, 9, 7, 1, 7)
  583. self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2", "0x3", "0x4"])
  584. if __name__ == "__main__":
  585. import sys
  586. if sys.argv[1:] == ["--test"]:
  587. sys.argv[1:2] = []
  588. unittest.main()