ircheck.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. """Utilities for checking that internal ir is valid and consistent."""
  2. from __future__ import annotations
  3. from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncIR
  4. from mypyc.ir.ops import (
  5. Assign,
  6. AssignMulti,
  7. BaseAssign,
  8. BasicBlock,
  9. Box,
  10. Branch,
  11. Call,
  12. CallC,
  13. Cast,
  14. ComparisonOp,
  15. ControlOp,
  16. DecRef,
  17. Extend,
  18. FloatComparisonOp,
  19. FloatNeg,
  20. FloatOp,
  21. GetAttr,
  22. GetElementPtr,
  23. Goto,
  24. IncRef,
  25. InitStatic,
  26. Integer,
  27. IntOp,
  28. KeepAlive,
  29. LoadAddress,
  30. LoadErrorValue,
  31. LoadGlobal,
  32. LoadLiteral,
  33. LoadMem,
  34. LoadStatic,
  35. MethodCall,
  36. Op,
  37. OpVisitor,
  38. RaiseStandardError,
  39. Register,
  40. Return,
  41. SetAttr,
  42. SetMem,
  43. Truncate,
  44. TupleGet,
  45. TupleSet,
  46. Unbox,
  47. Unreachable,
  48. Value,
  49. )
  50. from mypyc.ir.pprint import format_func
  51. from mypyc.ir.rtypes import (
  52. RArray,
  53. RInstance,
  54. RPrimitive,
  55. RType,
  56. RUnion,
  57. bytes_rprimitive,
  58. dict_rprimitive,
  59. int_rprimitive,
  60. is_float_rprimitive,
  61. is_object_rprimitive,
  62. list_rprimitive,
  63. range_rprimitive,
  64. set_rprimitive,
  65. str_rprimitive,
  66. tuple_rprimitive,
  67. )
  68. class FnError:
  69. def __init__(self, source: Op | BasicBlock, desc: str) -> None:
  70. self.source = source
  71. self.desc = desc
  72. def __eq__(self, other: object) -> bool:
  73. return (
  74. isinstance(other, FnError) and self.source == other.source and self.desc == other.desc
  75. )
  76. def __repr__(self) -> str:
  77. return f"FnError(source={self.source}, desc={self.desc})"
  78. def check_func_ir(fn: FuncIR) -> list[FnError]:
  79. """Applies validations to a given function ir and returns a list of errors found."""
  80. errors = []
  81. op_set = set()
  82. for block in fn.blocks:
  83. if not block.terminated:
  84. errors.append(
  85. FnError(source=block.ops[-1] if block.ops else block, desc="Block not terminated")
  86. )
  87. for op in block.ops[:-1]:
  88. if isinstance(op, ControlOp):
  89. errors.append(FnError(source=op, desc="Block has operations after control op"))
  90. if op in op_set:
  91. errors.append(FnError(source=op, desc="Func has a duplicate op"))
  92. op_set.add(op)
  93. errors.extend(check_op_sources_valid(fn))
  94. if errors:
  95. return errors
  96. op_checker = OpChecker(fn)
  97. for block in fn.blocks:
  98. for op in block.ops:
  99. op.accept(op_checker)
  100. return op_checker.errors
  101. class IrCheckException(Exception):
  102. pass
  103. def assert_func_ir_valid(fn: FuncIR) -> None:
  104. errors = check_func_ir(fn)
  105. if errors:
  106. raise IrCheckException(
  107. "Internal error: Generated invalid IR: \n"
  108. + "\n".join(format_func(fn, [(e.source, e.desc) for e in errors]))
  109. )
  110. def check_op_sources_valid(fn: FuncIR) -> list[FnError]:
  111. errors = []
  112. valid_ops: set[Op] = set()
  113. valid_registers: set[Register] = set()
  114. for block in fn.blocks:
  115. valid_ops.update(block.ops)
  116. for op in block.ops:
  117. if isinstance(op, BaseAssign):
  118. valid_registers.add(op.dest)
  119. elif isinstance(op, LoadAddress) and isinstance(op.src, Register):
  120. valid_registers.add(op.src)
  121. valid_registers.update(fn.arg_regs)
  122. for block in fn.blocks:
  123. for op in block.ops:
  124. for source in op.sources():
  125. if isinstance(source, Integer):
  126. pass
  127. elif isinstance(source, Op):
  128. if source not in valid_ops:
  129. errors.append(
  130. FnError(
  131. source=op,
  132. desc=f"Invalid op reference to op of type {type(source).__name__}",
  133. )
  134. )
  135. elif isinstance(source, Register):
  136. if source not in valid_registers:
  137. errors.append(
  138. FnError(
  139. source=op, desc=f"Invalid op reference to register {source.name!r}"
  140. )
  141. )
  142. return errors
  143. disjoint_types = {
  144. int_rprimitive.name,
  145. bytes_rprimitive.name,
  146. str_rprimitive.name,
  147. dict_rprimitive.name,
  148. list_rprimitive.name,
  149. set_rprimitive.name,
  150. tuple_rprimitive.name,
  151. range_rprimitive.name,
  152. }
  153. def can_coerce_to(src: RType, dest: RType) -> bool:
  154. """Check if src can be assigned to dest_rtype.
  155. Currently okay to have false positives.
  156. """
  157. if isinstance(dest, RUnion):
  158. return any(can_coerce_to(src, d) for d in dest.items)
  159. if isinstance(dest, RPrimitive):
  160. if isinstance(src, RPrimitive):
  161. # If either src or dest is a disjoint type, then they must both be.
  162. if src.name in disjoint_types and dest.name in disjoint_types:
  163. return src.name == dest.name
  164. return src.size == dest.size
  165. if isinstance(src, RInstance):
  166. return is_object_rprimitive(dest)
  167. if isinstance(src, RUnion):
  168. # IR doesn't have the ability to narrow unions based on
  169. # control flow, so cannot be a strict all() here.
  170. return any(can_coerce_to(s, dest) for s in src.items)
  171. return False
  172. return True
  173. class OpChecker(OpVisitor[None]):
  174. def __init__(self, parent_fn: FuncIR) -> None:
  175. self.parent_fn = parent_fn
  176. self.errors: list[FnError] = []
  177. def fail(self, source: Op, desc: str) -> None:
  178. self.errors.append(FnError(source=source, desc=desc))
  179. def check_control_op_targets(self, op: ControlOp) -> None:
  180. for target in op.targets():
  181. if target not in self.parent_fn.blocks:
  182. self.fail(source=op, desc=f"Invalid control operation target: {target.label}")
  183. def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None:
  184. if not can_coerce_to(src, dest):
  185. self.fail(
  186. source=op, desc=f"Cannot coerce source type {src.name} to dest type {dest.name}"
  187. )
  188. def check_compatibility(self, op: Op, t: RType, s: RType) -> None:
  189. if not can_coerce_to(t, s) or not can_coerce_to(s, t):
  190. self.fail(source=op, desc=f"{t.name} and {s.name} are not compatible")
  191. def expect_float(self, op: Op, v: Value) -> None:
  192. if not is_float_rprimitive(v.type):
  193. self.fail(op, f"Float expected (actual type is {v.type})")
  194. def expect_non_float(self, op: Op, v: Value) -> None:
  195. if is_float_rprimitive(v.type):
  196. self.fail(op, "Float not expected")
  197. def visit_goto(self, op: Goto) -> None:
  198. self.check_control_op_targets(op)
  199. def visit_branch(self, op: Branch) -> None:
  200. self.check_control_op_targets(op)
  201. def visit_return(self, op: Return) -> None:
  202. self.check_type_coercion(op, op.value.type, self.parent_fn.decl.sig.ret_type)
  203. def visit_unreachable(self, op: Unreachable) -> None:
  204. # Unreachables are checked at a higher level since validation
  205. # requires access to the entire basic block.
  206. pass
  207. def visit_assign(self, op: Assign) -> None:
  208. self.check_type_coercion(op, op.src.type, op.dest.type)
  209. def visit_assign_multi(self, op: AssignMulti) -> None:
  210. for src in op.src:
  211. assert isinstance(op.dest.type, RArray)
  212. self.check_type_coercion(op, src.type, op.dest.type.item_type)
  213. def visit_load_error_value(self, op: LoadErrorValue) -> None:
  214. # Currently it is assumed that all types have an error value.
  215. # Once this is fixed we can validate that the rtype here actually
  216. # has an error value.
  217. pass
  218. def check_tuple_items_valid_literals(self, op: LoadLiteral, t: tuple[object, ...]) -> None:
  219. for x in t:
  220. if x is not None and not isinstance(x, (str, bytes, bool, int, float, complex, tuple)):
  221. self.fail(op, f"Invalid type for item of tuple literal: {type(x)})")
  222. if isinstance(x, tuple):
  223. self.check_tuple_items_valid_literals(op, x)
  224. def check_frozenset_items_valid_literals(self, op: LoadLiteral, s: frozenset[object]) -> None:
  225. for x in s:
  226. if x is None or isinstance(x, (str, bytes, bool, int, float, complex)):
  227. pass
  228. elif isinstance(x, tuple):
  229. self.check_tuple_items_valid_literals(op, x)
  230. else:
  231. self.fail(op, f"Invalid type for item of frozenset literal: {type(x)})")
  232. def visit_load_literal(self, op: LoadLiteral) -> None:
  233. expected_type = None
  234. if op.value is None:
  235. expected_type = "builtins.object"
  236. elif isinstance(op.value, int):
  237. expected_type = "builtins.int"
  238. elif isinstance(op.value, str):
  239. expected_type = "builtins.str"
  240. elif isinstance(op.value, bytes):
  241. expected_type = "builtins.bytes"
  242. elif isinstance(op.value, bool):
  243. expected_type = "builtins.object"
  244. elif isinstance(op.value, float):
  245. expected_type = "builtins.float"
  246. elif isinstance(op.value, complex):
  247. expected_type = "builtins.object"
  248. elif isinstance(op.value, tuple):
  249. expected_type = "builtins.tuple"
  250. self.check_tuple_items_valid_literals(op, op.value)
  251. elif isinstance(op.value, frozenset):
  252. # There's no frozenset_rprimitive type since it'd be pretty useless so we just pretend
  253. # it's a set (when it's really a frozenset).
  254. expected_type = "builtins.set"
  255. self.check_frozenset_items_valid_literals(op, op.value)
  256. assert expected_type is not None, "Missed a case for LoadLiteral check"
  257. if op.type.name not in [expected_type, "builtins.object"]:
  258. self.fail(
  259. op,
  260. f"Invalid literal value for type: value has "
  261. f"type {expected_type}, but op has type {op.type.name}",
  262. )
  263. def visit_get_attr(self, op: GetAttr) -> None:
  264. # Nothing to do.
  265. pass
  266. def visit_set_attr(self, op: SetAttr) -> None:
  267. # Nothing to do.
  268. pass
  269. # Static operations cannot be checked at the function level.
  270. def visit_load_static(self, op: LoadStatic) -> None:
  271. pass
  272. def visit_init_static(self, op: InitStatic) -> None:
  273. pass
  274. def visit_tuple_get(self, op: TupleGet) -> None:
  275. # Nothing to do.
  276. pass
  277. def visit_tuple_set(self, op: TupleSet) -> None:
  278. # Nothing to do.
  279. pass
  280. def visit_inc_ref(self, op: IncRef) -> None:
  281. # Nothing to do.
  282. pass
  283. def visit_dec_ref(self, op: DecRef) -> None:
  284. # Nothing to do.
  285. pass
  286. def visit_call(self, op: Call) -> None:
  287. # Length is checked in constructor, and return type is set
  288. # in a way that can't be incorrect
  289. for arg_value, arg_runtime in zip(op.args, op.fn.sig.args):
  290. self.check_type_coercion(op, arg_value.type, arg_runtime.type)
  291. def visit_method_call(self, op: MethodCall) -> None:
  292. # Similar to above, but we must look up method first.
  293. method_decl = op.receiver_type.class_ir.method_decl(op.method)
  294. if method_decl.kind == FUNC_STATICMETHOD:
  295. decl_index = 0
  296. else:
  297. decl_index = 1
  298. if len(op.args) + decl_index != len(method_decl.sig.args):
  299. self.fail(op, "Incorrect number of args for method call.")
  300. # Skip the receiver argument (self)
  301. for arg_value, arg_runtime in zip(op.args, method_decl.sig.args[decl_index:]):
  302. self.check_type_coercion(op, arg_value.type, arg_runtime.type)
  303. def visit_cast(self, op: Cast) -> None:
  304. pass
  305. def visit_box(self, op: Box) -> None:
  306. pass
  307. def visit_unbox(self, op: Unbox) -> None:
  308. pass
  309. def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
  310. pass
  311. def visit_call_c(self, op: CallC) -> None:
  312. pass
  313. def visit_truncate(self, op: Truncate) -> None:
  314. pass
  315. def visit_extend(self, op: Extend) -> None:
  316. pass
  317. def visit_load_global(self, op: LoadGlobal) -> None:
  318. pass
  319. def visit_int_op(self, op: IntOp) -> None:
  320. self.expect_non_float(op, op.lhs)
  321. self.expect_non_float(op, op.rhs)
  322. def visit_comparison_op(self, op: ComparisonOp) -> None:
  323. self.check_compatibility(op, op.lhs.type, op.rhs.type)
  324. self.expect_non_float(op, op.lhs)
  325. self.expect_non_float(op, op.rhs)
  326. def visit_float_op(self, op: FloatOp) -> None:
  327. self.expect_float(op, op.lhs)
  328. self.expect_float(op, op.rhs)
  329. def visit_float_neg(self, op: FloatNeg) -> None:
  330. self.expect_float(op, op.src)
  331. def visit_float_comparison_op(self, op: FloatComparisonOp) -> None:
  332. self.expect_float(op, op.lhs)
  333. self.expect_float(op, op.rhs)
  334. def visit_load_mem(self, op: LoadMem) -> None:
  335. pass
  336. def visit_set_mem(self, op: SetMem) -> None:
  337. pass
  338. def visit_get_element_ptr(self, op: GetElementPtr) -> None:
  339. pass
  340. def visit_load_address(self, op: LoadAddress) -> None:
  341. pass
  342. def visit_keep_alive(self, op: KeepAlive) -> None:
  343. pass