pprint.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. """Utilities for pretty-printing IR in a human-readable form."""
  2. from __future__ import annotations
  3. from collections import defaultdict
  4. from typing import Any, Final, Sequence, Union
  5. from mypyc.common import short_name
  6. from mypyc.ir.func_ir import FuncIR, all_values_full
  7. from mypyc.ir.module_ir import ModuleIRs
  8. from mypyc.ir.ops import (
  9. ERR_NEVER,
  10. Assign,
  11. AssignMulti,
  12. BasicBlock,
  13. Box,
  14. Branch,
  15. Call,
  16. CallC,
  17. Cast,
  18. ComparisonOp,
  19. ControlOp,
  20. DecRef,
  21. Extend,
  22. Float,
  23. FloatComparisonOp,
  24. FloatNeg,
  25. FloatOp,
  26. GetAttr,
  27. GetElementPtr,
  28. Goto,
  29. IncRef,
  30. InitStatic,
  31. Integer,
  32. IntOp,
  33. KeepAlive,
  34. LoadAddress,
  35. LoadErrorValue,
  36. LoadGlobal,
  37. LoadLiteral,
  38. LoadMem,
  39. LoadStatic,
  40. MethodCall,
  41. Op,
  42. OpVisitor,
  43. RaiseStandardError,
  44. Register,
  45. Return,
  46. SetAttr,
  47. SetMem,
  48. Truncate,
  49. TupleGet,
  50. TupleSet,
  51. Unbox,
  52. Unreachable,
  53. Value,
  54. )
  55. from mypyc.ir.rtypes import RType, is_bool_rprimitive, is_int_rprimitive
  56. ErrorSource = Union[BasicBlock, Op]
  57. class IRPrettyPrintVisitor(OpVisitor[str]):
  58. """Internal visitor that pretty-prints ops."""
  59. def __init__(self, names: dict[Value, str]) -> None:
  60. # This should contain a name for all values that are shown as
  61. # registers in the output. This is not just for Register
  62. # instances -- all Ops that produce values need (generated) names.
  63. self.names = names
  64. def visit_goto(self, op: Goto) -> str:
  65. return self.format("goto %l", op.label)
  66. branch_op_names: Final = {Branch.BOOL: ("%r", "bool"), Branch.IS_ERROR: ("is_error(%r)", "")}
  67. def visit_branch(self, op: Branch) -> str:
  68. fmt, typ = self.branch_op_names[op.op]
  69. if op.negated:
  70. fmt = f"not {fmt}"
  71. cond = self.format(fmt, op.value)
  72. tb = ""
  73. if op.traceback_entry:
  74. tb = " (error at %s:%d)" % op.traceback_entry
  75. fmt = f"if {cond} goto %l{tb} else goto %l"
  76. if typ:
  77. fmt += f" :: {typ}"
  78. return self.format(fmt, op.true, op.false)
  79. def visit_return(self, op: Return) -> str:
  80. return self.format("return %r", op.value)
  81. def visit_unreachable(self, op: Unreachable) -> str:
  82. return "unreachable"
  83. def visit_assign(self, op: Assign) -> str:
  84. return self.format("%r = %r", op.dest, op.src)
  85. def visit_assign_multi(self, op: AssignMulti) -> str:
  86. return self.format("%r = [%s]", op.dest, ", ".join(self.format("%r", v) for v in op.src))
  87. def visit_load_error_value(self, op: LoadErrorValue) -> str:
  88. return self.format("%r = <error> :: %s", op, op.type)
  89. def visit_load_literal(self, op: LoadLiteral) -> str:
  90. prefix = ""
  91. # For values that have a potential unboxed representation, make
  92. # it explicit that this is a Python object.
  93. if isinstance(op.value, int):
  94. prefix = "object "
  95. rvalue = repr(op.value)
  96. if isinstance(op.value, frozenset):
  97. # We need to generate a string representation that won't vary
  98. # run-to-run because sets are unordered, otherwise we may get
  99. # spurious irbuild test failures.
  100. #
  101. # Sorting by the item's string representation is a bit of a
  102. # hack, but it's stable and won't cause TypeErrors.
  103. formatted_items = [repr(i) for i in sorted(op.value, key=str)]
  104. rvalue = "frozenset({" + ", ".join(formatted_items) + "})"
  105. return self.format("%r = %s%s", op, prefix, rvalue)
  106. def visit_get_attr(self, op: GetAttr) -> str:
  107. return self.format("%r = %s%r.%s", op, self.borrow_prefix(op), op.obj, op.attr)
  108. def borrow_prefix(self, op: Op) -> str:
  109. if op.is_borrowed:
  110. return "borrow "
  111. return ""
  112. def visit_set_attr(self, op: SetAttr) -> str:
  113. if op.is_init:
  114. assert op.error_kind == ERR_NEVER
  115. if op.error_kind == ERR_NEVER:
  116. # Initialization and direct struct access can never fail
  117. return self.format("%r.%s = %r", op.obj, op.attr, op.src)
  118. else:
  119. return self.format("%r.%s = %r; %r = is_error", op.obj, op.attr, op.src, op)
  120. def visit_load_static(self, op: LoadStatic) -> str:
  121. ann = f" ({repr(op.ann)})" if op.ann else ""
  122. name = op.identifier
  123. if op.module_name is not None:
  124. name = f"{op.module_name}.{name}"
  125. return self.format("%r = %s :: %s%s", op, name, op.namespace, ann)
  126. def visit_init_static(self, op: InitStatic) -> str:
  127. name = op.identifier
  128. if op.module_name is not None:
  129. name = f"{op.module_name}.{name}"
  130. return self.format("%s = %r :: %s", name, op.value, op.namespace)
  131. def visit_tuple_get(self, op: TupleGet) -> str:
  132. return self.format("%r = %r[%d]", op, op.src, op.index)
  133. def visit_tuple_set(self, op: TupleSet) -> str:
  134. item_str = ", ".join(self.format("%r", item) for item in op.items)
  135. return self.format("%r = (%s)", op, item_str)
  136. def visit_inc_ref(self, op: IncRef) -> str:
  137. s = self.format("inc_ref %r", op.src)
  138. # TODO: Remove bool check (it's unboxed)
  139. if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type):
  140. s += f" :: {short_name(op.src.type.name)}"
  141. return s
  142. def visit_dec_ref(self, op: DecRef) -> str:
  143. s = self.format("%sdec_ref %r", "x" if op.is_xdec else "", op.src)
  144. # TODO: Remove bool check (it's unboxed)
  145. if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type):
  146. s += f" :: {short_name(op.src.type.name)}"
  147. return s
  148. def visit_call(self, op: Call) -> str:
  149. args = ", ".join(self.format("%r", arg) for arg in op.args)
  150. # TODO: Display long name?
  151. short_name = op.fn.shortname
  152. s = f"{short_name}({args})"
  153. if not op.is_void:
  154. s = self.format("%r = ", op) + s
  155. return s
  156. def visit_method_call(self, op: MethodCall) -> str:
  157. args = ", ".join(self.format("%r", arg) for arg in op.args)
  158. s = self.format("%r.%s(%s)", op.obj, op.method, args)
  159. if not op.is_void:
  160. s = self.format("%r = ", op) + s
  161. return s
  162. def visit_cast(self, op: Cast) -> str:
  163. return self.format("%r = %scast(%s, %r)", op, self.borrow_prefix(op), op.type, op.src)
  164. def visit_box(self, op: Box) -> str:
  165. return self.format("%r = box(%s, %r)", op, op.src.type, op.src)
  166. def visit_unbox(self, op: Unbox) -> str:
  167. return self.format("%r = unbox(%s, %r)", op, op.type, op.src)
  168. def visit_raise_standard_error(self, op: RaiseStandardError) -> str:
  169. if op.value is not None:
  170. if isinstance(op.value, str):
  171. return self.format("%r = raise %s(%s)", op, op.class_name, repr(op.value))
  172. elif isinstance(op.value, Value):
  173. return self.format("%r = raise %s(%r)", op, op.class_name, op.value)
  174. else:
  175. assert False, "value type must be either str or Value"
  176. else:
  177. return self.format("%r = raise %s", op, op.class_name)
  178. def visit_call_c(self, op: CallC) -> str:
  179. args_str = ", ".join(self.format("%r", arg) for arg in op.args)
  180. if op.is_void:
  181. return self.format("%s(%s)", op.function_name, args_str)
  182. else:
  183. return self.format("%r = %s(%s)", op, op.function_name, args_str)
  184. def visit_truncate(self, op: Truncate) -> str:
  185. return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)
  186. def visit_extend(self, op: Extend) -> str:
  187. if op.signed:
  188. extra = " signed"
  189. else:
  190. extra = ""
  191. return self.format("%r = extend%s %r: %t to %t", op, extra, op.src, op.src_type, op.type)
  192. def visit_load_global(self, op: LoadGlobal) -> str:
  193. ann = f" ({repr(op.ann)})" if op.ann else ""
  194. return self.format("%r = load_global %s :: static%s", op, op.identifier, ann)
  195. def visit_int_op(self, op: IntOp) -> str:
  196. return self.format("%r = %r %s %r", op, op.lhs, IntOp.op_str[op.op], op.rhs)
  197. def visit_comparison_op(self, op: ComparisonOp) -> str:
  198. if op.op in (ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE):
  199. sign_format = " :: signed"
  200. elif op.op in (ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE):
  201. sign_format = " :: unsigned"
  202. else:
  203. sign_format = ""
  204. return self.format(
  205. "%r = %r %s %r%s", op, op.lhs, ComparisonOp.op_str[op.op], op.rhs, sign_format
  206. )
  207. def visit_float_op(self, op: FloatOp) -> str:
  208. return self.format("%r = %r %s %r", op, op.lhs, FloatOp.op_str[op.op], op.rhs)
  209. def visit_float_neg(self, op: FloatNeg) -> str:
  210. return self.format("%r = -%r", op, op.src)
  211. def visit_float_comparison_op(self, op: FloatComparisonOp) -> str:
  212. return self.format("%r = %r %s %r", op, op.lhs, op.op_str[op.op], op.rhs)
  213. def visit_load_mem(self, op: LoadMem) -> str:
  214. return self.format("%r = load_mem %r :: %t*", op, op.src, op.type)
  215. def visit_set_mem(self, op: SetMem) -> str:
  216. return self.format("set_mem %r, %r :: %t*", op.dest, op.src, op.dest_type)
  217. def visit_get_element_ptr(self, op: GetElementPtr) -> str:
  218. return self.format("%r = get_element_ptr %r %s :: %t", op, op.src, op.field, op.src_type)
  219. def visit_load_address(self, op: LoadAddress) -> str:
  220. if isinstance(op.src, Register):
  221. return self.format("%r = load_address %r", op, op.src)
  222. elif isinstance(op.src, LoadStatic):
  223. name = op.src.identifier
  224. if op.src.module_name is not None:
  225. name = f"{op.src.module_name}.{name}"
  226. return self.format("%r = load_address %s :: %s", op, name, op.src.namespace)
  227. else:
  228. return self.format("%r = load_address %s", op, op.src)
  229. def visit_keep_alive(self, op: KeepAlive) -> str:
  230. return self.format("keep_alive %s" % ", ".join(self.format("%r", v) for v in op.src))
  231. # Helpers
  232. def format(self, fmt: str, *args: Any) -> str:
  233. """Helper for formatting strings.
  234. These format sequences are supported in fmt:
  235. %s: arbitrary object converted to string using str()
  236. %r: name of IR value/register
  237. %d: int
  238. %f: float
  239. %l: BasicBlock (formatted as label 'Ln')
  240. %t: RType
  241. """
  242. result = []
  243. i = 0
  244. arglist = list(args)
  245. while i < len(fmt):
  246. n = fmt.find("%", i)
  247. if n < 0:
  248. n = len(fmt)
  249. result.append(fmt[i:n])
  250. if n < len(fmt):
  251. typespec = fmt[n + 1]
  252. arg = arglist.pop(0)
  253. if typespec == "r":
  254. # Register/value
  255. assert isinstance(arg, Value)
  256. if isinstance(arg, Integer):
  257. result.append(str(arg.value))
  258. elif isinstance(arg, Float):
  259. result.append(repr(arg.value))
  260. else:
  261. result.append(self.names[arg])
  262. elif typespec == "d":
  263. # Integer
  264. result.append("%d" % arg)
  265. elif typespec == "f":
  266. # Float
  267. result.append("%f" % arg)
  268. elif typespec == "l":
  269. # Basic block (label)
  270. assert isinstance(arg, BasicBlock)
  271. result.append("L%s" % arg.label)
  272. elif typespec == "t":
  273. # RType
  274. assert isinstance(arg, RType)
  275. result.append(arg.name)
  276. elif typespec == "s":
  277. # String
  278. result.append(str(arg))
  279. else:
  280. raise ValueError(f"Invalid format sequence %{typespec}")
  281. i = n + 2
  282. else:
  283. i = n
  284. return "".join(result)
  285. def format_registers(func_ir: FuncIR, names: dict[Value, str]) -> list[str]:
  286. result = []
  287. i = 0
  288. regs = all_values_full(func_ir.arg_regs, func_ir.blocks)
  289. while i < len(regs):
  290. i0 = i
  291. group = [names[regs[i0]]]
  292. while i + 1 < len(regs) and regs[i + 1].type == regs[i0].type:
  293. i += 1
  294. group.append(names[regs[i]])
  295. i += 1
  296. result.append("{} :: {}".format(", ".join(group), regs[i0].type))
  297. return result
  298. def format_blocks(
  299. blocks: list[BasicBlock],
  300. names: dict[Value, str],
  301. source_to_error: dict[ErrorSource, list[str]],
  302. ) -> list[str]:
  303. """Format a list of IR basic blocks into a human-readable form."""
  304. # First label all of the blocks
  305. for i, block in enumerate(blocks):
  306. block.label = i
  307. handler_map: dict[BasicBlock, list[BasicBlock]] = {}
  308. for b in blocks:
  309. if b.error_handler:
  310. handler_map.setdefault(b.error_handler, []).append(b)
  311. visitor = IRPrettyPrintVisitor(names)
  312. lines = []
  313. for i, block in enumerate(blocks):
  314. handler_msg = ""
  315. if block in handler_map:
  316. labels = sorted("L%d" % b.label for b in handler_map[block])
  317. handler_msg = " (handler for {})".format(", ".join(labels))
  318. lines.append("L%d:%s" % (block.label, handler_msg))
  319. if block in source_to_error:
  320. for error in source_to_error[block]:
  321. lines.append(f" ERR: {error}")
  322. ops = block.ops
  323. if (
  324. isinstance(ops[-1], Goto)
  325. and i + 1 < len(blocks)
  326. and ops[-1].label == blocks[i + 1]
  327. and not source_to_error.get(ops[-1], [])
  328. ):
  329. # Hide the last goto if it just goes to the next basic block,
  330. # and there are no assocatiated errors with the op.
  331. ops = ops[:-1]
  332. for op in ops:
  333. line = " " + op.accept(visitor)
  334. lines.append(line)
  335. if op in source_to_error:
  336. for error in source_to_error[op]:
  337. lines.append(f" ERR: {error}")
  338. if not isinstance(block.ops[-1], (Goto, Branch, Return, Unreachable)):
  339. # Each basic block needs to exit somewhere.
  340. lines.append(" [MISSING BLOCK EXIT OPCODE]")
  341. return lines
  342. def format_func(fn: FuncIR, errors: Sequence[tuple[ErrorSource, str]] = ()) -> list[str]:
  343. lines = []
  344. cls_prefix = fn.class_name + "." if fn.class_name else ""
  345. lines.append(
  346. "def {}{}({}):".format(cls_prefix, fn.name, ", ".join(arg.name for arg in fn.args))
  347. )
  348. names = generate_names_for_ir(fn.arg_regs, fn.blocks)
  349. for line in format_registers(fn, names):
  350. lines.append(" " + line)
  351. source_to_error = defaultdict(list)
  352. for source, error in errors:
  353. source_to_error[source].append(error)
  354. code = format_blocks(fn.blocks, names, source_to_error)
  355. lines.extend(code)
  356. return lines
  357. def format_modules(modules: ModuleIRs) -> list[str]:
  358. ops = []
  359. for module in modules.values():
  360. for fn in module.functions:
  361. ops.extend(format_func(fn))
  362. ops.append("")
  363. return ops
  364. def generate_names_for_ir(args: list[Register], blocks: list[BasicBlock]) -> dict[Value, str]:
  365. """Generate unique names for IR values.
  366. Give names such as 'r5' to temp values in IR which are useful when
  367. pretty-printing or generating C. Ensure generated names are unique.
  368. """
  369. names: dict[Value, str] = {}
  370. used_names = set()
  371. temp_index = 0
  372. for arg in args:
  373. names[arg] = arg.name
  374. used_names.add(arg.name)
  375. for block in blocks:
  376. for op in block.ops:
  377. values = []
  378. for source in op.sources():
  379. if source not in names:
  380. values.append(source)
  381. if isinstance(op, (Assign, AssignMulti)):
  382. values.append(op.dest)
  383. elif isinstance(op, ControlOp) or op.is_void:
  384. continue
  385. elif op not in names:
  386. values.append(op)
  387. for value in values:
  388. if value in names:
  389. continue
  390. if isinstance(value, Register) and value.name:
  391. name = value.name
  392. elif isinstance(value, (Integer, Float)):
  393. continue
  394. else:
  395. name = "r%d" % temp_index
  396. temp_index += 1
  397. # Append _2, _3, ... if needed to make the name unique.
  398. if name in used_names:
  399. n = 2
  400. while True:
  401. candidate = "%s_%d" % (name, n)
  402. if candidate not in used_names:
  403. name = candidate
  404. break
  405. n += 1
  406. names[value] = name
  407. used_names.add(name)
  408. return names