| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623 |
- """Data-flow analyses."""
- from __future__ import annotations
- from abc import abstractmethod
- from typing import Dict, Generic, Iterable, Iterator, Set, Tuple, TypeVar
- from mypyc.ir.func_ir import all_values
- from mypyc.ir.ops import (
- Assign,
- AssignMulti,
- BasicBlock,
- Box,
- Branch,
- Call,
- CallC,
- Cast,
- ComparisonOp,
- ControlOp,
- Extend,
- Float,
- FloatComparisonOp,
- FloatNeg,
- FloatOp,
- GetAttr,
- GetElementPtr,
- Goto,
- InitStatic,
- Integer,
- IntOp,
- KeepAlive,
- LoadAddress,
- LoadErrorValue,
- LoadGlobal,
- LoadLiteral,
- LoadMem,
- LoadStatic,
- MethodCall,
- Op,
- OpVisitor,
- RaiseStandardError,
- RegisterOp,
- Return,
- SetAttr,
- SetMem,
- Truncate,
- TupleGet,
- TupleSet,
- Unbox,
- Unreachable,
- Value,
- )
- class CFG:
- """Control-flow graph.
- Node 0 is always assumed to be the entry point. There must be a
- non-empty set of exits.
- """
- def __init__(
- self,
- succ: dict[BasicBlock, list[BasicBlock]],
- pred: dict[BasicBlock, list[BasicBlock]],
- exits: set[BasicBlock],
- ) -> None:
- assert exits
- self.succ = succ
- self.pred = pred
- self.exits = exits
- def __str__(self) -> str:
- lines = []
- lines.append("exits: %s" % sorted(self.exits, key=lambda e: int(e.label)))
- lines.append("succ: %s" % self.succ)
- lines.append("pred: %s" % self.pred)
- return "\n".join(lines)
- def get_cfg(blocks: list[BasicBlock]) -> CFG:
- """Calculate basic block control-flow graph.
- The result is a dictionary like this:
- basic block index -> (successors blocks, predecesssor blocks)
- """
- succ_map = {}
- pred_map: dict[BasicBlock, list[BasicBlock]] = {}
- exits = set()
- for block in blocks:
- assert not any(
- isinstance(op, ControlOp) for op in block.ops[:-1]
- ), "Control-flow ops must be at the end of blocks"
- succ = list(block.terminator.targets())
- if not succ:
- exits.add(block)
- # Errors can occur anywhere inside a block, which means that
- # we can't assume that the entire block has executed before
- # jumping to the error handler. In our CFG construction, we
- # model this as saying that a block can jump to its error
- # handler or the error handlers of any of its normal
- # successors (to represent an error before that next block
- # completes). This works well for analyses like "must
- # defined", where it implies that registers assigned in a
- # block may be undefined in its error handler, but is in
- # general not a precise representation of reality; any
- # analyses that require more fidelity must wait until after
- # exception insertion.
- for error_point in [block] + succ:
- if error_point.error_handler:
- succ.append(error_point.error_handler)
- succ_map[block] = succ
- pred_map[block] = []
- for prev, nxt in succ_map.items():
- for label in nxt:
- pred_map[label].append(prev)
- return CFG(succ_map, pred_map, exits)
- def get_real_target(label: BasicBlock) -> BasicBlock:
- if len(label.ops) == 1 and isinstance(label.ops[-1], Goto):
- label = label.ops[-1].label
- return label
- def cleanup_cfg(blocks: list[BasicBlock]) -> None:
- """Cleanup the control flow graph.
- This eliminates obviously dead basic blocks and eliminates blocks that contain
- nothing but a single jump.
- There is a lot more that could be done.
- """
- changed = True
- while changed:
- # First collapse any jumps to basic block that only contain a goto
- for block in blocks:
- for i, tgt in enumerate(block.terminator.targets()):
- block.terminator.set_target(i, get_real_target(tgt))
- # Then delete any blocks that have no predecessors
- changed = False
- cfg = get_cfg(blocks)
- orig_blocks = blocks.copy()
- blocks.clear()
- for i, block in enumerate(orig_blocks):
- if i == 0 or cfg.pred[block]:
- blocks.append(block)
- else:
- changed = True
- T = TypeVar("T")
- AnalysisDict = Dict[Tuple[BasicBlock, int], Set[T]]
- class AnalysisResult(Generic[T]):
- def __init__(self, before: AnalysisDict[T], after: AnalysisDict[T]) -> None:
- self.before = before
- self.after = after
- def __str__(self) -> str:
- return f"before: {self.before}\nafter: {self.after}\n"
- GenAndKill = Tuple[Set[T], Set[T]]
- class BaseAnalysisVisitor(OpVisitor[GenAndKill[T]]):
- def visit_goto(self, op: Goto) -> GenAndKill[T]:
- return set(), set()
- @abstractmethod
- def visit_register_op(self, op: RegisterOp) -> GenAndKill[T]:
- raise NotImplementedError
- @abstractmethod
- def visit_assign(self, op: Assign) -> GenAndKill[T]:
- raise NotImplementedError
- @abstractmethod
- def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[T]:
- raise NotImplementedError
- @abstractmethod
- def visit_set_mem(self, op: SetMem) -> GenAndKill[T]:
- raise NotImplementedError
- def visit_call(self, op: Call) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_method_call(self, op: MethodCall) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_load_error_value(self, op: LoadErrorValue) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_load_literal(self, op: LoadLiteral) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_get_attr(self, op: GetAttr) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_set_attr(self, op: SetAttr) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_load_static(self, op: LoadStatic) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_init_static(self, op: InitStatic) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_tuple_get(self, op: TupleGet) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_tuple_set(self, op: TupleSet) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_box(self, op: Box) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_unbox(self, op: Unbox) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_cast(self, op: Cast) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_call_c(self, op: CallC) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_truncate(self, op: Truncate) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_extend(self, op: Extend) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_load_global(self, op: LoadGlobal) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_int_op(self, op: IntOp) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_float_op(self, op: FloatOp) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_float_neg(self, op: FloatNeg) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_float_comparison_op(self, op: FloatComparisonOp) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_load_mem(self, op: LoadMem) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_load_address(self, op: LoadAddress) -> GenAndKill[T]:
- return self.visit_register_op(op)
- def visit_keep_alive(self, op: KeepAlive) -> GenAndKill[T]:
- return self.visit_register_op(op)
- class DefinedVisitor(BaseAnalysisVisitor[Value]):
- """Visitor for finding defined registers.
- Note that this only deals with registers and not temporaries, on
- the assumption that we never access temporaries when they might be
- undefined.
- If strict_errors is True, then we regard any use of LoadErrorValue
- as making a register undefined. Otherwise we only do if
- `undefines` is set on the error value.
- This lets us only consider the things we care about during
- uninitialized variable checking while capturing all possibly
- undefined things for refcounting.
- """
- def __init__(self, strict_errors: bool = False) -> None:
- self.strict_errors = strict_errors
- def visit_branch(self, op: Branch) -> GenAndKill[Value]:
- return set(), set()
- def visit_return(self, op: Return) -> GenAndKill[Value]:
- return set(), set()
- def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
- return set(), set()
- def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
- return set(), set()
- def visit_assign(self, op: Assign) -> GenAndKill[Value]:
- # Loading an error value may undefine the register.
- if isinstance(op.src, LoadErrorValue) and (op.src.undefines or self.strict_errors):
- return set(), {op.dest}
- else:
- return {op.dest}, set()
- def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
- # Array registers are special and we don't track the definedness of them.
- return set(), set()
- def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
- return set(), set()
- def analyze_maybe_defined_regs(
- blocks: list[BasicBlock], cfg: CFG, initial_defined: set[Value]
- ) -> AnalysisResult[Value]:
- """Calculate potentially defined registers at each CFG location.
- A register is defined if it has a value along some path from the initial location.
- """
- return run_analysis(
- blocks=blocks,
- cfg=cfg,
- gen_and_kill=DefinedVisitor(),
- initial=initial_defined,
- backward=False,
- kind=MAYBE_ANALYSIS,
- )
- def analyze_must_defined_regs(
- blocks: list[BasicBlock],
- cfg: CFG,
- initial_defined: set[Value],
- regs: Iterable[Value],
- strict_errors: bool = False,
- ) -> AnalysisResult[Value]:
- """Calculate always defined registers at each CFG location.
- This analysis can work before exception insertion, since it is a
- sound assumption that registers defined in a block might not be
- initialized in its error handler.
- A register is defined if it has a value along all paths from the
- initial location.
- """
- return run_analysis(
- blocks=blocks,
- cfg=cfg,
- gen_and_kill=DefinedVisitor(strict_errors=strict_errors),
- initial=initial_defined,
- backward=False,
- kind=MUST_ANALYSIS,
- universe=set(regs),
- )
- class BorrowedArgumentsVisitor(BaseAnalysisVisitor[Value]):
- def __init__(self, args: set[Value]) -> None:
- self.args = args
- def visit_branch(self, op: Branch) -> GenAndKill[Value]:
- return set(), set()
- def visit_return(self, op: Return) -> GenAndKill[Value]:
- return set(), set()
- def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
- return set(), set()
- def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
- return set(), set()
- def visit_assign(self, op: Assign) -> GenAndKill[Value]:
- if op.dest in self.args:
- return set(), {op.dest}
- return set(), set()
- def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
- return set(), set()
- def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
- return set(), set()
- def analyze_borrowed_arguments(
- blocks: list[BasicBlock], cfg: CFG, borrowed: set[Value]
- ) -> AnalysisResult[Value]:
- """Calculate arguments that can use references borrowed from the caller.
- When assigning to an argument, it no longer is borrowed.
- """
- return run_analysis(
- blocks=blocks,
- cfg=cfg,
- gen_and_kill=BorrowedArgumentsVisitor(borrowed),
- initial=borrowed,
- backward=False,
- kind=MUST_ANALYSIS,
- universe=borrowed,
- )
- class UndefinedVisitor(BaseAnalysisVisitor[Value]):
- def visit_branch(self, op: Branch) -> GenAndKill[Value]:
- return set(), set()
- def visit_return(self, op: Return) -> GenAndKill[Value]:
- return set(), set()
- def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
- return set(), set()
- def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
- return set(), {op} if not op.is_void else set()
- def visit_assign(self, op: Assign) -> GenAndKill[Value]:
- return set(), {op.dest}
- def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
- return set(), {op.dest}
- def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
- return set(), set()
- def analyze_undefined_regs(
- blocks: list[BasicBlock], cfg: CFG, initial_defined: set[Value]
- ) -> AnalysisResult[Value]:
- """Calculate potentially undefined registers at each CFG location.
- A register is undefined if there is some path from initial block
- where it has an undefined value.
- Function arguments are assumed to be always defined.
- """
- initial_undefined = set(all_values([], blocks)) - initial_defined
- return run_analysis(
- blocks=blocks,
- cfg=cfg,
- gen_and_kill=UndefinedVisitor(),
- initial=initial_undefined,
- backward=False,
- kind=MAYBE_ANALYSIS,
- )
- def non_trivial_sources(op: Op) -> set[Value]:
- result = set()
- for source in op.sources():
- if not isinstance(source, (Integer, Float)):
- result.add(source)
- return result
- class LivenessVisitor(BaseAnalysisVisitor[Value]):
- def visit_branch(self, op: Branch) -> GenAndKill[Value]:
- return non_trivial_sources(op), set()
- def visit_return(self, op: Return) -> GenAndKill[Value]:
- if not isinstance(op.value, (Integer, Float)):
- return {op.value}, set()
- else:
- return set(), set()
- def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
- return set(), set()
- def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
- gen = non_trivial_sources(op)
- if not op.is_void:
- return gen, {op}
- else:
- return gen, set()
- def visit_assign(self, op: Assign) -> GenAndKill[Value]:
- return non_trivial_sources(op), {op.dest}
- def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
- return non_trivial_sources(op), {op.dest}
- def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
- return non_trivial_sources(op), set()
- def analyze_live_regs(blocks: list[BasicBlock], cfg: CFG) -> AnalysisResult[Value]:
- """Calculate live registers at each CFG location.
- A register is live at a location if it can be read along some CFG path starting
- from the location.
- """
- return run_analysis(
- blocks=blocks,
- cfg=cfg,
- gen_and_kill=LivenessVisitor(),
- initial=set(),
- backward=True,
- kind=MAYBE_ANALYSIS,
- )
- # Analysis kinds
- MUST_ANALYSIS = 0
- MAYBE_ANALYSIS = 1
- def run_analysis(
- blocks: list[BasicBlock],
- cfg: CFG,
- gen_and_kill: OpVisitor[GenAndKill[T]],
- initial: set[T],
- kind: int,
- backward: bool,
- universe: set[T] | None = None,
- ) -> AnalysisResult[T]:
- """Run a general set-based data flow analysis.
- Args:
- blocks: All basic blocks
- cfg: Control-flow graph for the code
- gen_and_kill: Implementation of gen and kill functions for each op
- initial: Value of analysis for the entry points (for a forward analysis) or the
- exit points (for a backward analysis)
- kind: MUST_ANALYSIS or MAYBE_ANALYSIS
- backward: If False, the analysis is a forward analysis; it's backward otherwise
- universe: For a must analysis, the set of all possible values. This is the starting
- value for the work list algorithm, which will narrow this down until reaching a
- fixed point. For a maybe analysis the iteration always starts from an empty set
- and this argument is ignored.
- Return analysis results: (before, after)
- """
- block_gen = {}
- block_kill = {}
- # Calculate kill and gen sets for entire basic blocks.
- for block in blocks:
- gen: set[T] = set()
- kill: set[T] = set()
- ops = block.ops
- if backward:
- ops = list(reversed(ops))
- for op in ops:
- opgen, opkill = op.accept(gen_and_kill)
- gen = (gen - opkill) | opgen
- kill = (kill - opgen) | opkill
- block_gen[block] = gen
- block_kill[block] = kill
- # Set up initial state for worklist algorithm.
- worklist = list(blocks)
- if not backward:
- worklist = worklist[::-1] # Reverse for a small performance improvement
- workset = set(worklist)
- before: dict[BasicBlock, set[T]] = {}
- after: dict[BasicBlock, set[T]] = {}
- for block in blocks:
- if kind == MAYBE_ANALYSIS:
- before[block] = set()
- after[block] = set()
- else:
- assert universe is not None, "Universe must be defined for a must analysis"
- before[block] = set(universe)
- after[block] = set(universe)
- if backward:
- pred_map = cfg.succ
- succ_map = cfg.pred
- else:
- pred_map = cfg.pred
- succ_map = cfg.succ
- # Run work list algorithm to generate in and out sets for each basic block.
- while worklist:
- label = worklist.pop()
- workset.remove(label)
- if pred_map[label]:
- new_before: set[T] | None = None
- for pred in pred_map[label]:
- if new_before is None:
- new_before = set(after[pred])
- elif kind == MAYBE_ANALYSIS:
- new_before |= after[pred]
- else:
- new_before &= after[pred]
- assert new_before is not None
- else:
- new_before = set(initial)
- before[label] = new_before
- new_after = (new_before - block_kill[label]) | block_gen[label]
- if new_after != after[label]:
- for succ in succ_map[label]:
- if succ not in workset:
- worklist.append(succ)
- workset.add(succ)
- after[label] = new_after
- # Run algorithm for each basic block to generate opcode-level sets.
- op_before: dict[tuple[BasicBlock, int], set[T]] = {}
- op_after: dict[tuple[BasicBlock, int], set[T]] = {}
- for block in blocks:
- label = block
- cur = before[label]
- ops_enum: Iterator[tuple[int, Op]] = enumerate(block.ops)
- if backward:
- ops_enum = reversed(list(ops_enum))
- for idx, op in ops_enum:
- op_before[label, idx] = cur
- opgen, opkill = op.accept(gen_and_kill)
- cur = (cur - opkill) | opgen
- op_after[label, idx] = cur
- if backward:
- op_after, op_before = op_before, op_after
- return AnalysisResult(op_before, op_after)
|