make1305.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. #!/usr/bin/env python
  2. import sys
  3. import string
  4. from collections import namedtuple
  5. class Multiprecision(object):
  6. def __init__(self, target, minval, maxval, words):
  7. self.target = target
  8. self.minval = minval
  9. self.maxval = maxval
  10. self.words = words
  11. assert 0 <= self.minval
  12. assert self.minval <= self.maxval
  13. assert self.target.nwords(self.maxval) == len(words)
  14. def getword(self, n):
  15. return self.words[n] if n < len(self.words) else "0"
  16. def __add__(self, rhs):
  17. newmin = self.minval + rhs.minval
  18. newmax = self.maxval + rhs.maxval
  19. nwords = self.target.nwords(newmax)
  20. words = []
  21. addfn = self.target.add
  22. for i in range(nwords):
  23. words.append(addfn(self.getword(i), rhs.getword(i)))
  24. addfn = self.target.adc
  25. return Multiprecision(self.target, newmin, newmax, words)
  26. def __mul__(self, rhs):
  27. newmin = self.minval * rhs.minval
  28. newmax = self.maxval * rhs.maxval
  29. nwords = self.target.nwords(newmax)
  30. words = []
  31. # There are basically two strategies we could take for
  32. # multiplying two multiprecision integers. One is to enumerate
  33. # the space of pairs of word indices in lexicographic order,
  34. # essentially computing a*b[i] for each i and adding them
  35. # together; the other is to enumerate in diagonal order,
  36. # computing everything together that belongs at a particular
  37. # output word index.
  38. #
  39. # For the moment, I've gone for the former.
  40. sprev = []
  41. for i, sword in enumerate(self.words):
  42. rprev = None
  43. sthis = sprev[:i]
  44. for j, rword in enumerate(rhs.words):
  45. prevwords = []
  46. if i+j < len(sprev):
  47. prevwords.append(sprev[i+j])
  48. if rprev is not None:
  49. prevwords.append(rprev)
  50. vhi, vlo = self.target.muladd(sword, rword, *prevwords)
  51. sthis.append(vlo)
  52. rprev = vhi
  53. sthis.append(rprev)
  54. sprev = sthis
  55. # Remove unneeded words from the top of the output, if we can
  56. # prove by range analysis that they'll always be zero.
  57. sprev = sprev[:self.target.nwords(newmax)]
  58. return Multiprecision(self.target, newmin, newmax, sprev)
  59. def extract_bits(self, start, bits=None):
  60. if bits is None:
  61. bits = (self.maxval >> start).bit_length()
  62. # Overly thorough range analysis: if min and max have the same
  63. # *quotient* by 2^bits, then the result of reducing anything
  64. # in the range [min,max] mod 2^bits has to fall within the
  65. # obvious range. But if they have different quotients, then
  66. # you can wrap round the modulus and so any value mod 2^bits
  67. # is possible.
  68. newmin = self.minval >> start
  69. newmax = self.maxval >> start
  70. if (newmin >> bits) != (newmax >> bits):
  71. newmin = 0
  72. newmax = (1 << bits) - 1
  73. nwords = self.target.nwords(newmax)
  74. words = []
  75. for i in range(nwords):
  76. srcpos = i * self.target.bits + start
  77. maxbits = min(self.target.bits, start + bits - srcpos)
  78. wordindex = srcpos / self.target.bits
  79. if srcpos % self.target.bits == 0:
  80. word = self.getword(srcpos / self.target.bits)
  81. elif (wordindex+1 >= len(self.words) or
  82. srcpos % self.target.bits + maxbits < self.target.bits):
  83. word = self.target.new_value(
  84. "(%%s) >> %d" % (srcpos % self.target.bits),
  85. self.getword(srcpos / self.target.bits))
  86. else:
  87. word = self.target.new_value(
  88. "((%%s) >> %d) | ((%%s) << %d)" % (
  89. srcpos % self.target.bits,
  90. self.target.bits - (srcpos % self.target.bits)),
  91. self.getword(srcpos / self.target.bits),
  92. self.getword(srcpos / self.target.bits + 1))
  93. if maxbits < self.target.bits and maxbits < bits:
  94. word = self.target.new_value(
  95. "(%%s) & ((((BignumInt)1) << %d)-1)" % maxbits,
  96. word)
  97. words.append(word)
  98. return Multiprecision(self.target, newmin, newmax, words)
  99. # Each Statement has a list of variables it reads, and a list of ones
  100. # it writes. 'forms' is a list of multiple actual C statements it
  101. # could be generated as, depending on which of its output variables is
  102. # actually used (e.g. no point calling BignumADC if the generated
  103. # carry in a particular case is unused, or BignumMUL if nobody needs
  104. # the top half). It is indexed by a bitmap whose bits correspond to
  105. # the entries in wvars, with wvars[0] the MSB and wvars[-1] the LSB.
  106. Statement = namedtuple("Statement", "rvars wvars forms")
  107. class CodegenTarget(object):
  108. def __init__(self, bits):
  109. self.bits = bits
  110. self.valindex = 0
  111. self.stmts = []
  112. self.generators = {}
  113. self.bv_words = (130 + self.bits - 1) / self.bits
  114. self.carry_index = 0
  115. def nwords(self, maxval):
  116. return (maxval.bit_length() + self.bits - 1) / self.bits
  117. def stmt(self, stmt, needed=False):
  118. index = len(self.stmts)
  119. self.stmts.append([needed, stmt])
  120. for val in stmt.wvars:
  121. self.generators[val] = index
  122. def new_value(self, formatstr=None, *deps):
  123. name = "v%d" % self.valindex
  124. self.valindex += 1
  125. if formatstr is not None:
  126. self.stmt(Statement(
  127. rvars=deps, wvars=[name],
  128. forms=[None, name + " = " + formatstr % deps]))
  129. return name
  130. def bigval_input(self, name, bits):
  131. words = (bits + self.bits - 1) / self.bits
  132. # Expect not to require an entire extra word
  133. assert words == self.bv_words
  134. return Multiprecision(self, 0, (1<<bits)-1, [
  135. self.new_value("%s->w[%d]" % (name, i)) for i in range(words)])
  136. def const(self, value):
  137. # We only support constants small enough to both fit in a
  138. # BignumInt (of any size supported) _and_ be expressible in C
  139. # with no weird integer literal syntax like a trailing LL.
  140. #
  141. # Supporting larger constants would be possible - you could
  142. # break 'value' up into word-sized pieces on the Python side,
  143. # and generate a legal C expression for each piece by
  144. # splitting it further into pieces within the
  145. # standards-guaranteed 'unsigned long' limit of 32 bits and
  146. # then casting those to BignumInt before combining them with
  147. # shifts. But it would be a lot of effort, and since the
  148. # application for this code doesn't even need it, there's no
  149. # point in bothering.
  150. assert value < 2**16
  151. return Multiprecision(self, value, value, ["%d" % value])
  152. def current_carry(self):
  153. return "carry%d" % self.carry_index
  154. def add(self, a1, a2):
  155. ret = self.new_value()
  156. adcform = "BignumADC(%s, carry, %s, %s, 0)" % (ret, a1, a2)
  157. plainform = "%s = %s + %s" % (ret, a1, a2)
  158. self.carry_index += 1
  159. carryout = self.current_carry()
  160. self.stmt(Statement(
  161. rvars=[a1,a2], wvars=[ret,carryout],
  162. forms=[None, adcform, plainform, adcform]))
  163. return ret
  164. def adc(self, a1, a2):
  165. ret = self.new_value()
  166. adcform = "BignumADC(%s, carry, %s, %s, carry)" % (ret, a1, a2)
  167. plainform = "%s = %s + %s + carry" % (ret, a1, a2)
  168. carryin = self.current_carry()
  169. self.carry_index += 1
  170. carryout = self.current_carry()
  171. self.stmt(Statement(
  172. rvars=[a1,a2,carryin], wvars=[ret,carryout],
  173. forms=[None, adcform, plainform, adcform]))
  174. return ret
  175. def muladd(self, m1, m2, *addends):
  176. rlo = self.new_value()
  177. rhi = self.new_value()
  178. wideform = "BignumMUL%s(%s)" % (
  179. { 0:"", 1:"ADD", 2:"ADD2" }[len(addends)],
  180. ", ".join([rhi, rlo, m1, m2] + list(addends)))
  181. narrowform = " + ".join(["%s = %s * %s" % (rlo, m1, m2)] +
  182. list(addends))
  183. self.stmt(Statement(
  184. rvars=[m1,m2]+list(addends), wvars=[rhi,rlo],
  185. forms=[None, narrowform, wideform, wideform]))
  186. return rhi, rlo
  187. def write_bigval(self, name, val):
  188. for i in range(self.bv_words):
  189. word = val.getword(i)
  190. self.stmt(Statement(
  191. rvars=[word], wvars=[],
  192. forms=["%s->w[%d] = %s" % (name, i, word)]),
  193. needed=True)
  194. def compute_needed(self):
  195. used_vars = set()
  196. self.queue = [stmt for (needed,stmt) in self.stmts if needed]
  197. while len(self.queue) > 0:
  198. stmt = self.queue.pop(0)
  199. deps = []
  200. for var in stmt.rvars:
  201. if var[0] in string.digits:
  202. continue # constant
  203. deps.append(self.generators[var])
  204. used_vars.add(var)
  205. for index in deps:
  206. if not self.stmts[index][0]:
  207. self.stmts[index][0] = True
  208. self.queue.append(self.stmts[index][1])
  209. forms = []
  210. for i, (needed, stmt) in enumerate(self.stmts):
  211. if needed:
  212. formindex = 0
  213. for (j, var) in enumerate(stmt.wvars):
  214. formindex *= 2
  215. if var in used_vars:
  216. formindex += 1
  217. forms.append(stmt.forms[formindex])
  218. # Now we must check whether this form of the statement
  219. # also writes some variables we _don't_ actually need
  220. # (e.g. if you only wanted the top half from a mul, or
  221. # only the carry from an adc, you'd be forced to
  222. # generate the other output too). Easiest way to do
  223. # this is to look for an identical statement form
  224. # later in the array.
  225. maxindex = max(i for i in range(len(stmt.forms))
  226. if stmt.forms[i] == stmt.forms[formindex])
  227. extra_vars = maxindex & ~formindex
  228. bitpos = 0
  229. while extra_vars != 0:
  230. if extra_vars & (1 << bitpos):
  231. extra_vars &= ~(1 << bitpos)
  232. var = stmt.wvars[-1-bitpos]
  233. used_vars.add(var)
  234. # Also, write out a cast-to-void for each
  235. # subsequently unused value, to prevent gcc
  236. # warnings when the output code is compiled.
  237. forms.append("(void)" + var)
  238. bitpos += 1
  239. used_carry = any(v.startswith("carry") for v in used_vars)
  240. used_vars = [v for v in used_vars if v.startswith("v")]
  241. used_vars.sort(key=lambda v: int(v[1:]))
  242. return used_carry, used_vars, forms
  243. def text(self):
  244. used_carry, values, forms = self.compute_needed()
  245. ret = ""
  246. while len(values) > 0:
  247. prefix, sep, suffix = " BignumInt ", ", ", ";"
  248. currline = values.pop(0)
  249. while (len(values) > 0 and
  250. len(prefix+currline+sep+values[0]+suffix) < 79):
  251. currline += sep + values.pop(0)
  252. ret += prefix + currline + suffix + "\n"
  253. if used_carry:
  254. ret += " BignumCarry carry;\n"
  255. if ret != "":
  256. ret += "\n"
  257. for stmtform in forms:
  258. ret += " %s;\n" % stmtform
  259. return ret
  260. def gen_add(target):
  261. # This is an addition _without_ reduction mod p, so that it can be
  262. # used both during accumulation of the polynomial and for adding
  263. # on the encrypted nonce at the end (which is mod 2^128, not mod
  264. # p).
  265. #
  266. # Because one of the inputs will have come from our
  267. # not-completely-reducing multiplication function, we expect up to
  268. # 3 extra bits of input.
  269. a = target.bigval_input("a", 133)
  270. b = target.bigval_input("b", 133)
  271. ret = a + b
  272. target.write_bigval("r", ret)
  273. return """\
  274. static void bigval_add(bigval *r, const bigval *a, const bigval *b)
  275. {
  276. %s}
  277. \n""" % target.text()
  278. def gen_mul(target):
  279. # The inputs are not 100% reduced mod p. Specifically, we can get
  280. # a full 130-bit number from the pow5==0 pass, and then a 130-bit
  281. # number times 5 from the pow5==1 pass, plus a possible carry. The
  282. # total of that can be easily bounded above by 2^130 * 8, so we
  283. # need to assume we're multiplying two 133-bit numbers.
  284. a = target.bigval_input("a", 133)
  285. b = target.bigval_input("b", 133)
  286. ab = a * b
  287. ab0 = ab.extract_bits(0, 130)
  288. ab1 = ab.extract_bits(130, 130)
  289. ab2 = ab.extract_bits(260)
  290. ab1_5 = target.const(5) * ab1
  291. ab2_25 = target.const(25) * ab2
  292. ret = ab0 + ab1_5 + ab2_25
  293. target.write_bigval("r", ret)
  294. return """\
  295. static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)
  296. {
  297. %s}
  298. \n""" % target.text()
  299. def gen_final_reduce(target):
  300. # We take our input number n, and compute k = n + 5*(n >> 130).
  301. # Then k >> 130 is precisely the multiple of p that needs to be
  302. # subtracted from n to reduce it to strictly less than p.
  303. a = target.bigval_input("n", 133)
  304. a1 = a.extract_bits(130, 130)
  305. k = a + target.const(5) * a1
  306. q = k.extract_bits(130)
  307. adjusted = a + target.const(5) * q
  308. ret = adjusted.extract_bits(0, 130)
  309. target.write_bigval("n", ret)
  310. return """\
  311. static void bigval_final_reduce(bigval *n)
  312. {
  313. %s}
  314. \n""" % target.text()
  315. pp_keyword = "#if"
  316. for bits in [16, 32, 64]:
  317. sys.stdout.write("%s BIGNUM_INT_BITS == %d\n\n" % (pp_keyword, bits))
  318. pp_keyword = "#elif"
  319. sys.stdout.write(gen_add(CodegenTarget(bits)))
  320. sys.stdout.write(gen_mul(CodegenTarget(bits)))
  321. sys.stdout.write(gen_final_reduce(CodegenTarget(bits)))
  322. sys.stdout.write("""#else
  323. #error Add another bit count to contrib/make1305.py and rerun it
  324. #endif
  325. """)