dataflow.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. """Data-flow analyses."""
  2. from __future__ import annotations
  3. from abc import abstractmethod
  4. from typing import Dict, Generic, Iterable, Iterator, Set, Tuple, TypeVar
  5. from mypyc.ir.func_ir import all_values
  6. from mypyc.ir.ops import (
  7. Assign,
  8. AssignMulti,
  9. BasicBlock,
  10. Box,
  11. Branch,
  12. Call,
  13. CallC,
  14. Cast,
  15. ComparisonOp,
  16. ControlOp,
  17. Extend,
  18. Float,
  19. FloatComparisonOp,
  20. FloatNeg,
  21. FloatOp,
  22. GetAttr,
  23. GetElementPtr,
  24. Goto,
  25. InitStatic,
  26. Integer,
  27. IntOp,
  28. KeepAlive,
  29. LoadAddress,
  30. LoadErrorValue,
  31. LoadGlobal,
  32. LoadLiteral,
  33. LoadMem,
  34. LoadStatic,
  35. MethodCall,
  36. Op,
  37. OpVisitor,
  38. RaiseStandardError,
  39. RegisterOp,
  40. Return,
  41. SetAttr,
  42. SetMem,
  43. Truncate,
  44. TupleGet,
  45. TupleSet,
  46. Unbox,
  47. Unreachable,
  48. Value,
  49. )
  50. class CFG:
  51. """Control-flow graph.
  52. Node 0 is always assumed to be the entry point. There must be a
  53. non-empty set of exits.
  54. """
  55. def __init__(
  56. self,
  57. succ: dict[BasicBlock, list[BasicBlock]],
  58. pred: dict[BasicBlock, list[BasicBlock]],
  59. exits: set[BasicBlock],
  60. ) -> None:
  61. assert exits
  62. self.succ = succ
  63. self.pred = pred
  64. self.exits = exits
  65. def __str__(self) -> str:
  66. lines = []
  67. lines.append("exits: %s" % sorted(self.exits, key=lambda e: int(e.label)))
  68. lines.append("succ: %s" % self.succ)
  69. lines.append("pred: %s" % self.pred)
  70. return "\n".join(lines)
  71. def get_cfg(blocks: list[BasicBlock]) -> CFG:
  72. """Calculate basic block control-flow graph.
  73. The result is a dictionary like this:
  74. basic block index -> (successors blocks, predecesssor blocks)
  75. """
  76. succ_map = {}
  77. pred_map: dict[BasicBlock, list[BasicBlock]] = {}
  78. exits = set()
  79. for block in blocks:
  80. assert not any(
  81. isinstance(op, ControlOp) for op in block.ops[:-1]
  82. ), "Control-flow ops must be at the end of blocks"
  83. succ = list(block.terminator.targets())
  84. if not succ:
  85. exits.add(block)
  86. # Errors can occur anywhere inside a block, which means that
  87. # we can't assume that the entire block has executed before
  88. # jumping to the error handler. In our CFG construction, we
  89. # model this as saying that a block can jump to its error
  90. # handler or the error handlers of any of its normal
  91. # successors (to represent an error before that next block
  92. # completes). This works well for analyses like "must
  93. # defined", where it implies that registers assigned in a
  94. # block may be undefined in its error handler, but is in
  95. # general not a precise representation of reality; any
  96. # analyses that require more fidelity must wait until after
  97. # exception insertion.
  98. for error_point in [block] + succ:
  99. if error_point.error_handler:
  100. succ.append(error_point.error_handler)
  101. succ_map[block] = succ
  102. pred_map[block] = []
  103. for prev, nxt in succ_map.items():
  104. for label in nxt:
  105. pred_map[label].append(prev)
  106. return CFG(succ_map, pred_map, exits)
  107. def get_real_target(label: BasicBlock) -> BasicBlock:
  108. if len(label.ops) == 1 and isinstance(label.ops[-1], Goto):
  109. label = label.ops[-1].label
  110. return label
  111. def cleanup_cfg(blocks: list[BasicBlock]) -> None:
  112. """Cleanup the control flow graph.
  113. This eliminates obviously dead basic blocks and eliminates blocks that contain
  114. nothing but a single jump.
  115. There is a lot more that could be done.
  116. """
  117. changed = True
  118. while changed:
  119. # First collapse any jumps to basic block that only contain a goto
  120. for block in blocks:
  121. for i, tgt in enumerate(block.terminator.targets()):
  122. block.terminator.set_target(i, get_real_target(tgt))
  123. # Then delete any blocks that have no predecessors
  124. changed = False
  125. cfg = get_cfg(blocks)
  126. orig_blocks = blocks.copy()
  127. blocks.clear()
  128. for i, block in enumerate(orig_blocks):
  129. if i == 0 or cfg.pred[block]:
  130. blocks.append(block)
  131. else:
  132. changed = True
  133. T = TypeVar("T")
  134. AnalysisDict = Dict[Tuple[BasicBlock, int], Set[T]]
  135. class AnalysisResult(Generic[T]):
  136. def __init__(self, before: AnalysisDict[T], after: AnalysisDict[T]) -> None:
  137. self.before = before
  138. self.after = after
  139. def __str__(self) -> str:
  140. return f"before: {self.before}\nafter: {self.after}\n"
  141. GenAndKill = Tuple[Set[T], Set[T]]
  142. class BaseAnalysisVisitor(OpVisitor[GenAndKill[T]]):
  143. def visit_goto(self, op: Goto) -> GenAndKill[T]:
  144. return set(), set()
  145. @abstractmethod
  146. def visit_register_op(self, op: RegisterOp) -> GenAndKill[T]:
  147. raise NotImplementedError
  148. @abstractmethod
  149. def visit_assign(self, op: Assign) -> GenAndKill[T]:
  150. raise NotImplementedError
  151. @abstractmethod
  152. def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[T]:
  153. raise NotImplementedError
  154. @abstractmethod
  155. def visit_set_mem(self, op: SetMem) -> GenAndKill[T]:
  156. raise NotImplementedError
  157. def visit_call(self, op: Call) -> GenAndKill[T]:
  158. return self.visit_register_op(op)
  159. def visit_method_call(self, op: MethodCall) -> GenAndKill[T]:
  160. return self.visit_register_op(op)
  161. def visit_load_error_value(self, op: LoadErrorValue) -> GenAndKill[T]:
  162. return self.visit_register_op(op)
  163. def visit_load_literal(self, op: LoadLiteral) -> GenAndKill[T]:
  164. return self.visit_register_op(op)
  165. def visit_get_attr(self, op: GetAttr) -> GenAndKill[T]:
  166. return self.visit_register_op(op)
  167. def visit_set_attr(self, op: SetAttr) -> GenAndKill[T]:
  168. return self.visit_register_op(op)
  169. def visit_load_static(self, op: LoadStatic) -> GenAndKill[T]:
  170. return self.visit_register_op(op)
  171. def visit_init_static(self, op: InitStatic) -> GenAndKill[T]:
  172. return self.visit_register_op(op)
  173. def visit_tuple_get(self, op: TupleGet) -> GenAndKill[T]:
  174. return self.visit_register_op(op)
  175. def visit_tuple_set(self, op: TupleSet) -> GenAndKill[T]:
  176. return self.visit_register_op(op)
  177. def visit_box(self, op: Box) -> GenAndKill[T]:
  178. return self.visit_register_op(op)
  179. def visit_unbox(self, op: Unbox) -> GenAndKill[T]:
  180. return self.visit_register_op(op)
  181. def visit_cast(self, op: Cast) -> GenAndKill[T]:
  182. return self.visit_register_op(op)
  183. def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill[T]:
  184. return self.visit_register_op(op)
  185. def visit_call_c(self, op: CallC) -> GenAndKill[T]:
  186. return self.visit_register_op(op)
  187. def visit_truncate(self, op: Truncate) -> GenAndKill[T]:
  188. return self.visit_register_op(op)
  189. def visit_extend(self, op: Extend) -> GenAndKill[T]:
  190. return self.visit_register_op(op)
  191. def visit_load_global(self, op: LoadGlobal) -> GenAndKill[T]:
  192. return self.visit_register_op(op)
  193. def visit_int_op(self, op: IntOp) -> GenAndKill[T]:
  194. return self.visit_register_op(op)
  195. def visit_float_op(self, op: FloatOp) -> GenAndKill[T]:
  196. return self.visit_register_op(op)
  197. def visit_float_neg(self, op: FloatNeg) -> GenAndKill[T]:
  198. return self.visit_register_op(op)
  199. def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill[T]:
  200. return self.visit_register_op(op)
  201. def visit_float_comparison_op(self, op: FloatComparisonOp) -> GenAndKill[T]:
  202. return self.visit_register_op(op)
  203. def visit_load_mem(self, op: LoadMem) -> GenAndKill[T]:
  204. return self.visit_register_op(op)
  205. def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill[T]:
  206. return self.visit_register_op(op)
  207. def visit_load_address(self, op: LoadAddress) -> GenAndKill[T]:
  208. return self.visit_register_op(op)
  209. def visit_keep_alive(self, op: KeepAlive) -> GenAndKill[T]:
  210. return self.visit_register_op(op)
  211. class DefinedVisitor(BaseAnalysisVisitor[Value]):
  212. """Visitor for finding defined registers.
  213. Note that this only deals with registers and not temporaries, on
  214. the assumption that we never access temporaries when they might be
  215. undefined.
  216. If strict_errors is True, then we regard any use of LoadErrorValue
  217. as making a register undefined. Otherwise we only do if
  218. `undefines` is set on the error value.
  219. This lets us only consider the things we care about during
  220. uninitialized variable checking while capturing all possibly
  221. undefined things for refcounting.
  222. """
  223. def __init__(self, strict_errors: bool = False) -> None:
  224. self.strict_errors = strict_errors
  225. def visit_branch(self, op: Branch) -> GenAndKill[Value]:
  226. return set(), set()
  227. def visit_return(self, op: Return) -> GenAndKill[Value]:
  228. return set(), set()
  229. def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
  230. return set(), set()
  231. def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
  232. return set(), set()
  233. def visit_assign(self, op: Assign) -> GenAndKill[Value]:
  234. # Loading an error value may undefine the register.
  235. if isinstance(op.src, LoadErrorValue) and (op.src.undefines or self.strict_errors):
  236. return set(), {op.dest}
  237. else:
  238. return {op.dest}, set()
  239. def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
  240. # Array registers are special and we don't track the definedness of them.
  241. return set(), set()
  242. def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
  243. return set(), set()
  244. def analyze_maybe_defined_regs(
  245. blocks: list[BasicBlock], cfg: CFG, initial_defined: set[Value]
  246. ) -> AnalysisResult[Value]:
  247. """Calculate potentially defined registers at each CFG location.
  248. A register is defined if it has a value along some path from the initial location.
  249. """
  250. return run_analysis(
  251. blocks=blocks,
  252. cfg=cfg,
  253. gen_and_kill=DefinedVisitor(),
  254. initial=initial_defined,
  255. backward=False,
  256. kind=MAYBE_ANALYSIS,
  257. )
  258. def analyze_must_defined_regs(
  259. blocks: list[BasicBlock],
  260. cfg: CFG,
  261. initial_defined: set[Value],
  262. regs: Iterable[Value],
  263. strict_errors: bool = False,
  264. ) -> AnalysisResult[Value]:
  265. """Calculate always defined registers at each CFG location.
  266. This analysis can work before exception insertion, since it is a
  267. sound assumption that registers defined in a block might not be
  268. initialized in its error handler.
  269. A register is defined if it has a value along all paths from the
  270. initial location.
  271. """
  272. return run_analysis(
  273. blocks=blocks,
  274. cfg=cfg,
  275. gen_and_kill=DefinedVisitor(strict_errors=strict_errors),
  276. initial=initial_defined,
  277. backward=False,
  278. kind=MUST_ANALYSIS,
  279. universe=set(regs),
  280. )
  281. class BorrowedArgumentsVisitor(BaseAnalysisVisitor[Value]):
  282. def __init__(self, args: set[Value]) -> None:
  283. self.args = args
  284. def visit_branch(self, op: Branch) -> GenAndKill[Value]:
  285. return set(), set()
  286. def visit_return(self, op: Return) -> GenAndKill[Value]:
  287. return set(), set()
  288. def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
  289. return set(), set()
  290. def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
  291. return set(), set()
  292. def visit_assign(self, op: Assign) -> GenAndKill[Value]:
  293. if op.dest in self.args:
  294. return set(), {op.dest}
  295. return set(), set()
  296. def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
  297. return set(), set()
  298. def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
  299. return set(), set()
  300. def analyze_borrowed_arguments(
  301. blocks: list[BasicBlock], cfg: CFG, borrowed: set[Value]
  302. ) -> AnalysisResult[Value]:
  303. """Calculate arguments that can use references borrowed from the caller.
  304. When assigning to an argument, it no longer is borrowed.
  305. """
  306. return run_analysis(
  307. blocks=blocks,
  308. cfg=cfg,
  309. gen_and_kill=BorrowedArgumentsVisitor(borrowed),
  310. initial=borrowed,
  311. backward=False,
  312. kind=MUST_ANALYSIS,
  313. universe=borrowed,
  314. )
  315. class UndefinedVisitor(BaseAnalysisVisitor[Value]):
  316. def visit_branch(self, op: Branch) -> GenAndKill[Value]:
  317. return set(), set()
  318. def visit_return(self, op: Return) -> GenAndKill[Value]:
  319. return set(), set()
  320. def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
  321. return set(), set()
  322. def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
  323. return set(), {op} if not op.is_void else set()
  324. def visit_assign(self, op: Assign) -> GenAndKill[Value]:
  325. return set(), {op.dest}
  326. def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
  327. return set(), {op.dest}
  328. def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
  329. return set(), set()
  330. def analyze_undefined_regs(
  331. blocks: list[BasicBlock], cfg: CFG, initial_defined: set[Value]
  332. ) -> AnalysisResult[Value]:
  333. """Calculate potentially undefined registers at each CFG location.
  334. A register is undefined if there is some path from initial block
  335. where it has an undefined value.
  336. Function arguments are assumed to be always defined.
  337. """
  338. initial_undefined = set(all_values([], blocks)) - initial_defined
  339. return run_analysis(
  340. blocks=blocks,
  341. cfg=cfg,
  342. gen_and_kill=UndefinedVisitor(),
  343. initial=initial_undefined,
  344. backward=False,
  345. kind=MAYBE_ANALYSIS,
  346. )
  347. def non_trivial_sources(op: Op) -> set[Value]:
  348. result = set()
  349. for source in op.sources():
  350. if not isinstance(source, (Integer, Float)):
  351. result.add(source)
  352. return result
  353. class LivenessVisitor(BaseAnalysisVisitor[Value]):
  354. def visit_branch(self, op: Branch) -> GenAndKill[Value]:
  355. return non_trivial_sources(op), set()
  356. def visit_return(self, op: Return) -> GenAndKill[Value]:
  357. if not isinstance(op.value, (Integer, Float)):
  358. return {op.value}, set()
  359. else:
  360. return set(), set()
  361. def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
  362. return set(), set()
  363. def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
  364. gen = non_trivial_sources(op)
  365. if not op.is_void:
  366. return gen, {op}
  367. else:
  368. return gen, set()
  369. def visit_assign(self, op: Assign) -> GenAndKill[Value]:
  370. return non_trivial_sources(op), {op.dest}
  371. def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
  372. return non_trivial_sources(op), {op.dest}
  373. def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
  374. return non_trivial_sources(op), set()
  375. def analyze_live_regs(blocks: list[BasicBlock], cfg: CFG) -> AnalysisResult[Value]:
  376. """Calculate live registers at each CFG location.
  377. A register is live at a location if it can be read along some CFG path starting
  378. from the location.
  379. """
  380. return run_analysis(
  381. blocks=blocks,
  382. cfg=cfg,
  383. gen_and_kill=LivenessVisitor(),
  384. initial=set(),
  385. backward=True,
  386. kind=MAYBE_ANALYSIS,
  387. )
  388. # Analysis kinds
  389. MUST_ANALYSIS = 0
  390. MAYBE_ANALYSIS = 1
  391. def run_analysis(
  392. blocks: list[BasicBlock],
  393. cfg: CFG,
  394. gen_and_kill: OpVisitor[GenAndKill[T]],
  395. initial: set[T],
  396. kind: int,
  397. backward: bool,
  398. universe: set[T] | None = None,
  399. ) -> AnalysisResult[T]:
  400. """Run a general set-based data flow analysis.
  401. Args:
  402. blocks: All basic blocks
  403. cfg: Control-flow graph for the code
  404. gen_and_kill: Implementation of gen and kill functions for each op
  405. initial: Value of analysis for the entry points (for a forward analysis) or the
  406. exit points (for a backward analysis)
  407. kind: MUST_ANALYSIS or MAYBE_ANALYSIS
  408. backward: If False, the analysis is a forward analysis; it's backward otherwise
  409. universe: For a must analysis, the set of all possible values. This is the starting
  410. value for the work list algorithm, which will narrow this down until reaching a
  411. fixed point. For a maybe analysis the iteration always starts from an empty set
  412. and this argument is ignored.
  413. Return analysis results: (before, after)
  414. """
  415. block_gen = {}
  416. block_kill = {}
  417. # Calculate kill and gen sets for entire basic blocks.
  418. for block in blocks:
  419. gen: set[T] = set()
  420. kill: set[T] = set()
  421. ops = block.ops
  422. if backward:
  423. ops = list(reversed(ops))
  424. for op in ops:
  425. opgen, opkill = op.accept(gen_and_kill)
  426. gen = (gen - opkill) | opgen
  427. kill = (kill - opgen) | opkill
  428. block_gen[block] = gen
  429. block_kill[block] = kill
  430. # Set up initial state for worklist algorithm.
  431. worklist = list(blocks)
  432. if not backward:
  433. worklist = worklist[::-1] # Reverse for a small performance improvement
  434. workset = set(worklist)
  435. before: dict[BasicBlock, set[T]] = {}
  436. after: dict[BasicBlock, set[T]] = {}
  437. for block in blocks:
  438. if kind == MAYBE_ANALYSIS:
  439. before[block] = set()
  440. after[block] = set()
  441. else:
  442. assert universe is not None, "Universe must be defined for a must analysis"
  443. before[block] = set(universe)
  444. after[block] = set(universe)
  445. if backward:
  446. pred_map = cfg.succ
  447. succ_map = cfg.pred
  448. else:
  449. pred_map = cfg.pred
  450. succ_map = cfg.succ
  451. # Run work list algorithm to generate in and out sets for each basic block.
  452. while worklist:
  453. label = worklist.pop()
  454. workset.remove(label)
  455. if pred_map[label]:
  456. new_before: set[T] | None = None
  457. for pred in pred_map[label]:
  458. if new_before is None:
  459. new_before = set(after[pred])
  460. elif kind == MAYBE_ANALYSIS:
  461. new_before |= after[pred]
  462. else:
  463. new_before &= after[pred]
  464. assert new_before is not None
  465. else:
  466. new_before = set(initial)
  467. before[label] = new_before
  468. new_after = (new_before - block_kill[label]) | block_gen[label]
  469. if new_after != after[label]:
  470. for succ in succ_map[label]:
  471. if succ not in workset:
  472. worklist.append(succ)
  473. workset.add(succ)
  474. after[label] = new_after
  475. # Run algorithm for each basic block to generate opcode-level sets.
  476. op_before: dict[tuple[BasicBlock, int], set[T]] = {}
  477. op_after: dict[tuple[BasicBlock, int], set[T]] = {}
  478. for block in blocks:
  479. label = block
  480. cur = before[label]
  481. ops_enum: Iterator[tuple[int, Op]] = enumerate(block.ops)
  482. if backward:
  483. ops_enum = reversed(list(ops_enum))
  484. for idx, op in ops_enum:
  485. op_before[label, idx] = cur
  486. opgen, opkill = op.accept(gen_and_kill)
  487. cur = (cur - opkill) | opgen
  488. op_after[label, idx] = cur
  489. if backward:
  490. op_after, op_before = op_before, op_after
  491. return AnalysisResult(op_before, op_after)