reinterpret.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. """
  2. Find intermediate evalutation results in assert statements through builtin AST.
  3. """
  4. import ast
  5. import sys
  6. import _pytest._code
  7. import py
  8. from _pytest.assertion import util
  9. u = py.builtin._totext
  10. class AssertionError(util.BuiltinAssertionError):
  11. def __init__(self, *args):
  12. util.BuiltinAssertionError.__init__(self, *args)
  13. if args:
  14. # on Python2.6 we get len(args)==2 for: assert 0, (x,y)
  15. # on Python2.7 and above we always get len(args) == 1
  16. # with args[0] being the (x,y) tuple.
  17. if len(args) > 1:
  18. toprint = args
  19. else:
  20. toprint = args[0]
  21. try:
  22. self.msg = u(toprint)
  23. except Exception:
  24. self.msg = u(
  25. "<[broken __repr__] %s at %0xd>"
  26. % (toprint.__class__, id(toprint)))
  27. else:
  28. f = _pytest._code.Frame(sys._getframe(1))
  29. try:
  30. source = f.code.fullsource
  31. if source is not None:
  32. try:
  33. source = source.getstatement(f.lineno, assertion=True)
  34. except IndexError:
  35. source = None
  36. else:
  37. source = str(source.deindent()).strip()
  38. except py.error.ENOENT:
  39. source = None
  40. # this can also occur during reinterpretation, when the
  41. # co_filename is set to "<run>".
  42. if source:
  43. self.msg = reinterpret(source, f, should_fail=True)
  44. else:
  45. self.msg = "<could not determine information>"
  46. if not self.args:
  47. self.args = (self.msg,)
  48. if sys.version_info > (3, 0):
  49. AssertionError.__module__ = "builtins"
  50. if sys.platform.startswith("java"):
  51. # See http://bugs.jython.org/issue1497
  52. _exprs = ("BoolOp", "BinOp", "UnaryOp", "Lambda", "IfExp", "Dict",
  53. "ListComp", "GeneratorExp", "Yield", "Compare", "Call",
  54. "Repr", "Num", "Str", "Attribute", "Subscript", "Name",
  55. "List", "Tuple")
  56. _stmts = ("FunctionDef", "ClassDef", "Return", "Delete", "Assign",
  57. "AugAssign", "Print", "For", "While", "If", "With", "Raise",
  58. "TryExcept", "TryFinally", "Assert", "Import", "ImportFrom",
  59. "Exec", "Global", "Expr", "Pass", "Break", "Continue")
  60. _expr_nodes = set(getattr(ast, name) for name in _exprs)
  61. _stmt_nodes = set(getattr(ast, name) for name in _stmts)
  62. def _is_ast_expr(node):
  63. return node.__class__ in _expr_nodes
  64. def _is_ast_stmt(node):
  65. return node.__class__ in _stmt_nodes
  66. else:
  67. def _is_ast_expr(node):
  68. return isinstance(node, ast.expr)
  69. def _is_ast_stmt(node):
  70. return isinstance(node, ast.stmt)
  71. try:
  72. _Starred = ast.Starred
  73. except AttributeError:
  74. # Python 2. Define a dummy class so isinstance() will always be False.
  75. class _Starred(object): pass
  76. class Failure(Exception):
  77. """Error found while interpreting AST."""
  78. def __init__(self, explanation=""):
  79. self.cause = sys.exc_info()
  80. self.explanation = explanation
  81. def reinterpret(source, frame, should_fail=False):
  82. mod = ast.parse(source)
  83. visitor = DebugInterpreter(frame)
  84. try:
  85. visitor.visit(mod)
  86. except Failure:
  87. failure = sys.exc_info()[1]
  88. return getfailure(failure)
  89. if should_fail:
  90. return ("(assertion failed, but when it was re-run for "
  91. "printing intermediate values, it did not fail. Suggestions: "
  92. "compute assert expression before the assert or use --assert=plain)")
  93. def run(offending_line, frame=None):
  94. if frame is None:
  95. frame = _pytest._code.Frame(sys._getframe(1))
  96. return reinterpret(offending_line, frame)
  97. def getfailure(e):
  98. explanation = util.format_explanation(e.explanation)
  99. value = e.cause[1]
  100. if str(value):
  101. lines = explanation.split('\n')
  102. lines[0] += " << %s" % (value,)
  103. explanation = '\n'.join(lines)
  104. text = "%s: %s" % (e.cause[0].__name__, explanation)
  105. if text.startswith('AssertionError: assert '):
  106. text = text[16:]
  107. return text
  108. operator_map = {
  109. ast.BitOr : "|",
  110. ast.BitXor : "^",
  111. ast.BitAnd : "&",
  112. ast.LShift : "<<",
  113. ast.RShift : ">>",
  114. ast.Add : "+",
  115. ast.Sub : "-",
  116. ast.Mult : "*",
  117. ast.Div : "/",
  118. ast.FloorDiv : "//",
  119. ast.Mod : "%",
  120. ast.Eq : "==",
  121. ast.NotEq : "!=",
  122. ast.Lt : "<",
  123. ast.LtE : "<=",
  124. ast.Gt : ">",
  125. ast.GtE : ">=",
  126. ast.Pow : "**",
  127. ast.Is : "is",
  128. ast.IsNot : "is not",
  129. ast.In : "in",
  130. ast.NotIn : "not in"
  131. }
  132. unary_map = {
  133. ast.Not : "not %s",
  134. ast.Invert : "~%s",
  135. ast.USub : "-%s",
  136. ast.UAdd : "+%s"
  137. }
  138. class DebugInterpreter(ast.NodeVisitor):
  139. """Interpret AST nodes to gleam useful debugging information. """
  140. def __init__(self, frame):
  141. self.frame = frame
  142. def generic_visit(self, node):
  143. # Fallback when we don't have a special implementation.
  144. if _is_ast_expr(node):
  145. mod = ast.Expression(node)
  146. co = self._compile(mod)
  147. try:
  148. result = self.frame.eval(co)
  149. except Exception:
  150. raise Failure()
  151. explanation = self.frame.repr(result)
  152. return explanation, result
  153. elif _is_ast_stmt(node):
  154. mod = ast.Module([node])
  155. co = self._compile(mod, "exec")
  156. try:
  157. self.frame.exec_(co)
  158. except Exception:
  159. raise Failure()
  160. return None, None
  161. else:
  162. raise AssertionError("can't handle %s" %(node,))
  163. def _compile(self, source, mode="eval"):
  164. return compile(source, "<assertion interpretation>", mode)
  165. def visit_Expr(self, expr):
  166. return self.visit(expr.value)
  167. def visit_Module(self, mod):
  168. for stmt in mod.body:
  169. self.visit(stmt)
  170. def visit_Name(self, name):
  171. explanation, result = self.generic_visit(name)
  172. # See if the name is local.
  173. source = "%r in locals() is not globals()" % (name.id,)
  174. co = self._compile(source)
  175. try:
  176. local = self.frame.eval(co)
  177. except Exception:
  178. # have to assume it isn't
  179. local = None
  180. if local is None or not self.frame.is_true(local):
  181. return name.id, result
  182. return explanation, result
  183. def visit_Compare(self, comp):
  184. left = comp.left
  185. left_explanation, left_result = self.visit(left)
  186. for op, next_op in zip(comp.ops, comp.comparators):
  187. next_explanation, next_result = self.visit(next_op)
  188. op_symbol = operator_map[op.__class__]
  189. explanation = "%s %s %s" % (left_explanation, op_symbol,
  190. next_explanation)
  191. source = "__exprinfo_left %s __exprinfo_right" % (op_symbol,)
  192. co = self._compile(source)
  193. try:
  194. result = self.frame.eval(co, __exprinfo_left=left_result,
  195. __exprinfo_right=next_result)
  196. except Exception:
  197. raise Failure(explanation)
  198. try:
  199. if not self.frame.is_true(result):
  200. break
  201. except KeyboardInterrupt:
  202. raise
  203. except:
  204. break
  205. left_explanation, left_result = next_explanation, next_result
  206. if util._reprcompare is not None:
  207. res = util._reprcompare(op_symbol, left_result, next_result)
  208. if res:
  209. explanation = res
  210. return explanation, result
  211. def visit_BoolOp(self, boolop):
  212. is_or = isinstance(boolop.op, ast.Or)
  213. explanations = []
  214. for operand in boolop.values:
  215. explanation, result = self.visit(operand)
  216. explanations.append(explanation)
  217. if result == is_or:
  218. break
  219. name = is_or and " or " or " and "
  220. explanation = "(" + name.join(explanations) + ")"
  221. return explanation, result
  222. def visit_UnaryOp(self, unary):
  223. pattern = unary_map[unary.op.__class__]
  224. operand_explanation, operand_result = self.visit(unary.operand)
  225. explanation = pattern % (operand_explanation,)
  226. co = self._compile(pattern % ("__exprinfo_expr",))
  227. try:
  228. result = self.frame.eval(co, __exprinfo_expr=operand_result)
  229. except Exception:
  230. raise Failure(explanation)
  231. return explanation, result
  232. def visit_BinOp(self, binop):
  233. left_explanation, left_result = self.visit(binop.left)
  234. right_explanation, right_result = self.visit(binop.right)
  235. symbol = operator_map[binop.op.__class__]
  236. explanation = "(%s %s %s)" % (left_explanation, symbol,
  237. right_explanation)
  238. source = "__exprinfo_left %s __exprinfo_right" % (symbol,)
  239. co = self._compile(source)
  240. try:
  241. result = self.frame.eval(co, __exprinfo_left=left_result,
  242. __exprinfo_right=right_result)
  243. except Exception:
  244. raise Failure(explanation)
  245. return explanation, result
  246. def visit_Call(self, call):
  247. func_explanation, func = self.visit(call.func)
  248. arg_explanations = []
  249. ns = {"__exprinfo_func" : func}
  250. arguments = []
  251. for arg in call.args:
  252. arg_explanation, arg_result = self.visit(arg)
  253. if isinstance(arg, _Starred):
  254. arg_name = "__exprinfo_star"
  255. ns[arg_name] = arg_result
  256. arguments.append("*%s" % (arg_name,))
  257. arg_explanations.append("*%s" % (arg_explanation,))
  258. else:
  259. arg_name = "__exprinfo_%s" % (len(ns),)
  260. ns[arg_name] = arg_result
  261. arguments.append(arg_name)
  262. arg_explanations.append(arg_explanation)
  263. for keyword in call.keywords:
  264. arg_explanation, arg_result = self.visit(keyword.value)
  265. if keyword.arg:
  266. arg_name = "__exprinfo_%s" % (len(ns),)
  267. keyword_source = "%s=%%s" % (keyword.arg)
  268. arguments.append(keyword_source % (arg_name,))
  269. arg_explanations.append(keyword_source % (arg_explanation,))
  270. else:
  271. arg_name = "__exprinfo_kwds"
  272. arguments.append("**%s" % (arg_name,))
  273. arg_explanations.append("**%s" % (arg_explanation,))
  274. ns[arg_name] = arg_result
  275. if getattr(call, 'starargs', None):
  276. arg_explanation, arg_result = self.visit(call.starargs)
  277. arg_name = "__exprinfo_star"
  278. ns[arg_name] = arg_result
  279. arguments.append("*%s" % (arg_name,))
  280. arg_explanations.append("*%s" % (arg_explanation,))
  281. if getattr(call, 'kwargs', None):
  282. arg_explanation, arg_result = self.visit(call.kwargs)
  283. arg_name = "__exprinfo_kwds"
  284. ns[arg_name] = arg_result
  285. arguments.append("**%s" % (arg_name,))
  286. arg_explanations.append("**%s" % (arg_explanation,))
  287. args_explained = ", ".join(arg_explanations)
  288. explanation = "%s(%s)" % (func_explanation, args_explained)
  289. args = ", ".join(arguments)
  290. source = "__exprinfo_func(%s)" % (args,)
  291. co = self._compile(source)
  292. try:
  293. result = self.frame.eval(co, **ns)
  294. except Exception:
  295. raise Failure(explanation)
  296. pattern = "%s\n{%s = %s\n}"
  297. rep = self.frame.repr(result)
  298. explanation = pattern % (rep, rep, explanation)
  299. return explanation, result
  300. def _is_builtin_name(self, name):
  301. pattern = "%r not in globals() and %r not in locals()"
  302. source = pattern % (name.id, name.id)
  303. co = self._compile(source)
  304. try:
  305. return self.frame.eval(co)
  306. except Exception:
  307. return False
  308. def visit_Attribute(self, attr):
  309. if not isinstance(attr.ctx, ast.Load):
  310. return self.generic_visit(attr)
  311. source_explanation, source_result = self.visit(attr.value)
  312. explanation = "%s.%s" % (source_explanation, attr.attr)
  313. source = "__exprinfo_expr.%s" % (attr.attr,)
  314. co = self._compile(source)
  315. try:
  316. try:
  317. result = self.frame.eval(co, __exprinfo_expr=source_result)
  318. except AttributeError:
  319. # Maybe the attribute name needs to be mangled?
  320. if not attr.attr.startswith("__") or attr.attr.endswith("__"):
  321. raise
  322. source = "getattr(__exprinfo_expr.__class__, '__name__', '')"
  323. co = self._compile(source)
  324. class_name = self.frame.eval(co, __exprinfo_expr=source_result)
  325. mangled_attr = "_" + class_name + attr.attr
  326. source = "__exprinfo_expr.%s" % (mangled_attr,)
  327. co = self._compile(source)
  328. result = self.frame.eval(co, __exprinfo_expr=source_result)
  329. except Exception:
  330. raise Failure(explanation)
  331. explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result),
  332. self.frame.repr(result),
  333. source_explanation, attr.attr)
  334. # Check if the attr is from an instance.
  335. source = "%r in getattr(__exprinfo_expr, '__dict__', {})"
  336. source = source % (attr.attr,)
  337. co = self._compile(source)
  338. try:
  339. from_instance = self.frame.eval(co, __exprinfo_expr=source_result)
  340. except Exception:
  341. from_instance = None
  342. if from_instance is None or self.frame.is_true(from_instance):
  343. rep = self.frame.repr(result)
  344. pattern = "%s\n{%s = %s\n}"
  345. explanation = pattern % (rep, rep, explanation)
  346. return explanation, result
  347. def visit_Assert(self, assrt):
  348. test_explanation, test_result = self.visit(assrt.test)
  349. explanation = "assert %s" % (test_explanation,)
  350. if not self.frame.is_true(test_result):
  351. try:
  352. raise util.BuiltinAssertionError
  353. except Exception:
  354. raise Failure(explanation)
  355. return explanation, test_result
  356. def visit_Assign(self, assign):
  357. value_explanation, value_result = self.visit(assign.value)
  358. explanation = "... = %s" % (value_explanation,)
  359. name = ast.Name("__exprinfo_expr", ast.Load(),
  360. lineno=assign.value.lineno,
  361. col_offset=assign.value.col_offset)
  362. new_assign = ast.Assign(assign.targets, name, lineno=assign.lineno,
  363. col_offset=assign.col_offset)
  364. mod = ast.Module([new_assign])
  365. co = self._compile(mod, "exec")
  366. try:
  367. self.frame.exec_(co, __exprinfo_expr=value_result)
  368. except Exception:
  369. raise Failure(explanation)
  370. return explanation, value_result