123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375 |
- #!/usr/bin/env python3
- import sys
- import string
- from collections import namedtuple
- assert sys.version_info[:2] >= (3,0), "This is Python 3 code"
- class Multiprecision(object):
- def __init__(self, target, minval, maxval, words):
- self.target = target
- self.minval = minval
- self.maxval = maxval
- self.words = words
- assert 0 <= self.minval
- assert self.minval <= self.maxval
- assert self.target.nwords(self.maxval) == len(words)
- def getword(self, n):
- return self.words[n] if n < len(self.words) else "0"
- def __add__(self, rhs):
- newmin = self.minval + rhs.minval
- newmax = self.maxval + rhs.maxval
- nwords = self.target.nwords(newmax)
- words = []
- addfn = self.target.add
- for i in range(nwords):
- words.append(addfn(self.getword(i), rhs.getword(i)))
- addfn = self.target.adc
- return Multiprecision(self.target, newmin, newmax, words)
- def __mul__(self, rhs):
- newmin = self.minval * rhs.minval
- newmax = self.maxval * rhs.maxval
- nwords = self.target.nwords(newmax)
- words = []
- # There are basically two strategies we could take for
- # multiplying two multiprecision integers. One is to enumerate
- # the space of pairs of word indices in lexicographic order,
- # essentially computing a*b[i] for each i and adding them
- # together; the other is to enumerate in diagonal order,
- # computing everything together that belongs at a particular
- # output word index.
- #
- # For the moment, I've gone for the former.
- sprev = []
- for i, sword in enumerate(self.words):
- rprev = None
- sthis = sprev[:i]
- for j, rword in enumerate(rhs.words):
- prevwords = []
- if i+j < len(sprev):
- prevwords.append(sprev[i+j])
- if rprev is not None:
- prevwords.append(rprev)
- vhi, vlo = self.target.muladd(sword, rword, *prevwords)
- sthis.append(vlo)
- rprev = vhi
- sthis.append(rprev)
- sprev = sthis
- # Remove unneeded words from the top of the output, if we can
- # prove by range analysis that they'll always be zero.
- sprev = sprev[:self.target.nwords(newmax)]
- return Multiprecision(self.target, newmin, newmax, sprev)
- def extract_bits(self, start, bits=None):
- if bits is None:
- bits = (self.maxval >> start).bit_length()
- # Overly thorough range analysis: if min and max have the same
- # *quotient* by 2^bits, then the result of reducing anything
- # in the range [min,max] mod 2^bits has to fall within the
- # obvious range. But if they have different quotients, then
- # you can wrap round the modulus and so any value mod 2^bits
- # is possible.
- newmin = self.minval >> start
- newmax = self.maxval >> start
- if (newmin >> bits) != (newmax >> bits):
- newmin = 0
- newmax = (1 << bits) - 1
- nwords = self.target.nwords(newmax)
- words = []
- for i in range(nwords):
- srcpos = i * self.target.bits + start
- maxbits = min(self.target.bits, start + bits - srcpos)
- wordindex = srcpos // self.target.bits
- if srcpos % self.target.bits == 0:
- word = self.getword(srcpos // self.target.bits)
- elif (wordindex+1 >= len(self.words) or
- srcpos % self.target.bits + maxbits < self.target.bits):
- word = self.target.new_value(
- "(%%s) >> %d" % (srcpos % self.target.bits),
- self.getword(srcpos // self.target.bits))
- else:
- word = self.target.new_value(
- "((%%s) >> %d) | ((%%s) << %d)" % (
- srcpos % self.target.bits,
- self.target.bits - (srcpos % self.target.bits)),
- self.getword(srcpos // self.target.bits),
- self.getword(srcpos // self.target.bits + 1))
- if maxbits < self.target.bits and maxbits < bits:
- word = self.target.new_value(
- "(%%s) & ((((BignumInt)1) << %d)-1)" % maxbits,
- word)
- words.append(word)
- return Multiprecision(self.target, newmin, newmax, words)
- # Each Statement has a list of variables it reads, and a list of ones
- # it writes. 'forms' is a list of multiple actual C statements it
- # could be generated as, depending on which of its output variables is
- # actually used (e.g. no point calling BignumADC if the generated
- # carry in a particular case is unused, or BignumMUL if nobody needs
- # the top half). It is indexed by a bitmap whose bits correspond to
- # the entries in wvars, with wvars[0] the MSB and wvars[-1] the LSB.
- Statement = namedtuple("Statement", "rvars wvars forms")
- class CodegenTarget(object):
- def __init__(self, bits):
- self.bits = bits
- self.valindex = 0
- self.stmts = []
- self.generators = {}
- self.bv_words = (130 + self.bits - 1) // self.bits
- self.carry_index = 0
- def nwords(self, maxval):
- return (maxval.bit_length() + self.bits - 1) // self.bits
- def stmt(self, stmt, needed=False):
- index = len(self.stmts)
- self.stmts.append([needed, stmt])
- for val in stmt.wvars:
- self.generators[val] = index
- def new_value(self, formatstr=None, *deps):
- name = "v%d" % self.valindex
- self.valindex += 1
- if formatstr is not None:
- self.stmt(Statement(
- rvars=deps, wvars=[name],
- forms=[None, name + " = " + formatstr % deps]))
- return name
- def bigval_input(self, name, bits):
- words = (bits + self.bits - 1) // self.bits
- # Expect not to require an entire extra word
- assert words == self.bv_words
- return Multiprecision(self, 0, (1<<bits)-1, [
- self.new_value("%s->w[%d]" % (name, i)) for i in range(words)])
- def const(self, value):
- # We only support constants small enough to both fit in a
- # BignumInt (of any size supported) _and_ be expressible in C
- # with no weird integer literal syntax like a trailing LL.
- #
- # Supporting larger constants would be possible - you could
- # break 'value' up into word-sized pieces on the Python side,
- # and generate a legal C expression for each piece by
- # splitting it further into pieces within the
- # standards-guaranteed 'unsigned long' limit of 32 bits and
- # then casting those to BignumInt before combining them with
- # shifts. But it would be a lot of effort, and since the
- # application for this code doesn't even need it, there's no
- # point in bothering.
- assert value < 2**16
- return Multiprecision(self, value, value, ["%d" % value])
- def current_carry(self):
- return "carry%d" % self.carry_index
- def add(self, a1, a2):
- ret = self.new_value()
- adcform = "BignumADC(%s, carry, %s, %s, 0)" % (ret, a1, a2)
- plainform = "%s = %s + %s" % (ret, a1, a2)
- self.carry_index += 1
- carryout = self.current_carry()
- self.stmt(Statement(
- rvars=[a1,a2], wvars=[ret,carryout],
- forms=[None, adcform, plainform, adcform]))
- return ret
- def adc(self, a1, a2):
- ret = self.new_value()
- adcform = "BignumADC(%s, carry, %s, %s, carry)" % (ret, a1, a2)
- plainform = "%s = %s + %s + carry" % (ret, a1, a2)
- carryin = self.current_carry()
- self.carry_index += 1
- carryout = self.current_carry()
- self.stmt(Statement(
- rvars=[a1,a2,carryin], wvars=[ret,carryout],
- forms=[None, adcform, plainform, adcform]))
- return ret
- def muladd(self, m1, m2, *addends):
- rlo = self.new_value()
- rhi = self.new_value()
- wideform = "BignumMUL%s(%s)" % (
- { 0:"", 1:"ADD", 2:"ADD2" }[len(addends)],
- ", ".join([rhi, rlo, m1, m2] + list(addends)))
- narrowform = " + ".join(["%s = %s * %s" % (rlo, m1, m2)] +
- list(addends))
- self.stmt(Statement(
- rvars=[m1,m2]+list(addends), wvars=[rhi,rlo],
- forms=[None, narrowform, wideform, wideform]))
- return rhi, rlo
- def write_bigval(self, name, val):
- for i in range(self.bv_words):
- word = val.getword(i)
- self.stmt(Statement(
- rvars=[word], wvars=[],
- forms=["%s->w[%d] = %s" % (name, i, word)]),
- needed=True)
- def compute_needed(self):
- used_vars = set()
- self.queue = [stmt for (needed,stmt) in self.stmts if needed]
- while len(self.queue) > 0:
- stmt = self.queue.pop(0)
- deps = []
- for var in stmt.rvars:
- if var[0] in string.digits:
- continue # constant
- deps.append(self.generators[var])
- used_vars.add(var)
- for index in deps:
- if not self.stmts[index][0]:
- self.stmts[index][0] = True
- self.queue.append(self.stmts[index][1])
- forms = []
- for i, (needed, stmt) in enumerate(self.stmts):
- if needed:
- formindex = 0
- for (j, var) in enumerate(stmt.wvars):
- formindex *= 2
- if var in used_vars:
- formindex += 1
- forms.append(stmt.forms[formindex])
- # Now we must check whether this form of the statement
- # also writes some variables we _don't_ actually need
- # (e.g. if you only wanted the top half from a mul, or
- # only the carry from an adc, you'd be forced to
- # generate the other output too). Easiest way to do
- # this is to look for an identical statement form
- # later in the array.
- maxindex = max(i for i in range(len(stmt.forms))
- if stmt.forms[i] == stmt.forms[formindex])
- extra_vars = maxindex & ~formindex
- bitpos = 0
- while extra_vars != 0:
- if extra_vars & (1 << bitpos):
- extra_vars &= ~(1 << bitpos)
- var = stmt.wvars[-1-bitpos]
- used_vars.add(var)
- # Also, write out a cast-to-void for each
- # subsequently unused value, to prevent gcc
- # warnings when the output code is compiled.
- forms.append("(void)" + var)
- bitpos += 1
- used_carry = any(v.startswith("carry") for v in used_vars)
- used_vars = [v for v in used_vars if v.startswith("v")]
- used_vars.sort(key=lambda v: int(v[1:]))
- return used_carry, used_vars, forms
- def text(self):
- used_carry, values, forms = self.compute_needed()
- ret = ""
- while len(values) > 0:
- prefix, sep, suffix = " BignumInt ", ", ", ";"
- currline = values.pop(0)
- while (len(values) > 0 and
- len(prefix+currline+sep+values[0]+suffix) < 79):
- currline += sep + values.pop(0)
- ret += prefix + currline + suffix + "\n"
- if used_carry:
- ret += " BignumCarry carry;\n"
- if ret != "":
- ret += "\n"
- for stmtform in forms:
- ret += " %s;\n" % stmtform
- return ret
- def gen_add(target):
- # This is an addition _without_ reduction mod p, so that it can be
- # used both during accumulation of the polynomial and for adding
- # on the encrypted nonce at the end (which is mod 2^128, not mod
- # p).
- #
- # Because one of the inputs will have come from our
- # not-completely-reducing multiplication function, we expect up to
- # 3 extra bits of input.
- a = target.bigval_input("a", 133)
- b = target.bigval_input("b", 133)
- ret = a + b
- target.write_bigval("r", ret)
- return """\
- static void bigval_add(bigval *r, const bigval *a, const bigval *b)
- {
- %s}
- \n""" % target.text()
- def gen_mul(target):
- # The inputs are not 100% reduced mod p. Specifically, we can get
- # a full 130-bit number from the pow5==0 pass, and then a 130-bit
- # number times 5 from the pow5==1 pass, plus a possible carry. The
- # total of that can be easily bounded above by 2^130 * 8, so we
- # need to assume we're multiplying two 133-bit numbers.
- a = target.bigval_input("a", 133)
- b = target.bigval_input("b", 133)
- ab = a * b
- ab0 = ab.extract_bits(0, 130)
- ab1 = ab.extract_bits(130, 130)
- ab2 = ab.extract_bits(260)
- ab1_5 = target.const(5) * ab1
- ab2_25 = target.const(25) * ab2
- ret = ab0 + ab1_5 + ab2_25
- target.write_bigval("r", ret)
- return """\
- static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)
- {
- %s}
- \n""" % target.text()
- def gen_final_reduce(target):
- # Given our input number n, n >> 130 is usually precisely the
- # multiple of p that needs to be subtracted from n to reduce it to
- # strictly less than p, but it might be too low by 1 (but not more
- # than 1, given the range of our input is nowhere near the square
- # of the modulus). So we add another 5, which will push a carry
- # into the 130th bit if and only if that has happened, and then
- # use that to decide whether to subtract one more copy of p.
- a = target.bigval_input("n", 133)
- q = a.extract_bits(130)
- adjusted = a.extract_bits(0, 130) + target.const(5) * q
- final_subtract = (adjusted + target.const(5)).extract_bits(130)
- adjusted2 = adjusted + target.const(5) * final_subtract
- ret = adjusted2.extract_bits(0, 130)
- target.write_bigval("n", ret)
- return """\
- static void bigval_final_reduce(bigval *n)
- {
- %s}
- \n""" % target.text()
- pp_keyword = "#if"
- for bits in [16, 32, 64]:
- sys.stdout.write("%s BIGNUM_INT_BITS == %d\n\n" % (pp_keyword, bits))
- pp_keyword = "#elif"
- sys.stdout.write(gen_add(CodegenTarget(bits)))
- sys.stdout.write(gen_mul(CodegenTarget(bits)))
- sys.stdout.write(gen_final_reduce(CodegenTarget(bits)))
- sys.stdout.write("""#else
- #error Add another bit count to contrib/make1305.py and rerun it
- #endif
- """)
|