pprint.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. """Utilities for pretty-printing IR in a human-readable form."""
  2. from __future__ import annotations
  3. from collections import defaultdict
  4. from typing import Any, Sequence, Union
  5. from typing_extensions import Final
  6. from mypyc.common import short_name
  7. from mypyc.ir.func_ir import FuncIR, all_values_full
  8. from mypyc.ir.module_ir import ModuleIRs
  9. from mypyc.ir.ops import (
  10. ERR_NEVER,
  11. Assign,
  12. AssignMulti,
  13. BasicBlock,
  14. Box,
  15. Branch,
  16. Call,
  17. CallC,
  18. Cast,
  19. ComparisonOp,
  20. ControlOp,
  21. DecRef,
  22. Extend,
  23. Float,
  24. FloatComparisonOp,
  25. FloatNeg,
  26. FloatOp,
  27. GetAttr,
  28. GetElementPtr,
  29. Goto,
  30. IncRef,
  31. InitStatic,
  32. Integer,
  33. IntOp,
  34. KeepAlive,
  35. LoadAddress,
  36. LoadErrorValue,
  37. LoadGlobal,
  38. LoadLiteral,
  39. LoadMem,
  40. LoadStatic,
  41. MethodCall,
  42. Op,
  43. OpVisitor,
  44. RaiseStandardError,
  45. Register,
  46. Return,
  47. SetAttr,
  48. SetMem,
  49. Truncate,
  50. TupleGet,
  51. TupleSet,
  52. Unbox,
  53. Unreachable,
  54. Value,
  55. )
  56. from mypyc.ir.rtypes import RType, is_bool_rprimitive, is_int_rprimitive
  57. ErrorSource = Union[BasicBlock, Op]
  58. class IRPrettyPrintVisitor(OpVisitor[str]):
  59. """Internal visitor that pretty-prints ops."""
  60. def __init__(self, names: dict[Value, str]) -> None:
  61. # This should contain a name for all values that are shown as
  62. # registers in the output. This is not just for Register
  63. # instances -- all Ops that produce values need (generated) names.
  64. self.names = names
  65. def visit_goto(self, op: Goto) -> str:
  66. return self.format("goto %l", op.label)
  67. branch_op_names: Final = {Branch.BOOL: ("%r", "bool"), Branch.IS_ERROR: ("is_error(%r)", "")}
  68. def visit_branch(self, op: Branch) -> str:
  69. fmt, typ = self.branch_op_names[op.op]
  70. if op.negated:
  71. fmt = f"not {fmt}"
  72. cond = self.format(fmt, op.value)
  73. tb = ""
  74. if op.traceback_entry:
  75. tb = " (error at %s:%d)" % op.traceback_entry
  76. fmt = f"if {cond} goto %l{tb} else goto %l"
  77. if typ:
  78. fmt += f" :: {typ}"
  79. return self.format(fmt, op.true, op.false)
  80. def visit_return(self, op: Return) -> str:
  81. return self.format("return %r", op.value)
  82. def visit_unreachable(self, op: Unreachable) -> str:
  83. return "unreachable"
  84. def visit_assign(self, op: Assign) -> str:
  85. return self.format("%r = %r", op.dest, op.src)
  86. def visit_assign_multi(self, op: AssignMulti) -> str:
  87. return self.format("%r = [%s]", op.dest, ", ".join(self.format("%r", v) for v in op.src))
  88. def visit_load_error_value(self, op: LoadErrorValue) -> str:
  89. return self.format("%r = <error> :: %s", op, op.type)
  90. def visit_load_literal(self, op: LoadLiteral) -> str:
  91. prefix = ""
  92. # For values that have a potential unboxed representation, make
  93. # it explicit that this is a Python object.
  94. if isinstance(op.value, int):
  95. prefix = "object "
  96. rvalue = repr(op.value)
  97. if isinstance(op.value, frozenset):
  98. # We need to generate a string representation that won't vary
  99. # run-to-run because sets are unordered, otherwise we may get
  100. # spurious irbuild test failures.
  101. #
  102. # Sorting by the item's string representation is a bit of a
  103. # hack, but it's stable and won't cause TypeErrors.
  104. formatted_items = [repr(i) for i in sorted(op.value, key=str)]
  105. rvalue = "frozenset({" + ", ".join(formatted_items) + "})"
  106. return self.format("%r = %s%s", op, prefix, rvalue)
  107. def visit_get_attr(self, op: GetAttr) -> str:
  108. return self.format("%r = %s%r.%s", op, self.borrow_prefix(op), op.obj, op.attr)
  109. def borrow_prefix(self, op: Op) -> str:
  110. if op.is_borrowed:
  111. return "borrow "
  112. return ""
  113. def visit_set_attr(self, op: SetAttr) -> str:
  114. if op.is_init:
  115. assert op.error_kind == ERR_NEVER
  116. if op.error_kind == ERR_NEVER:
  117. # Initialization and direct struct access can never fail
  118. return self.format("%r.%s = %r", op.obj, op.attr, op.src)
  119. else:
  120. return self.format("%r.%s = %r; %r = is_error", op.obj, op.attr, op.src, op)
  121. def visit_load_static(self, op: LoadStatic) -> str:
  122. ann = f" ({repr(op.ann)})" if op.ann else ""
  123. name = op.identifier
  124. if op.module_name is not None:
  125. name = f"{op.module_name}.{name}"
  126. return self.format("%r = %s :: %s%s", op, name, op.namespace, ann)
  127. def visit_init_static(self, op: InitStatic) -> str:
  128. name = op.identifier
  129. if op.module_name is not None:
  130. name = f"{op.module_name}.{name}"
  131. return self.format("%s = %r :: %s", name, op.value, op.namespace)
  132. def visit_tuple_get(self, op: TupleGet) -> str:
  133. return self.format("%r = %r[%d]", op, op.src, op.index)
  134. def visit_tuple_set(self, op: TupleSet) -> str:
  135. item_str = ", ".join(self.format("%r", item) for item in op.items)
  136. return self.format("%r = (%s)", op, item_str)
  137. def visit_inc_ref(self, op: IncRef) -> str:
  138. s = self.format("inc_ref %r", op.src)
  139. # TODO: Remove bool check (it's unboxed)
  140. if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type):
  141. s += f" :: {short_name(op.src.type.name)}"
  142. return s
  143. def visit_dec_ref(self, op: DecRef) -> str:
  144. s = self.format("%sdec_ref %r", "x" if op.is_xdec else "", op.src)
  145. # TODO: Remove bool check (it's unboxed)
  146. if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type):
  147. s += f" :: {short_name(op.src.type.name)}"
  148. return s
  149. def visit_call(self, op: Call) -> str:
  150. args = ", ".join(self.format("%r", arg) for arg in op.args)
  151. # TODO: Display long name?
  152. short_name = op.fn.shortname
  153. s = f"{short_name}({args})"
  154. if not op.is_void:
  155. s = self.format("%r = ", op) + s
  156. return s
  157. def visit_method_call(self, op: MethodCall) -> str:
  158. args = ", ".join(self.format("%r", arg) for arg in op.args)
  159. s = self.format("%r.%s(%s)", op.obj, op.method, args)
  160. if not op.is_void:
  161. s = self.format("%r = ", op) + s
  162. return s
  163. def visit_cast(self, op: Cast) -> str:
  164. return self.format("%r = %scast(%s, %r)", op, self.borrow_prefix(op), op.type, op.src)
  165. def visit_box(self, op: Box) -> str:
  166. return self.format("%r = box(%s, %r)", op, op.src.type, op.src)
  167. def visit_unbox(self, op: Unbox) -> str:
  168. return self.format("%r = unbox(%s, %r)", op, op.type, op.src)
  169. def visit_raise_standard_error(self, op: RaiseStandardError) -> str:
  170. if op.value is not None:
  171. if isinstance(op.value, str):
  172. return self.format("%r = raise %s(%s)", op, op.class_name, repr(op.value))
  173. elif isinstance(op.value, Value):
  174. return self.format("%r = raise %s(%r)", op, op.class_name, op.value)
  175. else:
  176. assert False, "value type must be either str or Value"
  177. else:
  178. return self.format("%r = raise %s", op, op.class_name)
  179. def visit_call_c(self, op: CallC) -> str:
  180. args_str = ", ".join(self.format("%r", arg) for arg in op.args)
  181. if op.is_void:
  182. return self.format("%s(%s)", op.function_name, args_str)
  183. else:
  184. return self.format("%r = %s(%s)", op, op.function_name, args_str)
  185. def visit_truncate(self, op: Truncate) -> str:
  186. return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)
  187. def visit_extend(self, op: Extend) -> str:
  188. if op.signed:
  189. extra = " signed"
  190. else:
  191. extra = ""
  192. return self.format("%r = extend%s %r: %t to %t", op, extra, op.src, op.src_type, op.type)
  193. def visit_load_global(self, op: LoadGlobal) -> str:
  194. ann = f" ({repr(op.ann)})" if op.ann else ""
  195. return self.format("%r = load_global %s :: static%s", op, op.identifier, ann)
  196. def visit_int_op(self, op: IntOp) -> str:
  197. return self.format("%r = %r %s %r", op, op.lhs, IntOp.op_str[op.op], op.rhs)
  198. def visit_comparison_op(self, op: ComparisonOp) -> str:
  199. if op.op in (ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE):
  200. sign_format = " :: signed"
  201. elif op.op in (ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE):
  202. sign_format = " :: unsigned"
  203. else:
  204. sign_format = ""
  205. return self.format(
  206. "%r = %r %s %r%s", op, op.lhs, ComparisonOp.op_str[op.op], op.rhs, sign_format
  207. )
  208. def visit_float_op(self, op: FloatOp) -> str:
  209. return self.format("%r = %r %s %r", op, op.lhs, FloatOp.op_str[op.op], op.rhs)
  210. def visit_float_neg(self, op: FloatNeg) -> str:
  211. return self.format("%r = -%r", op, op.src)
  212. def visit_float_comparison_op(self, op: FloatComparisonOp) -> str:
  213. return self.format("%r = %r %s %r", op, op.lhs, op.op_str[op.op], op.rhs)
  214. def visit_load_mem(self, op: LoadMem) -> str:
  215. return self.format("%r = load_mem %r :: %t*", op, op.src, op.type)
  216. def visit_set_mem(self, op: SetMem) -> str:
  217. return self.format("set_mem %r, %r :: %t*", op.dest, op.src, op.dest_type)
  218. def visit_get_element_ptr(self, op: GetElementPtr) -> str:
  219. return self.format("%r = get_element_ptr %r %s :: %t", op, op.src, op.field, op.src_type)
  220. def visit_load_address(self, op: LoadAddress) -> str:
  221. if isinstance(op.src, Register):
  222. return self.format("%r = load_address %r", op, op.src)
  223. elif isinstance(op.src, LoadStatic):
  224. name = op.src.identifier
  225. if op.src.module_name is not None:
  226. name = f"{op.src.module_name}.{name}"
  227. return self.format("%r = load_address %s :: %s", op, name, op.src.namespace)
  228. else:
  229. return self.format("%r = load_address %s", op, op.src)
  230. def visit_keep_alive(self, op: KeepAlive) -> str:
  231. return self.format("keep_alive %s" % ", ".join(self.format("%r", v) for v in op.src))
  232. # Helpers
  233. def format(self, fmt: str, *args: Any) -> str:
  234. """Helper for formatting strings.
  235. These format sequences are supported in fmt:
  236. %s: arbitrary object converted to string using str()
  237. %r: name of IR value/register
  238. %d: int
  239. %f: float
  240. %l: BasicBlock (formatted as label 'Ln')
  241. %t: RType
  242. """
  243. result = []
  244. i = 0
  245. arglist = list(args)
  246. while i < len(fmt):
  247. n = fmt.find("%", i)
  248. if n < 0:
  249. n = len(fmt)
  250. result.append(fmt[i:n])
  251. if n < len(fmt):
  252. typespec = fmt[n + 1]
  253. arg = arglist.pop(0)
  254. if typespec == "r":
  255. # Register/value
  256. assert isinstance(arg, Value)
  257. if isinstance(arg, Integer):
  258. result.append(str(arg.value))
  259. elif isinstance(arg, Float):
  260. result.append(repr(arg.value))
  261. else:
  262. result.append(self.names[arg])
  263. elif typespec == "d":
  264. # Integer
  265. result.append("%d" % arg)
  266. elif typespec == "f":
  267. # Float
  268. result.append("%f" % arg)
  269. elif typespec == "l":
  270. # Basic block (label)
  271. assert isinstance(arg, BasicBlock)
  272. result.append("L%s" % arg.label)
  273. elif typespec == "t":
  274. # RType
  275. assert isinstance(arg, RType)
  276. result.append(arg.name)
  277. elif typespec == "s":
  278. # String
  279. result.append(str(arg))
  280. else:
  281. raise ValueError(f"Invalid format sequence %{typespec}")
  282. i = n + 2
  283. else:
  284. i = n
  285. return "".join(result)
  286. def format_registers(func_ir: FuncIR, names: dict[Value, str]) -> list[str]:
  287. result = []
  288. i = 0
  289. regs = all_values_full(func_ir.arg_regs, func_ir.blocks)
  290. while i < len(regs):
  291. i0 = i
  292. group = [names[regs[i0]]]
  293. while i + 1 < len(regs) and regs[i + 1].type == regs[i0].type:
  294. i += 1
  295. group.append(names[regs[i]])
  296. i += 1
  297. result.append("{} :: {}".format(", ".join(group), regs[i0].type))
  298. return result
  299. def format_blocks(
  300. blocks: list[BasicBlock],
  301. names: dict[Value, str],
  302. source_to_error: dict[ErrorSource, list[str]],
  303. ) -> list[str]:
  304. """Format a list of IR basic blocks into a human-readable form."""
  305. # First label all of the blocks
  306. for i, block in enumerate(blocks):
  307. block.label = i
  308. handler_map: dict[BasicBlock, list[BasicBlock]] = {}
  309. for b in blocks:
  310. if b.error_handler:
  311. handler_map.setdefault(b.error_handler, []).append(b)
  312. visitor = IRPrettyPrintVisitor(names)
  313. lines = []
  314. for i, block in enumerate(blocks):
  315. handler_msg = ""
  316. if block in handler_map:
  317. labels = sorted("L%d" % b.label for b in handler_map[block])
  318. handler_msg = " (handler for {})".format(", ".join(labels))
  319. lines.append("L%d:%s" % (block.label, handler_msg))
  320. if block in source_to_error:
  321. for error in source_to_error[block]:
  322. lines.append(f" ERR: {error}")
  323. ops = block.ops
  324. if (
  325. isinstance(ops[-1], Goto)
  326. and i + 1 < len(blocks)
  327. and ops[-1].label == blocks[i + 1]
  328. and not source_to_error.get(ops[-1], [])
  329. ):
  330. # Hide the last goto if it just goes to the next basic block,
  331. # and there are no assocatiated errors with the op.
  332. ops = ops[:-1]
  333. for op in ops:
  334. line = " " + op.accept(visitor)
  335. lines.append(line)
  336. if op in source_to_error:
  337. for error in source_to_error[op]:
  338. lines.append(f" ERR: {error}")
  339. if not isinstance(block.ops[-1], (Goto, Branch, Return, Unreachable)):
  340. # Each basic block needs to exit somewhere.
  341. lines.append(" [MISSING BLOCK EXIT OPCODE]")
  342. return lines
  343. def format_func(fn: FuncIR, errors: Sequence[tuple[ErrorSource, str]] = ()) -> list[str]:
  344. lines = []
  345. cls_prefix = fn.class_name + "." if fn.class_name else ""
  346. lines.append(
  347. "def {}{}({}):".format(cls_prefix, fn.name, ", ".join(arg.name for arg in fn.args))
  348. )
  349. names = generate_names_for_ir(fn.arg_regs, fn.blocks)
  350. for line in format_registers(fn, names):
  351. lines.append(" " + line)
  352. source_to_error = defaultdict(list)
  353. for source, error in errors:
  354. source_to_error[source].append(error)
  355. code = format_blocks(fn.blocks, names, source_to_error)
  356. lines.extend(code)
  357. return lines
  358. def format_modules(modules: ModuleIRs) -> list[str]:
  359. ops = []
  360. for module in modules.values():
  361. for fn in module.functions:
  362. ops.extend(format_func(fn))
  363. ops.append("")
  364. return ops
  365. def generate_names_for_ir(args: list[Register], blocks: list[BasicBlock]) -> dict[Value, str]:
  366. """Generate unique names for IR values.
  367. Give names such as 'r5' to temp values in IR which are useful when
  368. pretty-printing or generating C. Ensure generated names are unique.
  369. """
  370. names: dict[Value, str] = {}
  371. used_names = set()
  372. temp_index = 0
  373. for arg in args:
  374. names[arg] = arg.name
  375. used_names.add(arg.name)
  376. for block in blocks:
  377. for op in block.ops:
  378. values = []
  379. for source in op.sources():
  380. if source not in names:
  381. values.append(source)
  382. if isinstance(op, (Assign, AssignMulti)):
  383. values.append(op.dest)
  384. elif isinstance(op, ControlOp) or op.is_void:
  385. continue
  386. elif op not in names:
  387. values.append(op)
  388. for value in values:
  389. if value in names:
  390. continue
  391. if isinstance(value, Register) and value.name:
  392. name = value.name
  393. elif isinstance(value, (Integer, Float)):
  394. continue
  395. else:
  396. name = "r%d" % temp_index
  397. temp_index += 1
  398. # Append _2, _3, ... if needed to make the name unique.
  399. if name in used_names:
  400. n = 2
  401. while True:
  402. candidate = "%s_%d" % (name, n)
  403. if candidate not in used_names:
  404. name = candidate
  405. break
  406. n += 1
  407. names[value] = name
  408. used_names.add(name)
  409. return names