test_ircheck.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. from __future__ import annotations
  2. import unittest
  3. from mypyc.analysis.ircheck import FnError, can_coerce_to, check_func_ir
  4. from mypyc.ir.class_ir import ClassIR
  5. from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature
  6. from mypyc.ir.ops import (
  7. Assign,
  8. BasicBlock,
  9. Goto,
  10. Integer,
  11. LoadAddress,
  12. LoadLiteral,
  13. Op,
  14. Register,
  15. Return,
  16. )
  17. from mypyc.ir.pprint import format_func
  18. from mypyc.ir.rtypes import (
  19. RInstance,
  20. RType,
  21. RUnion,
  22. bytes_rprimitive,
  23. int32_rprimitive,
  24. int64_rprimitive,
  25. none_rprimitive,
  26. object_rprimitive,
  27. pointer_rprimitive,
  28. str_rprimitive,
  29. )
  30. def assert_has_error(fn: FuncIR, error: FnError) -> None:
  31. errors = check_func_ir(fn)
  32. assert errors == [error]
  33. def assert_no_errors(fn: FuncIR) -> None:
  34. assert not check_func_ir(fn)
  35. NONE_VALUE = Integer(0, rtype=none_rprimitive)
  36. class TestIrcheck(unittest.TestCase):
  37. def setUp(self) -> None:
  38. self.label = 0
  39. def basic_block(self, ops: list[Op]) -> BasicBlock:
  40. self.label += 1
  41. block = BasicBlock(self.label)
  42. block.ops = ops
  43. return block
  44. def func_decl(self, name: str, ret_type: RType | None = None) -> FuncDecl:
  45. if ret_type is None:
  46. ret_type = none_rprimitive
  47. return FuncDecl(
  48. name=name,
  49. class_name=None,
  50. module_name="module",
  51. sig=FuncSignature(args=[], ret_type=ret_type),
  52. )
  53. def test_valid_fn(self) -> None:
  54. assert_no_errors(
  55. FuncIR(
  56. decl=self.func_decl(name="func_1"),
  57. arg_regs=[],
  58. blocks=[self.basic_block(ops=[Return(value=NONE_VALUE)])],
  59. )
  60. )
  61. def test_block_not_terminated_empty_block(self) -> None:
  62. block = self.basic_block([])
  63. fn = FuncIR(decl=self.func_decl(name="func_1"), arg_regs=[], blocks=[block])
  64. assert_has_error(fn, FnError(source=block, desc="Block not terminated"))
  65. def test_valid_goto(self) -> None:
  66. block_1 = self.basic_block([Return(value=NONE_VALUE)])
  67. block_2 = self.basic_block([Goto(label=block_1)])
  68. fn = FuncIR(decl=self.func_decl(name="func_1"), arg_regs=[], blocks=[block_1, block_2])
  69. assert_no_errors(fn)
  70. def test_invalid_goto(self) -> None:
  71. block_1 = self.basic_block([Return(value=NONE_VALUE)])
  72. goto = Goto(label=block_1)
  73. block_2 = self.basic_block([goto])
  74. fn = FuncIR(
  75. decl=self.func_decl(name="func_1"),
  76. arg_regs=[],
  77. # block_1 omitted
  78. blocks=[block_2],
  79. )
  80. assert_has_error(fn, FnError(source=goto, desc="Invalid control operation target: 1"))
  81. def test_invalid_register_source(self) -> None:
  82. ret = Return(value=Register(type=none_rprimitive, name="r1"))
  83. block = self.basic_block([ret])
  84. fn = FuncIR(decl=self.func_decl(name="func_1"), arg_regs=[], blocks=[block])
  85. assert_has_error(fn, FnError(source=ret, desc="Invalid op reference to register 'r1'"))
  86. def test_invalid_op_source(self) -> None:
  87. ret = Return(value=LoadLiteral(value="foo", rtype=str_rprimitive))
  88. block = self.basic_block([ret])
  89. fn = FuncIR(decl=self.func_decl(name="func_1"), arg_regs=[], blocks=[block])
  90. assert_has_error(
  91. fn, FnError(source=ret, desc="Invalid op reference to op of type LoadLiteral")
  92. )
  93. def test_invalid_return_type(self) -> None:
  94. ret = Return(value=Integer(value=5, rtype=int32_rprimitive))
  95. fn = FuncIR(
  96. decl=self.func_decl(name="func_1", ret_type=int64_rprimitive),
  97. arg_regs=[],
  98. blocks=[self.basic_block([ret])],
  99. )
  100. assert_has_error(
  101. fn, FnError(source=ret, desc="Cannot coerce source type i32 to dest type i64")
  102. )
  103. def test_invalid_assign(self) -> None:
  104. arg_reg = Register(type=int64_rprimitive, name="r1")
  105. assign = Assign(dest=arg_reg, src=Integer(value=5, rtype=int32_rprimitive))
  106. ret = Return(value=NONE_VALUE)
  107. fn = FuncIR(
  108. decl=self.func_decl(name="func_1"),
  109. arg_regs=[arg_reg],
  110. blocks=[self.basic_block([assign, ret])],
  111. )
  112. assert_has_error(
  113. fn, FnError(source=assign, desc="Cannot coerce source type i32 to dest type i64")
  114. )
  115. def test_can_coerce_to(self) -> None:
  116. cls = ClassIR(name="Cls", module_name="cls")
  117. valid_cases = [
  118. (int64_rprimitive, int64_rprimitive),
  119. (str_rprimitive, str_rprimitive),
  120. (str_rprimitive, object_rprimitive),
  121. (object_rprimitive, str_rprimitive),
  122. (RUnion([bytes_rprimitive, str_rprimitive]), str_rprimitive),
  123. (str_rprimitive, RUnion([bytes_rprimitive, str_rprimitive])),
  124. (RInstance(cls), object_rprimitive),
  125. ]
  126. invalid_cases = [
  127. (int64_rprimitive, int32_rprimitive),
  128. (RInstance(cls), str_rprimitive),
  129. (str_rprimitive, bytes_rprimitive),
  130. ]
  131. for src, dest in valid_cases:
  132. assert can_coerce_to(src, dest)
  133. for src, dest in invalid_cases:
  134. assert not can_coerce_to(src, dest)
  135. def test_duplicate_op(self) -> None:
  136. arg_reg = Register(type=int32_rprimitive, name="r1")
  137. assign = Assign(dest=arg_reg, src=Integer(value=5, rtype=int32_rprimitive))
  138. block = self.basic_block([assign, assign, Return(value=NONE_VALUE)])
  139. fn = FuncIR(decl=self.func_decl(name="func_1"), arg_regs=[], blocks=[block])
  140. assert_has_error(fn, FnError(source=assign, desc="Func has a duplicate op"))
  141. def test_pprint(self) -> None:
  142. block_1 = self.basic_block([Return(value=NONE_VALUE)])
  143. goto = Goto(label=block_1)
  144. block_2 = self.basic_block([goto])
  145. fn = FuncIR(
  146. decl=self.func_decl(name="func_1"),
  147. arg_regs=[],
  148. # block_1 omitted
  149. blocks=[block_2],
  150. )
  151. errors = [(goto, "Invalid control operation target: 1")]
  152. formatted = format_func(fn, errors)
  153. assert formatted == [
  154. "def func_1():",
  155. "L0:",
  156. " goto L1",
  157. " ERR: Invalid control operation target: 1",
  158. ]
  159. def test_load_address_declares_register(self) -> None:
  160. rx = Register(str_rprimitive, "x")
  161. ry = Register(pointer_rprimitive, "y")
  162. load_addr = LoadAddress(pointer_rprimitive, rx)
  163. assert_no_errors(
  164. FuncIR(
  165. decl=self.func_decl(name="func_1"),
  166. arg_regs=[],
  167. blocks=[
  168. self.basic_block(
  169. ops=[load_addr, Assign(ry, load_addr), Return(value=NONE_VALUE)]
  170. )
  171. ],
  172. )
  173. )