testcrypt.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. import sys
  2. import os
  3. import numbers
  4. import subprocess
  5. import re
  6. import string
  7. import struct
  8. from binascii import hexlify
  9. assert sys.version_info[:2] >= (3,0), "This is Python 3 code"
  10. # Expect to be run from the 'test' subdirectory, one level down from
  11. # the main source
  12. putty_srcdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  13. def coerce_to_bytes(arg):
  14. return arg.encode("UTF-8") if isinstance(arg, str) else arg
  15. class ChildProcessFailure(Exception):
  16. pass
  17. class ChildProcess(object):
  18. def __init__(self):
  19. self.sp = None
  20. self.debug = None
  21. self.exitstatus = None
  22. self.exception = None
  23. dbg = os.environ.get("PUTTY_TESTCRYPT_DEBUG")
  24. if dbg is not None:
  25. if dbg == "stderr":
  26. self.debug = sys.stderr
  27. else:
  28. sys.stderr.write("Unknown value '{}' for PUTTY_TESTCRYPT_DEBUG"
  29. " (try 'stderr'\n")
  30. def start(self):
  31. assert self.sp is None
  32. override_command = os.environ.get("PUTTY_TESTCRYPT")
  33. if override_command is None:
  34. cmd = [os.path.join(putty_srcdir, "testcrypt")]
  35. shell = False
  36. else:
  37. cmd = override_command
  38. shell = True
  39. self.sp = subprocess.Popen(
  40. cmd, shell=shell, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
  41. def write_line(self, line):
  42. if self.exception is not None:
  43. # Re-raise our fatal-error exception, if it previously
  44. # occurred in a context where it couldn't be propagated (a
  45. # __del__ method).
  46. raise self.exception
  47. if self.debug is not None:
  48. self.debug.write("send: {}\n".format(line))
  49. self.sp.stdin.write(line + b"\n")
  50. self.sp.stdin.flush()
  51. def read_line(self):
  52. line = self.sp.stdout.readline()
  53. if len(line) == 0:
  54. self.exception = ChildProcessFailure("received EOF from testcrypt")
  55. raise self.exception
  56. line = line.rstrip(b"\r\n")
  57. if self.debug is not None:
  58. self.debug.write("recv: {}\n".format(line))
  59. return line
  60. def already_terminated(self):
  61. return self.sp is None and self.exitstatus is not None
  62. def funcall(self, cmd, args):
  63. if self.sp is None:
  64. assert self.exitstatus is None
  65. self.start()
  66. self.write_line(coerce_to_bytes(cmd) + b" " + b" ".join(
  67. coerce_to_bytes(arg) for arg in args))
  68. argcount = int(self.read_line())
  69. return [self.read_line() for arg in range(argcount)]
  70. def wait_for_exit(self):
  71. if self.sp is not None:
  72. self.sp.stdin.close()
  73. self.exitstatus = self.sp.wait()
  74. self.sp = None
  75. def check_return_status(self):
  76. self.wait_for_exit()
  77. if self.exitstatus is not None and self.exitstatus != 0:
  78. raise ChildProcessFailure("testcrypt returned exit status {}"
  79. .format(self.exitstatus))
  80. childprocess = ChildProcess()
  81. method_prefixes = {
  82. 'val_wpoint': ['ecc_weierstrass_'],
  83. 'val_mpoint': ['ecc_montgomery_'],
  84. 'val_epoint': ['ecc_edwards_'],
  85. 'val_hash': ['ssh_hash_'],
  86. 'val_mac': ['ssh2_mac_'],
  87. 'val_key': ['ssh_key_'],
  88. 'val_cipher': ['ssh_cipher_'],
  89. 'val_dh': ['dh_'],
  90. 'val_ecdh': ['ssh_ecdhkex_'],
  91. 'val_rsakex': ['ssh_rsakex_'],
  92. 'val_prng': ['prng_'],
  93. 'val_pcs': ['pcs_'],
  94. 'val_pockle': ['pockle_'],
  95. 'val_ntruencodeschedule': ['ntru_encode_schedule_', 'ntru_'],
  96. }
  97. method_lists = {t: [] for t in method_prefixes}
  98. checked_enum_values = {}
  99. class Value(object):
  100. def __init__(self, typename, ident):
  101. self._typename = typename
  102. self._ident = ident
  103. for methodname, function in method_lists.get(self._typename, []):
  104. setattr(self, methodname,
  105. (lambda f: lambda *args: f(self, *args))(function))
  106. def _consumed(self):
  107. self._ident = None
  108. def __repr__(self):
  109. return "Value({!r}, {!r})".format(self._typename, self._ident)
  110. def __del__(self):
  111. if self._ident is not None and not childprocess.already_terminated():
  112. try:
  113. childprocess.funcall("free", [self._ident])
  114. except ChildProcessFailure:
  115. # If we see this exception now, we can't do anything
  116. # about it, because exceptions don't propagate out of
  117. # __del__ methods. Squelch it to prevent the annoying
  118. # runtime warning from Python, and the
  119. # 'self.exception' mechanism in the ChildProcess class
  120. # will raise it again at the next opportunity.
  121. #
  122. # (This covers both the case where testcrypt crashes
  123. # _during_ one of these free operations, and the
  124. # silencing of cascade failures when we try to send a
  125. # "free" command to testcrypt after it had already
  126. # crashed for some other reason.)
  127. pass
  128. def __long__(self):
  129. if self._typename != "val_mpint":
  130. raise TypeError("testcrypt values of types other than mpint"
  131. " cannot be converted to integer")
  132. hexval = childprocess.funcall("mp_dump", [self._ident])[0]
  133. return 0 if len(hexval) == 0 else int(hexval, 16)
  134. def __int__(self):
  135. return int(self.__long__())
  136. def marshal_string(val):
  137. val = coerce_to_bytes(val)
  138. assert isinstance(val, bytes), "Bad type for val_string input"
  139. return "".join(
  140. chr(b) if (0x20 <= b < 0x7F and b != 0x25)
  141. else "%{:02x}".format(b)
  142. for b in val)
  143. def make_argword(arg, argtype, fnname, argindex, argname, to_preserve):
  144. typename, consumed = argtype
  145. if typename.startswith("opt_"):
  146. if arg is None:
  147. return "NULL"
  148. typename = typename[4:]
  149. if typename == "val_string":
  150. retwords = childprocess.funcall("newstring", [marshal_string(arg)])
  151. arg = make_retvals([typename], retwords, unpack_strings=False)[0]
  152. to_preserve.append(arg)
  153. if typename == "val_mpint" and isinstance(arg, numbers.Integral):
  154. retwords = childprocess.funcall("mp_literal", ["0x{:x}".format(arg)])
  155. arg = make_retvals([typename], retwords)[0]
  156. to_preserve.append(arg)
  157. if isinstance(arg, Value):
  158. if arg._typename != typename:
  159. raise TypeError(
  160. "{}() argument #{:d} ({}) should be {} ({} given)".format(
  161. fnname, argindex, argname, typename, arg._typename))
  162. ident = arg._ident
  163. if consumed:
  164. arg._consumed()
  165. return ident
  166. if typename == "uint" and isinstance(arg, numbers.Integral):
  167. return "0x{:x}".format(arg)
  168. if typename == "boolean":
  169. return "true" if arg else "false"
  170. if typename in {
  171. "hashalg", "macalg", "keyalg", "cipheralg",
  172. "dh_group", "ecdh_alg", "rsaorder", "primegenpolicy",
  173. "argon2flavour", "fptype", "httpdigesthash"}:
  174. arg = coerce_to_bytes(arg)
  175. if isinstance(arg, bytes) and b" " not in arg:
  176. dictkey = (typename, arg)
  177. if dictkey not in checked_enum_values:
  178. retwords = childprocess.funcall("checkenum", [typename, arg])
  179. assert len(retwords) == 1
  180. checked_enum_values[dictkey] = (retwords[0] == b"ok")
  181. if checked_enum_values[dictkey]:
  182. return arg
  183. if typename == "mpint_list":
  184. sublist = [make_argword(len(arg), ("uint", False),
  185. fnname, argindex, argname, to_preserve)]
  186. for val in arg:
  187. sublist.append(make_argword(val, ("val_mpint", False),
  188. fnname, argindex, argname, to_preserve))
  189. return b" ".join(coerce_to_bytes(sub) for sub in sublist)
  190. if typename == "int16_list":
  191. sublist = [make_argword(len(arg), ("uint", False),
  192. fnname, argindex, argname, to_preserve)]
  193. for val in arg:
  194. sublist.append(make_argword(val & 0xFFFF, ("uint", False),
  195. fnname, argindex, argname, to_preserve))
  196. return b" ".join(coerce_to_bytes(sub) for sub in sublist)
  197. raise TypeError(
  198. "Can't convert {}() argument #{:d} ({}) to {} (value was {!r})".format(
  199. fnname, argindex, argname, typename, arg))
  200. def unpack_string(identifier):
  201. retwords = childprocess.funcall("getstring", [identifier])
  202. childprocess.funcall("free", [identifier])
  203. return re.sub(b"%[0-9A-F][0-9A-F]",
  204. lambda m: bytes([int(m.group(0)[1:], 16)]),
  205. retwords[0])
  206. def unpack_mp(identifier):
  207. retwords = childprocess.funcall("mp_dump", [identifier])
  208. childprocess.funcall("free", [identifier])
  209. return int(retwords[0], 16)
  210. def make_retval(rettype, word, unpack_strings):
  211. if rettype.startswith("opt_"):
  212. if word == b"NULL":
  213. return None
  214. rettype = rettype[4:]
  215. if rettype == "val_string" and unpack_strings:
  216. return unpack_string(word)
  217. if rettype == "val_keycomponents":
  218. kc = {}
  219. retwords = childprocess.funcall("key_components_count", [word])
  220. for i in range(int(retwords[0], 0)):
  221. args = [word, "{:d}".format(i)]
  222. retwords = childprocess.funcall("key_components_nth_name", args)
  223. kc_key = unpack_string(retwords[0])
  224. retwords = childprocess.funcall("key_components_nth_str", args)
  225. if retwords[0] != b"NULL":
  226. kc_value = unpack_string(retwords[0]).decode("ASCII")
  227. else:
  228. retwords = childprocess.funcall("key_components_nth_mp", args)
  229. kc_value = unpack_mp(retwords[0])
  230. kc[kc_key.decode("ASCII")] = kc_value
  231. childprocess.funcall("free", [word])
  232. return kc
  233. if rettype.startswith("val_"):
  234. return Value(rettype, word)
  235. elif rettype == "int" or rettype == "uint":
  236. return int(word, 0)
  237. elif rettype == "boolean":
  238. assert word == b"true" or word == b"false"
  239. return word == b"true"
  240. elif rettype in {"pocklestatus", "mr_result"}:
  241. return word.decode("ASCII")
  242. elif rettype == "int16_list":
  243. return list(map(int, word.split(b',')))
  244. raise TypeError("Can't deal with return value {!r} of type {!r}"
  245. .format(word, rettype))
  246. def make_retvals(rettypes, retwords, unpack_strings=True):
  247. assert len(rettypes) == len(retwords) # FIXME: better exception
  248. return [make_retval(rettype, word, unpack_strings)
  249. for rettype, word in zip(rettypes, retwords)]
  250. class Function(object):
  251. def __init__(self, fnname, rettypes, retnames, argtypes, argnames):
  252. self.fnname = fnname
  253. self.rettypes = rettypes
  254. self.retnames = retnames
  255. self.argtypes = argtypes
  256. self.argnames = argnames
  257. def __repr__(self):
  258. return "<Function {}({}) -> ({})>".format(
  259. self.fnname,
  260. ", ".join(("consumed " if c else "")+t+" "+n
  261. for (t,c),n in zip(self.argtypes, self.argnames)),
  262. ", ".join((t+" "+n if n is not None else t)
  263. for t,n in zip(self.rettypes, self.retnames)),
  264. )
  265. def __call__(self, *args):
  266. if len(args) != len(self.argtypes):
  267. raise TypeError(
  268. "{}() takes exactly {} arguments ({} given)".format(
  269. self.fnname, len(self.argtypes), len(args)))
  270. to_preserve = []
  271. retwords = childprocess.funcall(
  272. self.fnname, [make_argword(args[i], self.argtypes[i],
  273. self.fnname, i, self.argnames[i],
  274. to_preserve)
  275. for i in range(len(args))])
  276. retvals = make_retvals(self.rettypes, retwords)
  277. if len(retvals) == 0:
  278. return None
  279. if len(retvals) == 1:
  280. return retvals[0]
  281. return tuple(retvals)
  282. def _lex_testcrypt_header(header):
  283. pat = re.compile(
  284. # Skip any combination of whitespace and comments
  285. '(?:{})*'.format('|'.join((
  286. '[ \t\n]', # whitespace
  287. '/\\*(?:.|\n)*?\\*/', # C90-style /* ... */ comment, ended eagerly
  288. '//[^\n]*\n', # C99-style comment to end-of-line
  289. ))) +
  290. # And then match a token
  291. '({})'.format('|'.join((
  292. # Punctuation
  293. r'\(',
  294. r'\)',
  295. ',',
  296. # Identifier
  297. '[A-Za-z_][A-Za-z0-9_]*',
  298. # End of string
  299. '$',
  300. )))
  301. )
  302. pos = 0
  303. end = len(header)
  304. while pos < end:
  305. m = pat.match(header, pos)
  306. assert m is not None, (
  307. "Failed to lex testcrypt-func.h at byte position {:d}".format(pos))
  308. pos = m.end()
  309. tok = m.group(1)
  310. if len(tok) == 0:
  311. assert pos == end, (
  312. "Empty token should only be returned at end of string")
  313. yield tok, m.start(1)
  314. def _parse_testcrypt_header(tokens):
  315. def is_id(tok):
  316. return tok[0] in string.ascii_letters+"_"
  317. def expect(what, why, eof_ok=False):
  318. tok, pos = next(tokens)
  319. if tok == '' and eof_ok:
  320. return None
  321. if hasattr(what, '__call__'):
  322. description = lambda: ""
  323. ok = what(tok)
  324. elif isinstance(what, set):
  325. description = lambda: " or ".join("'"+x+"' " for x in sorted(what))
  326. ok = tok in what
  327. else:
  328. description = lambda: "'"+what+"' "
  329. ok = tok == what
  330. if not ok:
  331. sys.exit("testcrypt-func.h:{:d}: expected {}{}".format(
  332. pos, description(), why))
  333. return tok
  334. while True:
  335. tok = expect({"FUNC", "FUNC_WRAPPED"},
  336. "at start of function specification", eof_ok=True)
  337. if tok is None:
  338. break
  339. expect("(", "after FUNC")
  340. rettype = expect(is_id, "return type")
  341. expect(",", "after return type")
  342. funcname = expect(is_id, "function name")
  343. expect(",", "after function name")
  344. args = []
  345. firstargkind = expect({"ARG", "VOID"}, "at start of argument list")
  346. if firstargkind == "VOID":
  347. expect(")", "after VOID")
  348. else:
  349. while True:
  350. # Every time we come back to the top of this loop, we've
  351. # just seen 'ARG'
  352. expect("(", "after ARG")
  353. argtype = expect(is_id, "argument type")
  354. expect(",", "after argument type")
  355. argname = expect(is_id, "argument name")
  356. args.append((argtype, argname))
  357. expect(")", "at end of ARG")
  358. punct = expect({",", ")"}, "after argument")
  359. if punct == ")":
  360. break
  361. expect("ARG", "to begin next argument")
  362. yield funcname, rettype, args
  363. def _setup(scope):
  364. valprefix = "val_"
  365. outprefix = "out_"
  366. optprefix = "opt_"
  367. consprefix = "consumed_"
  368. def trim_argtype(arg):
  369. if arg.startswith(optprefix):
  370. return optprefix + trim_argtype(arg[len(optprefix):])
  371. if (arg.startswith(valprefix) and
  372. "_" in arg[len(valprefix):]):
  373. # Strip suffixes like val_string_asciz
  374. arg = arg[:arg.index("_", len(valprefix))]
  375. return arg
  376. with open(os.path.join(putty_srcdir, "test", "testcrypt-func.h")) as f:
  377. header = f.read()
  378. tokens = _lex_testcrypt_header(header)
  379. for function, rettype, arglist in _parse_testcrypt_header(tokens):
  380. rettypes = []
  381. retnames = []
  382. if rettype != "void":
  383. rettypes.append(trim_argtype(rettype))
  384. retnames.append(None)
  385. argtypes = []
  386. argnames = []
  387. argsconsumed = []
  388. for arg, argname in arglist:
  389. if arg.startswith(outprefix):
  390. rettypes.append(trim_argtype(arg[len(outprefix):]))
  391. retnames.append(argname)
  392. else:
  393. consumed = False
  394. if arg.startswith(consprefix):
  395. arg = arg[len(consprefix):]
  396. consumed = True
  397. arg = trim_argtype(arg)
  398. argtypes.append((arg, consumed))
  399. argnames.append(argname)
  400. func = Function(function, rettypes, retnames,
  401. argtypes, argnames)
  402. scope[function] = func
  403. if len(argtypes) > 0:
  404. t = argtypes[0][0]
  405. if t in method_prefixes:
  406. for prefix in method_prefixes[t]:
  407. if function.startswith(prefix):
  408. methodname = function[len(prefix):]
  409. method_lists[t].append((methodname, func))
  410. break
  411. _setup(globals())
  412. del _setup