refcount.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. """Transformation for inserting refrecence count inc/dec opcodes.
  2. This transformation happens towards the end of compilation. Before this
  3. transformation, reference count management is not explicitly handled at all.
  4. By postponing this pass, the previous passes are simpler as they don't have
  5. to update reference count opcodes.
  6. The approach is to decrement reference counts soon after a value is no
  7. longer live, to quickly free memory (and call __del__ methods), though
  8. there are no strict guarantees -- other than that local variables are
  9. freed before return from a function.
  10. Function arguments are a little special. They are initially considered
  11. 'borrowed' from the caller and their reference counts don't need to be
  12. decremented before returning. An assignment to a borrowed value turns it
  13. into a regular, owned reference that needs to freed before return.
  14. """
  15. from __future__ import annotations
  16. from typing import Dict, Iterable, Tuple
  17. from mypyc.analysis.dataflow import (
  18. AnalysisDict,
  19. analyze_borrowed_arguments,
  20. analyze_live_regs,
  21. analyze_must_defined_regs,
  22. cleanup_cfg,
  23. get_cfg,
  24. )
  25. from mypyc.ir.func_ir import FuncIR, all_values
  26. from mypyc.ir.ops import (
  27. Assign,
  28. BasicBlock,
  29. Branch,
  30. ControlOp,
  31. DecRef,
  32. Goto,
  33. IncRef,
  34. Integer,
  35. KeepAlive,
  36. LoadAddress,
  37. Op,
  38. Register,
  39. RegisterOp,
  40. Value,
  41. )
  42. Decs = Tuple[Tuple[Value, bool], ...]
  43. Incs = Tuple[Value, ...]
  44. # A cache of basic blocks that decrement and increment specific values
  45. # and then jump to some target block. This lets us cut down on how
  46. # much code we generate in some circumstances.
  47. BlockCache = Dict[Tuple[BasicBlock, Decs, Incs], BasicBlock]
  48. def insert_ref_count_opcodes(ir: FuncIR) -> None:
  49. """Insert reference count inc/dec opcodes to a function.
  50. This is the entry point to this module.
  51. """
  52. cfg = get_cfg(ir.blocks)
  53. values = all_values(ir.arg_regs, ir.blocks)
  54. borrowed = {value for value in values if value.is_borrowed}
  55. args: set[Value] = set(ir.arg_regs)
  56. live = analyze_live_regs(ir.blocks, cfg)
  57. borrow = analyze_borrowed_arguments(ir.blocks, cfg, borrowed)
  58. defined = analyze_must_defined_regs(ir.blocks, cfg, args, values, strict_errors=True)
  59. ordering = make_value_ordering(ir)
  60. cache: BlockCache = {}
  61. for block in ir.blocks.copy():
  62. if isinstance(block.ops[-1], (Branch, Goto)):
  63. insert_branch_inc_and_decrefs(
  64. block,
  65. cache,
  66. ir.blocks,
  67. live.before,
  68. borrow.before,
  69. borrow.after,
  70. defined.after,
  71. ordering,
  72. )
  73. transform_block(block, live.before, live.after, borrow.before, defined.after)
  74. cleanup_cfg(ir.blocks)
  75. def is_maybe_undefined(post_must_defined: set[Value], src: Value) -> bool:
  76. return isinstance(src, Register) and src not in post_must_defined
  77. def maybe_append_dec_ref(
  78. ops: list[Op], dest: Value, defined: AnalysisDict[Value], key: tuple[BasicBlock, int]
  79. ) -> None:
  80. if dest.type.is_refcounted and not isinstance(dest, Integer):
  81. ops.append(DecRef(dest, is_xdec=is_maybe_undefined(defined[key], dest)))
  82. def maybe_append_inc_ref(ops: list[Op], dest: Value) -> None:
  83. if dest.type.is_refcounted:
  84. ops.append(IncRef(dest))
  85. def transform_block(
  86. block: BasicBlock,
  87. pre_live: AnalysisDict[Value],
  88. post_live: AnalysisDict[Value],
  89. pre_borrow: AnalysisDict[Value],
  90. post_must_defined: AnalysisDict[Value],
  91. ) -> None:
  92. old_ops = block.ops
  93. ops: list[Op] = []
  94. for i, op in enumerate(old_ops):
  95. key = (block, i)
  96. assert op not in pre_live[key]
  97. dest = op.dest if isinstance(op, Assign) else op
  98. stolen = op.stolen()
  99. # Incref any references that are being stolen that stay live, were borrowed,
  100. # or are stolen more than once by this operation.
  101. for j, src in enumerate(stolen):
  102. if src in post_live[key] or src in pre_borrow[key] or src in stolen[:j]:
  103. maybe_append_inc_ref(ops, src)
  104. # For assignments to registers that were already live,
  105. # decref the old value.
  106. if dest not in pre_borrow[key] and dest in pre_live[key]:
  107. assert isinstance(op, Assign)
  108. maybe_append_dec_ref(ops, dest, post_must_defined, key)
  109. # Strip KeepAlive. Its only purpose is to help with this transform.
  110. if not isinstance(op, KeepAlive):
  111. ops.append(op)
  112. # Control ops don't have any space to insert ops after them, so
  113. # their inc/decrefs get inserted by insert_branch_inc_and_decrefs.
  114. if isinstance(op, ControlOp):
  115. continue
  116. for src in op.unique_sources():
  117. # Decrement source that won't be live afterwards.
  118. if src not in post_live[key] and src not in pre_borrow[key] and src not in stolen:
  119. maybe_append_dec_ref(ops, src, post_must_defined, key)
  120. # Decrement the destination if it is dead after the op and
  121. # wasn't a borrowed RegisterOp
  122. if (
  123. not dest.is_void
  124. and dest not in post_live[key]
  125. and not (isinstance(op, RegisterOp) and dest.is_borrowed)
  126. ):
  127. maybe_append_dec_ref(ops, dest, post_must_defined, key)
  128. block.ops = ops
  129. def insert_branch_inc_and_decrefs(
  130. block: BasicBlock,
  131. cache: BlockCache,
  132. blocks: list[BasicBlock],
  133. pre_live: AnalysisDict[Value],
  134. pre_borrow: AnalysisDict[Value],
  135. post_borrow: AnalysisDict[Value],
  136. post_must_defined: AnalysisDict[Value],
  137. ordering: dict[Value, int],
  138. ) -> None:
  139. """Insert inc_refs and/or dec_refs after a branch/goto.
  140. Add dec_refs for registers that become dead after a branch.
  141. Add inc_refs for registers that become unborrowed after a branch or goto.
  142. Branches are special as the true and false targets may have a different
  143. live and borrowed register sets. Add new blocks before the true/false target
  144. blocks that tweak reference counts.
  145. Example where we need to add an inc_ref:
  146. def f(a: int) -> None
  147. if a:
  148. a = 1
  149. return a # a is borrowed if condition is false and unborrowed if true
  150. """
  151. prev_key = (block, len(block.ops) - 1)
  152. source_live_regs = pre_live[prev_key]
  153. source_borrowed = post_borrow[prev_key]
  154. source_defined = post_must_defined[prev_key]
  155. term = block.terminator
  156. for i, target in enumerate(term.targets()):
  157. # HAX: After we've checked against an error value the value we must not touch the
  158. # refcount since it will be a null pointer. The correct way to do this would be
  159. # to perform data flow analysis on whether a value can be null (or is always
  160. # null).
  161. omitted: Iterable[Value]
  162. if isinstance(term, Branch) and term.op == Branch.IS_ERROR and i == 0:
  163. omitted = (term.value,)
  164. else:
  165. omitted = ()
  166. decs = after_branch_decrefs(
  167. target, pre_live, source_defined, source_borrowed, source_live_regs, ordering, omitted
  168. )
  169. incs = after_branch_increfs(target, pre_live, pre_borrow, source_borrowed, ordering)
  170. term.set_target(i, add_block(decs, incs, cache, blocks, target))
  171. def after_branch_decrefs(
  172. label: BasicBlock,
  173. pre_live: AnalysisDict[Value],
  174. source_defined: set[Value],
  175. source_borrowed: set[Value],
  176. source_live_regs: set[Value],
  177. ordering: dict[Value, int],
  178. omitted: Iterable[Value],
  179. ) -> tuple[tuple[Value, bool], ...]:
  180. target_pre_live = pre_live[label, 0]
  181. decref = source_live_regs - target_pre_live - source_borrowed
  182. if decref:
  183. return tuple(
  184. (reg, is_maybe_undefined(source_defined, reg))
  185. for reg in sorted(decref, key=lambda r: ordering[r])
  186. if reg.type.is_refcounted and reg not in omitted
  187. )
  188. return ()
  189. def after_branch_increfs(
  190. label: BasicBlock,
  191. pre_live: AnalysisDict[Value],
  192. pre_borrow: AnalysisDict[Value],
  193. source_borrowed: set[Value],
  194. ordering: dict[Value, int],
  195. ) -> tuple[Value, ...]:
  196. target_pre_live = pre_live[label, 0]
  197. target_borrowed = pre_borrow[label, 0]
  198. incref = (source_borrowed - target_borrowed) & target_pre_live
  199. if incref:
  200. return tuple(
  201. reg for reg in sorted(incref, key=lambda r: ordering[r]) if reg.type.is_refcounted
  202. )
  203. return ()
  204. def add_block(
  205. decs: Decs, incs: Incs, cache: BlockCache, blocks: list[BasicBlock], label: BasicBlock
  206. ) -> BasicBlock:
  207. if not decs and not incs:
  208. return label
  209. # TODO: be able to share *partial* results
  210. if (label, decs, incs) in cache:
  211. return cache[label, decs, incs]
  212. block = BasicBlock()
  213. blocks.append(block)
  214. block.ops.extend(DecRef(reg, is_xdec=xdec) for reg, xdec in decs)
  215. block.ops.extend(IncRef(reg) for reg in incs)
  216. block.ops.append(Goto(label))
  217. cache[label, decs, incs] = block
  218. return block
  219. def make_value_ordering(ir: FuncIR) -> dict[Value, int]:
  220. """Create a ordering of values that allows them to be sorted.
  221. This omits registers that are only ever read.
  222. """
  223. # TODO: Never initialized values??
  224. result: dict[Value, int] = {}
  225. n = 0
  226. for arg in ir.arg_regs:
  227. result[arg] = n
  228. n += 1
  229. for block in ir.blocks:
  230. for op in block.ops:
  231. if (
  232. isinstance(op, LoadAddress)
  233. and isinstance(op.src, Register)
  234. and op.src not in result
  235. ):
  236. # Taking the address of a register allows initialization.
  237. result[op.src] = n
  238. n += 1
  239. if isinstance(op, Assign):
  240. if op.dest not in result:
  241. result[op.dest] = n
  242. n += 1
  243. elif op not in result:
  244. result[op] = n
  245. n += 1
  246. return result