prebuildvisitor.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from __future__ import annotations
  2. from mypy.nodes import (
  3. Block,
  4. Decorator,
  5. Expression,
  6. FuncDef,
  7. FuncItem,
  8. Import,
  9. LambdaExpr,
  10. MemberExpr,
  11. MypyFile,
  12. NameExpr,
  13. Node,
  14. SymbolNode,
  15. Var,
  16. )
  17. from mypy.traverser import ExtendedTraverserVisitor
  18. from mypyc.errors import Errors
  19. class PreBuildVisitor(ExtendedTraverserVisitor):
  20. """Mypy file AST visitor run before building the IR.
  21. This collects various things, including:
  22. * Determine relationships between nested functions and functions that
  23. contain nested functions
  24. * Find non-local variables (free variables)
  25. * Find property setters
  26. * Find decorators of functions
  27. * Find module import groups
  28. The main IR build pass uses this information.
  29. """
  30. def __init__(
  31. self,
  32. errors: Errors,
  33. current_file: MypyFile,
  34. decorators_to_remove: dict[FuncDef, list[int]],
  35. ) -> None:
  36. super().__init__()
  37. # Dict from a function to symbols defined directly in the
  38. # function that are used as non-local (free) variables within a
  39. # nested function.
  40. self.free_variables: dict[FuncItem, set[SymbolNode]] = {}
  41. # Intermediate data structure used to find the function where
  42. # a SymbolNode is declared. Initially this may point to a
  43. # function nested inside the function with the declaration,
  44. # but we'll eventually update this to refer to the function
  45. # with the declaration.
  46. self.symbols_to_funcs: dict[SymbolNode, FuncItem] = {}
  47. # Stack representing current function nesting.
  48. self.funcs: list[FuncItem] = []
  49. # All property setters encountered so far.
  50. self.prop_setters: set[FuncDef] = set()
  51. # A map from any function that contains nested functions to
  52. # a set of all the functions that are nested within it.
  53. self.encapsulating_funcs: dict[FuncItem, list[FuncItem]] = {}
  54. # Map nested function to its parent/encapsulating function.
  55. self.nested_funcs: dict[FuncItem, FuncItem] = {}
  56. # Map function to its non-special decorators.
  57. self.funcs_to_decorators: dict[FuncDef, list[Expression]] = {}
  58. # Map function to indices of decorators to remove
  59. self.decorators_to_remove: dict[FuncDef, list[int]] = decorators_to_remove
  60. # A mapping of import groups (a series of Import nodes with
  61. # nothing inbetween) where each group is keyed by its first
  62. # import node.
  63. self.module_import_groups: dict[Import, list[Import]] = {}
  64. self._current_import_group: Import | None = None
  65. self.errors: Errors = errors
  66. self.current_file: MypyFile = current_file
  67. def visit(self, o: Node) -> bool:
  68. if not isinstance(o, Import):
  69. self._current_import_group = None
  70. return True
  71. def visit_block(self, block: Block) -> None:
  72. self._current_import_group = None
  73. super().visit_block(block)
  74. self._current_import_group = None
  75. def visit_decorator(self, dec: Decorator) -> None:
  76. if dec.decorators:
  77. # Only add the function being decorated if there exist
  78. # (ordinary) decorators in the decorator list. Certain
  79. # decorators (such as @property, @abstractmethod) are
  80. # special cased and removed from this list by
  81. # mypy. Functions decorated only by special decorators
  82. # (and property setters) are not treated as decorated
  83. # functions by the IR builder.
  84. if isinstance(dec.decorators[0], MemberExpr) and dec.decorators[0].name == "setter":
  85. # Property setters are not treated as decorated methods.
  86. self.prop_setters.add(dec.func)
  87. else:
  88. decorators_to_store = dec.decorators.copy()
  89. if dec.func in self.decorators_to_remove:
  90. to_remove = self.decorators_to_remove[dec.func]
  91. for i in reversed(to_remove):
  92. del decorators_to_store[i]
  93. # if all of the decorators are removed, we shouldn't treat this as a decorated
  94. # function because there aren't any decorators to apply
  95. if not decorators_to_store:
  96. return
  97. self.funcs_to_decorators[dec.func] = decorators_to_store
  98. super().visit_decorator(dec)
  99. def visit_func_def(self, fdef: FuncItem) -> None:
  100. # TODO: What about overloaded functions?
  101. self.visit_func(fdef)
  102. def visit_lambda_expr(self, expr: LambdaExpr) -> None:
  103. self.visit_func(expr)
  104. def visit_func(self, func: FuncItem) -> None:
  105. # If there were already functions or lambda expressions
  106. # defined in the function stack, then note the previous
  107. # FuncItem as containing a nested function and the current
  108. # FuncItem as being a nested function.
  109. if self.funcs:
  110. # Add the new func to the set of nested funcs within the
  111. # func at top of the func stack.
  112. self.encapsulating_funcs.setdefault(self.funcs[-1], []).append(func)
  113. # Add the func at top of the func stack as the parent of
  114. # new func.
  115. self.nested_funcs[func] = self.funcs[-1]
  116. self.funcs.append(func)
  117. super().visit_func(func)
  118. self.funcs.pop()
  119. def visit_import(self, imp: Import) -> None:
  120. if self._current_import_group is not None:
  121. self.module_import_groups[self._current_import_group].append(imp)
  122. else:
  123. self.module_import_groups[imp] = [imp]
  124. self._current_import_group = imp
  125. super().visit_import(imp)
  126. def visit_name_expr(self, expr: NameExpr) -> None:
  127. if isinstance(expr.node, (Var, FuncDef)):
  128. self.visit_symbol_node(expr.node)
  129. def visit_var(self, var: Var) -> None:
  130. self.visit_symbol_node(var)
  131. def visit_symbol_node(self, symbol: SymbolNode) -> None:
  132. if not self.funcs:
  133. # We are not inside a function and hence do not need to do
  134. # anything regarding free variables.
  135. return
  136. if symbol in self.symbols_to_funcs:
  137. orig_func = self.symbols_to_funcs[symbol]
  138. if self.is_parent(self.funcs[-1], orig_func):
  139. # The function in which the symbol was previously seen is
  140. # nested within the function currently being visited. Thus
  141. # the current function is a better candidate to contain the
  142. # declaration.
  143. self.symbols_to_funcs[symbol] = self.funcs[-1]
  144. # TODO: Remove from the orig_func free_variables set?
  145. self.free_variables.setdefault(self.funcs[-1], set()).add(symbol)
  146. elif self.is_parent(orig_func, self.funcs[-1]):
  147. # The SymbolNode instance has already been visited
  148. # before in a parent function, thus it's a non-local
  149. # symbol.
  150. self.add_free_variable(symbol)
  151. else:
  152. # This is the first time the SymbolNode is being
  153. # visited. We map the SymbolNode to the current FuncDef
  154. # being visited to note where it was first visited.
  155. self.symbols_to_funcs[symbol] = self.funcs[-1]
  156. def is_parent(self, fitem: FuncItem, child: FuncItem) -> bool:
  157. # Check if child is nested within fdef (possibly indirectly
  158. # within multiple nested functions).
  159. if child not in self.nested_funcs:
  160. return False
  161. parent = self.nested_funcs[child]
  162. return parent == fitem or self.is_parent(fitem, parent)
  163. def add_free_variable(self, symbol: SymbolNode) -> None:
  164. # Find the function where the symbol was (likely) first declared,
  165. # and mark is as a non-local symbol within that function.
  166. func = self.symbols_to_funcs[symbol]
  167. self.free_variables.setdefault(func, set()).add(symbol)