| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- from __future__ import annotations
- from mypy.nodes import (
- Block,
- Decorator,
- Expression,
- FuncDef,
- FuncItem,
- Import,
- LambdaExpr,
- MemberExpr,
- MypyFile,
- NameExpr,
- Node,
- SymbolNode,
- Var,
- )
- from mypy.traverser import ExtendedTraverserVisitor
- from mypyc.errors import Errors
- class PreBuildVisitor(ExtendedTraverserVisitor):
- """Mypy file AST visitor run before building the IR.
- This collects various things, including:
- * Determine relationships between nested functions and functions that
- contain nested functions
- * Find non-local variables (free variables)
- * Find property setters
- * Find decorators of functions
- * Find module import groups
- The main IR build pass uses this information.
- """
- def __init__(
- self,
- errors: Errors,
- current_file: MypyFile,
- decorators_to_remove: dict[FuncDef, list[int]],
- ) -> None:
- super().__init__()
- # Dict from a function to symbols defined directly in the
- # function that are used as non-local (free) variables within a
- # nested function.
- self.free_variables: dict[FuncItem, set[SymbolNode]] = {}
- # Intermediate data structure used to find the function where
- # a SymbolNode is declared. Initially this may point to a
- # function nested inside the function with the declaration,
- # but we'll eventually update this to refer to the function
- # with the declaration.
- self.symbols_to_funcs: dict[SymbolNode, FuncItem] = {}
- # Stack representing current function nesting.
- self.funcs: list[FuncItem] = []
- # All property setters encountered so far.
- self.prop_setters: set[FuncDef] = set()
- # A map from any function that contains nested functions to
- # a set of all the functions that are nested within it.
- self.encapsulating_funcs: dict[FuncItem, list[FuncItem]] = {}
- # Map nested function to its parent/encapsulating function.
- self.nested_funcs: dict[FuncItem, FuncItem] = {}
- # Map function to its non-special decorators.
- self.funcs_to_decorators: dict[FuncDef, list[Expression]] = {}
- # Map function to indices of decorators to remove
- self.decorators_to_remove: dict[FuncDef, list[int]] = decorators_to_remove
- # A mapping of import groups (a series of Import nodes with
- # nothing inbetween) where each group is keyed by its first
- # import node.
- self.module_import_groups: dict[Import, list[Import]] = {}
- self._current_import_group: Import | None = None
- self.errors: Errors = errors
- self.current_file: MypyFile = current_file
- def visit(self, o: Node) -> bool:
- if not isinstance(o, Import):
- self._current_import_group = None
- return True
- def visit_block(self, block: Block) -> None:
- self._current_import_group = None
- super().visit_block(block)
- self._current_import_group = None
- def visit_decorator(self, dec: Decorator) -> None:
- if dec.decorators:
- # Only add the function being decorated if there exist
- # (ordinary) decorators in the decorator list. Certain
- # decorators (such as @property, @abstractmethod) are
- # special cased and removed from this list by
- # mypy. Functions decorated only by special decorators
- # (and property setters) are not treated as decorated
- # functions by the IR builder.
- if isinstance(dec.decorators[0], MemberExpr) and dec.decorators[0].name == "setter":
- # Property setters are not treated as decorated methods.
- self.prop_setters.add(dec.func)
- else:
- decorators_to_store = dec.decorators.copy()
- if dec.func in self.decorators_to_remove:
- to_remove = self.decorators_to_remove[dec.func]
- for i in reversed(to_remove):
- del decorators_to_store[i]
- # if all of the decorators are removed, we shouldn't treat this as a decorated
- # function because there aren't any decorators to apply
- if not decorators_to_store:
- return
- self.funcs_to_decorators[dec.func] = decorators_to_store
- super().visit_decorator(dec)
- def visit_func_def(self, fdef: FuncItem) -> None:
- # TODO: What about overloaded functions?
- self.visit_func(fdef)
- def visit_lambda_expr(self, expr: LambdaExpr) -> None:
- self.visit_func(expr)
- def visit_func(self, func: FuncItem) -> None:
- # If there were already functions or lambda expressions
- # defined in the function stack, then note the previous
- # FuncItem as containing a nested function and the current
- # FuncItem as being a nested function.
- if self.funcs:
- # Add the new func to the set of nested funcs within the
- # func at top of the func stack.
- self.encapsulating_funcs.setdefault(self.funcs[-1], []).append(func)
- # Add the func at top of the func stack as the parent of
- # new func.
- self.nested_funcs[func] = self.funcs[-1]
- self.funcs.append(func)
- super().visit_func(func)
- self.funcs.pop()
- def visit_import(self, imp: Import) -> None:
- if self._current_import_group is not None:
- self.module_import_groups[self._current_import_group].append(imp)
- else:
- self.module_import_groups[imp] = [imp]
- self._current_import_group = imp
- super().visit_import(imp)
- def visit_name_expr(self, expr: NameExpr) -> None:
- if isinstance(expr.node, (Var, FuncDef)):
- self.visit_symbol_node(expr.node)
- def visit_var(self, var: Var) -> None:
- self.visit_symbol_node(var)
- def visit_symbol_node(self, symbol: SymbolNode) -> None:
- if not self.funcs:
- # We are not inside a function and hence do not need to do
- # anything regarding free variables.
- return
- if symbol in self.symbols_to_funcs:
- orig_func = self.symbols_to_funcs[symbol]
- if self.is_parent(self.funcs[-1], orig_func):
- # The function in which the symbol was previously seen is
- # nested within the function currently being visited. Thus
- # the current function is a better candidate to contain the
- # declaration.
- self.symbols_to_funcs[symbol] = self.funcs[-1]
- # TODO: Remove from the orig_func free_variables set?
- self.free_variables.setdefault(self.funcs[-1], set()).add(symbol)
- elif self.is_parent(orig_func, self.funcs[-1]):
- # The SymbolNode instance has already been visited
- # before in a parent function, thus it's a non-local
- # symbol.
- self.add_free_variable(symbol)
- else:
- # This is the first time the SymbolNode is being
- # visited. We map the SymbolNode to the current FuncDef
- # being visited to note where it was first visited.
- self.symbols_to_funcs[symbol] = self.funcs[-1]
- def is_parent(self, fitem: FuncItem, child: FuncItem) -> bool:
- # Check if child is nested within fdef (possibly indirectly
- # within multiple nested functions).
- if child not in self.nested_funcs:
- return False
- parent = self.nested_funcs[child]
- return parent == fitem or self.is_parent(fitem, parent)
- def add_free_variable(self, symbol: SymbolNode) -> None:
- # Find the function where the symbol was (likely) first declared,
- # and mark is as a non-local symbol within that function.
- func = self.symbols_to_funcs[symbol]
- self.free_variables.setdefault(func, set()).add(symbol)
|