make1305.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. #!/usr/bin/env python3
  2. import sys
  3. import string
  4. from collections import namedtuple
  5. assert sys.version_info[:2] >= (3,0), "This is Python 3 code"
  6. class Multiprecision(object):
  7. def __init__(self, target, minval, maxval, words):
  8. self.target = target
  9. self.minval = minval
  10. self.maxval = maxval
  11. self.words = words
  12. assert 0 <= self.minval
  13. assert self.minval <= self.maxval
  14. assert self.target.nwords(self.maxval) == len(words)
  15. def getword(self, n):
  16. return self.words[n] if n < len(self.words) else "0"
  17. def __add__(self, rhs):
  18. newmin = self.minval + rhs.minval
  19. newmax = self.maxval + rhs.maxval
  20. nwords = self.target.nwords(newmax)
  21. words = []
  22. addfn = self.target.add
  23. for i in range(nwords):
  24. words.append(addfn(self.getword(i), rhs.getword(i)))
  25. addfn = self.target.adc
  26. return Multiprecision(self.target, newmin, newmax, words)
  27. def __mul__(self, rhs):
  28. newmin = self.minval * rhs.minval
  29. newmax = self.maxval * rhs.maxval
  30. nwords = self.target.nwords(newmax)
  31. words = []
  32. # There are basically two strategies we could take for
  33. # multiplying two multiprecision integers. One is to enumerate
  34. # the space of pairs of word indices in lexicographic order,
  35. # essentially computing a*b[i] for each i and adding them
  36. # together; the other is to enumerate in diagonal order,
  37. # computing everything together that belongs at a particular
  38. # output word index.
  39. #
  40. # For the moment, I've gone for the former.
  41. sprev = []
  42. for i, sword in enumerate(self.words):
  43. rprev = None
  44. sthis = sprev[:i]
  45. for j, rword in enumerate(rhs.words):
  46. prevwords = []
  47. if i+j < len(sprev):
  48. prevwords.append(sprev[i+j])
  49. if rprev is not None:
  50. prevwords.append(rprev)
  51. vhi, vlo = self.target.muladd(sword, rword, *prevwords)
  52. sthis.append(vlo)
  53. rprev = vhi
  54. sthis.append(rprev)
  55. sprev = sthis
  56. # Remove unneeded words from the top of the output, if we can
  57. # prove by range analysis that they'll always be zero.
  58. sprev = sprev[:self.target.nwords(newmax)]
  59. return Multiprecision(self.target, newmin, newmax, sprev)
  60. def extract_bits(self, start, bits=None):
  61. if bits is None:
  62. bits = (self.maxval >> start).bit_length()
  63. # Overly thorough range analysis: if min and max have the same
  64. # *quotient* by 2^bits, then the result of reducing anything
  65. # in the range [min,max] mod 2^bits has to fall within the
  66. # obvious range. But if they have different quotients, then
  67. # you can wrap round the modulus and so any value mod 2^bits
  68. # is possible.
  69. newmin = self.minval >> start
  70. newmax = self.maxval >> start
  71. if (newmin >> bits) != (newmax >> bits):
  72. newmin = 0
  73. newmax = (1 << bits) - 1
  74. nwords = self.target.nwords(newmax)
  75. words = []
  76. for i in range(nwords):
  77. srcpos = i * self.target.bits + start
  78. maxbits = min(self.target.bits, start + bits - srcpos)
  79. wordindex = srcpos // self.target.bits
  80. if srcpos % self.target.bits == 0:
  81. word = self.getword(srcpos // self.target.bits)
  82. elif (wordindex+1 >= len(self.words) or
  83. srcpos % self.target.bits + maxbits < self.target.bits):
  84. word = self.target.new_value(
  85. "(%%s) >> %d" % (srcpos % self.target.bits),
  86. self.getword(srcpos // self.target.bits))
  87. else:
  88. word = self.target.new_value(
  89. "((%%s) >> %d) | ((%%s) << %d)" % (
  90. srcpos % self.target.bits,
  91. self.target.bits - (srcpos % self.target.bits)),
  92. self.getword(srcpos // self.target.bits),
  93. self.getword(srcpos // self.target.bits + 1))
  94. if maxbits < self.target.bits and maxbits < bits:
  95. word = self.target.new_value(
  96. "(%%s) & ((((BignumInt)1) << %d)-1)" % maxbits,
  97. word)
  98. words.append(word)
  99. return Multiprecision(self.target, newmin, newmax, words)
  100. # Each Statement has a list of variables it reads, and a list of ones
  101. # it writes. 'forms' is a list of multiple actual C statements it
  102. # could be generated as, depending on which of its output variables is
  103. # actually used (e.g. no point calling BignumADC if the generated
  104. # carry in a particular case is unused, or BignumMUL if nobody needs
  105. # the top half). It is indexed by a bitmap whose bits correspond to
  106. # the entries in wvars, with wvars[0] the MSB and wvars[-1] the LSB.
  107. Statement = namedtuple("Statement", "rvars wvars forms")
  108. class CodegenTarget(object):
  109. def __init__(self, bits):
  110. self.bits = bits
  111. self.valindex = 0
  112. self.stmts = []
  113. self.generators = {}
  114. self.bv_words = (130 + self.bits - 1) // self.bits
  115. self.carry_index = 0
  116. def nwords(self, maxval):
  117. return (maxval.bit_length() + self.bits - 1) // self.bits
  118. def stmt(self, stmt, needed=False):
  119. index = len(self.stmts)
  120. self.stmts.append([needed, stmt])
  121. for val in stmt.wvars:
  122. self.generators[val] = index
  123. def new_value(self, formatstr=None, *deps):
  124. name = "v%d" % self.valindex
  125. self.valindex += 1
  126. if formatstr is not None:
  127. self.stmt(Statement(
  128. rvars=deps, wvars=[name],
  129. forms=[None, name + " = " + formatstr % deps]))
  130. return name
  131. def bigval_input(self, name, bits):
  132. words = (bits + self.bits - 1) // self.bits
  133. # Expect not to require an entire extra word
  134. assert words == self.bv_words
  135. return Multiprecision(self, 0, (1<<bits)-1, [
  136. self.new_value("%s->w[%d]" % (name, i)) for i in range(words)])
  137. def const(self, value):
  138. # We only support constants small enough to both fit in a
  139. # BignumInt (of any size supported) _and_ be expressible in C
  140. # with no weird integer literal syntax like a trailing LL.
  141. #
  142. # Supporting larger constants would be possible - you could
  143. # break 'value' up into word-sized pieces on the Python side,
  144. # and generate a legal C expression for each piece by
  145. # splitting it further into pieces within the
  146. # standards-guaranteed 'unsigned long' limit of 32 bits and
  147. # then casting those to BignumInt before combining them with
  148. # shifts. But it would be a lot of effort, and since the
  149. # application for this code doesn't even need it, there's no
  150. # point in bothering.
  151. assert value < 2**16
  152. return Multiprecision(self, value, value, ["%d" % value])
  153. def current_carry(self):
  154. return "carry%d" % self.carry_index
  155. def add(self, a1, a2):
  156. ret = self.new_value()
  157. adcform = "BignumADC(%s, carry, %s, %s, 0)" % (ret, a1, a2)
  158. plainform = "%s = %s + %s" % (ret, a1, a2)
  159. self.carry_index += 1
  160. carryout = self.current_carry()
  161. self.stmt(Statement(
  162. rvars=[a1,a2], wvars=[ret,carryout],
  163. forms=[None, adcform, plainform, adcform]))
  164. return ret
  165. def adc(self, a1, a2):
  166. ret = self.new_value()
  167. adcform = "BignumADC(%s, carry, %s, %s, carry)" % (ret, a1, a2)
  168. plainform = "%s = %s + %s + carry" % (ret, a1, a2)
  169. carryin = self.current_carry()
  170. self.carry_index += 1
  171. carryout = self.current_carry()
  172. self.stmt(Statement(
  173. rvars=[a1,a2,carryin], wvars=[ret,carryout],
  174. forms=[None, adcform, plainform, adcform]))
  175. return ret
  176. def muladd(self, m1, m2, *addends):
  177. rlo = self.new_value()
  178. rhi = self.new_value()
  179. wideform = "BignumMUL%s(%s)" % (
  180. { 0:"", 1:"ADD", 2:"ADD2" }[len(addends)],
  181. ", ".join([rhi, rlo, m1, m2] + list(addends)))
  182. narrowform = " + ".join(["%s = %s * %s" % (rlo, m1, m2)] +
  183. list(addends))
  184. self.stmt(Statement(
  185. rvars=[m1,m2]+list(addends), wvars=[rhi,rlo],
  186. forms=[None, narrowform, wideform, wideform]))
  187. return rhi, rlo
  188. def write_bigval(self, name, val):
  189. for i in range(self.bv_words):
  190. word = val.getword(i)
  191. self.stmt(Statement(
  192. rvars=[word], wvars=[],
  193. forms=["%s->w[%d] = %s" % (name, i, word)]),
  194. needed=True)
  195. def compute_needed(self):
  196. used_vars = set()
  197. self.queue = [stmt for (needed,stmt) in self.stmts if needed]
  198. while len(self.queue) > 0:
  199. stmt = self.queue.pop(0)
  200. deps = []
  201. for var in stmt.rvars:
  202. if var[0] in string.digits:
  203. continue # constant
  204. deps.append(self.generators[var])
  205. used_vars.add(var)
  206. for index in deps:
  207. if not self.stmts[index][0]:
  208. self.stmts[index][0] = True
  209. self.queue.append(self.stmts[index][1])
  210. forms = []
  211. for i, (needed, stmt) in enumerate(self.stmts):
  212. if needed:
  213. formindex = 0
  214. for (j, var) in enumerate(stmt.wvars):
  215. formindex *= 2
  216. if var in used_vars:
  217. formindex += 1
  218. forms.append(stmt.forms[formindex])
  219. # Now we must check whether this form of the statement
  220. # also writes some variables we _don't_ actually need
  221. # (e.g. if you only wanted the top half from a mul, or
  222. # only the carry from an adc, you'd be forced to
  223. # generate the other output too). Easiest way to do
  224. # this is to look for an identical statement form
  225. # later in the array.
  226. maxindex = max(i for i in range(len(stmt.forms))
  227. if stmt.forms[i] == stmt.forms[formindex])
  228. extra_vars = maxindex & ~formindex
  229. bitpos = 0
  230. while extra_vars != 0:
  231. if extra_vars & (1 << bitpos):
  232. extra_vars &= ~(1 << bitpos)
  233. var = stmt.wvars[-1-bitpos]
  234. used_vars.add(var)
  235. # Also, write out a cast-to-void for each
  236. # subsequently unused value, to prevent gcc
  237. # warnings when the output code is compiled.
  238. forms.append("(void)" + var)
  239. bitpos += 1
  240. used_carry = any(v.startswith("carry") for v in used_vars)
  241. used_vars = [v for v in used_vars if v.startswith("v")]
  242. used_vars.sort(key=lambda v: int(v[1:]))
  243. return used_carry, used_vars, forms
  244. def text(self):
  245. used_carry, values, forms = self.compute_needed()
  246. ret = ""
  247. while len(values) > 0:
  248. prefix, sep, suffix = " BignumInt ", ", ", ";"
  249. currline = values.pop(0)
  250. while (len(values) > 0 and
  251. len(prefix+currline+sep+values[0]+suffix) < 79):
  252. currline += sep + values.pop(0)
  253. ret += prefix + currline + suffix + "\n"
  254. if used_carry:
  255. ret += " BignumCarry carry;\n"
  256. if ret != "":
  257. ret += "\n"
  258. for stmtform in forms:
  259. ret += " %s;\n" % stmtform
  260. return ret
  261. def gen_add(target):
  262. # This is an addition _without_ reduction mod p, so that it can be
  263. # used both during accumulation of the polynomial and for adding
  264. # on the encrypted nonce at the end (which is mod 2^128, not mod
  265. # p).
  266. #
  267. # Because one of the inputs will have come from our
  268. # not-completely-reducing multiplication function, we expect up to
  269. # 3 extra bits of input.
  270. a = target.bigval_input("a", 133)
  271. b = target.bigval_input("b", 133)
  272. ret = a + b
  273. target.write_bigval("r", ret)
  274. return """\
  275. static void bigval_add(bigval *r, const bigval *a, const bigval *b)
  276. {
  277. %s}
  278. \n""" % target.text()
  279. def gen_mul(target):
  280. # The inputs are not 100% reduced mod p. Specifically, we can get
  281. # a full 130-bit number from the pow5==0 pass, and then a 130-bit
  282. # number times 5 from the pow5==1 pass, plus a possible carry. The
  283. # total of that can be easily bounded above by 2^130 * 8, so we
  284. # need to assume we're multiplying two 133-bit numbers.
  285. a = target.bigval_input("a", 133)
  286. b = target.bigval_input("b", 133)
  287. ab = a * b
  288. ab0 = ab.extract_bits(0, 130)
  289. ab1 = ab.extract_bits(130, 130)
  290. ab2 = ab.extract_bits(260)
  291. ab1_5 = target.const(5) * ab1
  292. ab2_25 = target.const(25) * ab2
  293. ret = ab0 + ab1_5 + ab2_25
  294. target.write_bigval("r", ret)
  295. return """\
  296. static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)
  297. {
  298. %s}
  299. \n""" % target.text()
  300. def gen_final_reduce(target):
  301. # Given our input number n, n >> 130 is usually precisely the
  302. # multiple of p that needs to be subtracted from n to reduce it to
  303. # strictly less than p, but it might be too low by 1 (but not more
  304. # than 1, given the range of our input is nowhere near the square
  305. # of the modulus). So we add another 5, which will push a carry
  306. # into the 130th bit if and only if that has happened, and then
  307. # use that to decide whether to subtract one more copy of p.
  308. a = target.bigval_input("n", 133)
  309. q = a.extract_bits(130)
  310. adjusted = a.extract_bits(0, 130) + target.const(5) * q
  311. final_subtract = (adjusted + target.const(5)).extract_bits(130)
  312. adjusted2 = adjusted + target.const(5) * final_subtract
  313. ret = adjusted2.extract_bits(0, 130)
  314. target.write_bigval("n", ret)
  315. return """\
  316. static void bigval_final_reduce(bigval *n)
  317. {
  318. %s}
  319. \n""" % target.text()
  320. pp_keyword = "#if"
  321. for bits in [16, 32, 64]:
  322. sys.stdout.write("%s BIGNUM_INT_BITS == %d\n\n" % (pp_keyword, bits))
  323. pp_keyword = "#elif"
  324. sys.stdout.write(gen_add(CodegenTarget(bits)))
  325. sys.stdout.write(gen_mul(CodegenTarget(bits)))
  326. sys.stdout.write(gen_final_reduce(CodegenTarget(bits)))
  327. sys.stdout.write("""#else
  328. #error Add another bit count to contrib/make1305.py and rerun it
  329. #endif
  330. """)