exceptions.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. """Transform that inserts error checks after opcodes.
  2. When initially building the IR, the code doesn't perform error checks
  3. for exceptions. This module is used to insert all required error checks
  4. afterwards. Each Op describes how it indicates an error condition (if
  5. at all).
  6. We need to split basic blocks on each error check since branches can
  7. only be placed at the end of a basic block.
  8. """
  9. from __future__ import annotations
  10. from mypyc.ir.func_ir import FuncIR
  11. from mypyc.ir.ops import (
  12. ERR_ALWAYS,
  13. ERR_FALSE,
  14. ERR_MAGIC,
  15. ERR_MAGIC_OVERLAPPING,
  16. ERR_NEVER,
  17. NO_TRACEBACK_LINE_NO,
  18. BasicBlock,
  19. Branch,
  20. CallC,
  21. ComparisonOp,
  22. Float,
  23. GetAttr,
  24. Integer,
  25. LoadErrorValue,
  26. Op,
  27. RegisterOp,
  28. Return,
  29. SetAttr,
  30. TupleGet,
  31. Value,
  32. )
  33. from mypyc.ir.rtypes import RTuple, bool_rprimitive, is_float_rprimitive
  34. from mypyc.primitives.exc_ops import err_occurred_op
  35. from mypyc.primitives.registry import CFunctionDescription
  36. def insert_exception_handling(ir: FuncIR) -> None:
  37. # Generate error block if any ops may raise an exception. If an op
  38. # fails without its own error handler, we'll branch to this
  39. # block. The block just returns an error value.
  40. error_label: BasicBlock | None = None
  41. for block in ir.blocks:
  42. adjust_error_kinds(block)
  43. if error_label is None and any(op.can_raise() for op in block.ops):
  44. error_label = add_default_handler_block(ir)
  45. if error_label:
  46. ir.blocks = split_blocks_at_errors(ir.blocks, error_label, ir.traceback_name)
  47. def add_default_handler_block(ir: FuncIR) -> BasicBlock:
  48. block = BasicBlock()
  49. ir.blocks.append(block)
  50. op = LoadErrorValue(ir.ret_type)
  51. block.ops.append(op)
  52. block.ops.append(Return(op))
  53. return block
  54. def split_blocks_at_errors(
  55. blocks: list[BasicBlock], default_error_handler: BasicBlock, func_name: str | None
  56. ) -> list[BasicBlock]:
  57. new_blocks: list[BasicBlock] = []
  58. # First split blocks on ops that may raise.
  59. for block in blocks:
  60. ops = block.ops
  61. block.ops = []
  62. cur_block = block
  63. new_blocks.append(cur_block)
  64. # If the block has an error handler specified, use it. Otherwise
  65. # fall back to the default.
  66. error_label = block.error_handler or default_error_handler
  67. block.error_handler = None
  68. for op in ops:
  69. target: Value = op
  70. cur_block.ops.append(op)
  71. if isinstance(op, RegisterOp) and op.error_kind != ERR_NEVER:
  72. # Split
  73. new_block = BasicBlock()
  74. new_blocks.append(new_block)
  75. if op.error_kind == ERR_MAGIC:
  76. # Op returns an error value on error that depends on result RType.
  77. variant = Branch.IS_ERROR
  78. negated = False
  79. elif op.error_kind == ERR_FALSE:
  80. # Op returns a C false value on error.
  81. variant = Branch.BOOL
  82. negated = True
  83. elif op.error_kind == ERR_ALWAYS:
  84. variant = Branch.BOOL
  85. negated = True
  86. # this is a hack to represent the always fail
  87. # semantics, using a temporary bool with value false
  88. target = Integer(0, bool_rprimitive)
  89. elif op.error_kind == ERR_MAGIC_OVERLAPPING:
  90. comp = insert_overlapping_error_value_check(cur_block.ops, target)
  91. new_block2 = BasicBlock()
  92. new_blocks.append(new_block2)
  93. branch = Branch(
  94. comp,
  95. true_label=new_block2,
  96. false_label=new_block,
  97. op=Branch.BOOL,
  98. rare=True,
  99. )
  100. cur_block.ops.append(branch)
  101. cur_block = new_block2
  102. target = primitive_call(err_occurred_op, [], target.line)
  103. cur_block.ops.append(target)
  104. variant = Branch.IS_ERROR
  105. negated = True
  106. else:
  107. assert False, "unknown error kind %d" % op.error_kind
  108. # Void ops can't generate errors since error is always
  109. # indicated by a special value stored in a register.
  110. if op.error_kind != ERR_ALWAYS:
  111. assert not op.is_void, "void op generating errors?"
  112. branch = Branch(
  113. target, true_label=error_label, false_label=new_block, op=variant, line=op.line
  114. )
  115. branch.negated = negated
  116. if op.line != NO_TRACEBACK_LINE_NO and func_name is not None:
  117. branch.traceback_entry = (func_name, op.line)
  118. cur_block.ops.append(branch)
  119. cur_block = new_block
  120. return new_blocks
  121. def primitive_call(desc: CFunctionDescription, args: list[Value], line: int) -> CallC:
  122. return CallC(
  123. desc.c_function_name,
  124. [],
  125. desc.return_type,
  126. desc.steals,
  127. desc.is_borrowed,
  128. desc.error_kind,
  129. line,
  130. )
  131. def adjust_error_kinds(block: BasicBlock) -> None:
  132. """Infer more precise error_kind attributes for ops.
  133. We have access here to more information than what was available
  134. when the IR was initially built.
  135. """
  136. for op in block.ops:
  137. if isinstance(op, GetAttr):
  138. if op.class_type.class_ir.is_always_defined(op.attr):
  139. op.error_kind = ERR_NEVER
  140. if isinstance(op, SetAttr):
  141. if op.class_type.class_ir.is_always_defined(op.attr):
  142. op.error_kind = ERR_NEVER
  143. def insert_overlapping_error_value_check(ops: list[Op], target: Value) -> ComparisonOp:
  144. """Append to ops to check for an overlapping error value."""
  145. typ = target.type
  146. if isinstance(typ, RTuple):
  147. item = TupleGet(target, 0)
  148. ops.append(item)
  149. return insert_overlapping_error_value_check(ops, item)
  150. else:
  151. errvalue: Value
  152. if is_float_rprimitive(target.type):
  153. errvalue = Float(float(typ.c_undefined))
  154. else:
  155. errvalue = Integer(int(typ.c_undefined), rtype=typ)
  156. op = ComparisonOp(target, errvalue, ComparisonOp.EQ)
  157. ops.append(op)
  158. return op