treetransform.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800
  1. """Base visitor that implements an identity AST transform.
  2. Subclass TransformVisitor to perform non-trivial transformations.
  3. """
  4. from __future__ import annotations
  5. from typing import Iterable, Optional, cast
  6. from mypy.nodes import (
  7. GDEF,
  8. REVEAL_TYPE,
  9. Argument,
  10. AssertStmt,
  11. AssertTypeExpr,
  12. AssignmentExpr,
  13. AssignmentStmt,
  14. AwaitExpr,
  15. Block,
  16. BreakStmt,
  17. BytesExpr,
  18. CallExpr,
  19. CastExpr,
  20. ClassDef,
  21. ComparisonExpr,
  22. ComplexExpr,
  23. ConditionalExpr,
  24. ContinueStmt,
  25. Decorator,
  26. DelStmt,
  27. DictExpr,
  28. DictionaryComprehension,
  29. EllipsisExpr,
  30. EnumCallExpr,
  31. Expression,
  32. ExpressionStmt,
  33. FloatExpr,
  34. ForStmt,
  35. FuncDef,
  36. FuncItem,
  37. GeneratorExpr,
  38. GlobalDecl,
  39. IfStmt,
  40. Import,
  41. ImportAll,
  42. ImportFrom,
  43. IndexExpr,
  44. IntExpr,
  45. LambdaExpr,
  46. ListComprehension,
  47. ListExpr,
  48. MatchStmt,
  49. MemberExpr,
  50. MypyFile,
  51. NamedTupleExpr,
  52. NameExpr,
  53. NewTypeExpr,
  54. Node,
  55. NonlocalDecl,
  56. OperatorAssignmentStmt,
  57. OpExpr,
  58. OverloadedFuncDef,
  59. OverloadPart,
  60. ParamSpecExpr,
  61. PassStmt,
  62. PromoteExpr,
  63. RaiseStmt,
  64. RefExpr,
  65. ReturnStmt,
  66. RevealExpr,
  67. SetComprehension,
  68. SetExpr,
  69. SliceExpr,
  70. StarExpr,
  71. Statement,
  72. StrExpr,
  73. SuperExpr,
  74. SymbolTable,
  75. TempNode,
  76. TryStmt,
  77. TupleExpr,
  78. TypeAliasExpr,
  79. TypeApplication,
  80. TypedDictExpr,
  81. TypeVarExpr,
  82. TypeVarTupleExpr,
  83. UnaryExpr,
  84. Var,
  85. WhileStmt,
  86. WithStmt,
  87. YieldExpr,
  88. YieldFromExpr,
  89. )
  90. from mypy.patterns import (
  91. AsPattern,
  92. ClassPattern,
  93. MappingPattern,
  94. OrPattern,
  95. Pattern,
  96. SequencePattern,
  97. SingletonPattern,
  98. StarredPattern,
  99. ValuePattern,
  100. )
  101. from mypy.traverser import TraverserVisitor
  102. from mypy.types import FunctionLike, ProperType, Type
  103. from mypy.util import replace_object_state
  104. from mypy.visitor import NodeVisitor
  105. class TransformVisitor(NodeVisitor[Node]):
  106. """Transform a semantically analyzed AST (or subtree) to an identical copy.
  107. Use the node() method to transform an AST node.
  108. Subclass to perform a non-identity transform.
  109. Notes:
  110. * This can only be used to transform functions or classes, not top-level
  111. statements, and/or modules as a whole.
  112. * Do not duplicate TypeInfo nodes. This would generally not be desirable.
  113. * Only update some name binding cross-references, but only those that
  114. refer to Var, Decorator or FuncDef nodes, not those targeting ClassDef or
  115. TypeInfo nodes.
  116. * Types are not transformed, but you can override type() to also perform
  117. type transformation.
  118. TODO nested classes and functions have not been tested well enough
  119. """
  120. def __init__(self) -> None:
  121. # To simplify testing, set this flag to True if you want to transform
  122. # all statements in a file (this is prohibited in normal mode).
  123. self.test_only = False
  124. # There may be multiple references to a Var node. Keep track of
  125. # Var translations using a dictionary.
  126. self.var_map: dict[Var, Var] = {}
  127. # These are uninitialized placeholder nodes used temporarily for nested
  128. # functions while we are transforming a top-level function. This maps an
  129. # untransformed node to a placeholder (which will later become the
  130. # transformed node).
  131. self.func_placeholder_map: dict[FuncDef, FuncDef] = {}
  132. def visit_mypy_file(self, node: MypyFile) -> MypyFile:
  133. assert self.test_only, "This visitor should not be used for whole files."
  134. # NOTE: The 'names' and 'imports' instance variables will be empty!
  135. ignored_lines = {line: codes.copy() for line, codes in node.ignored_lines.items()}
  136. new = MypyFile(self.statements(node.defs), [], node.is_bom, ignored_lines=ignored_lines)
  137. new._fullname = node._fullname
  138. new.path = node.path
  139. new.names = SymbolTable()
  140. return new
  141. def visit_import(self, node: Import) -> Import:
  142. return Import(node.ids.copy())
  143. def visit_import_from(self, node: ImportFrom) -> ImportFrom:
  144. return ImportFrom(node.id, node.relative, node.names.copy())
  145. def visit_import_all(self, node: ImportAll) -> ImportAll:
  146. return ImportAll(node.id, node.relative)
  147. def copy_argument(self, argument: Argument) -> Argument:
  148. arg = Argument(
  149. self.visit_var(argument.variable),
  150. argument.type_annotation,
  151. argument.initializer,
  152. argument.kind,
  153. )
  154. # Refresh lines of the inner things
  155. arg.set_line(argument)
  156. return arg
  157. def visit_func_def(self, node: FuncDef) -> FuncDef:
  158. # Note that a FuncDef must be transformed to a FuncDef.
  159. # These contortions are needed to handle the case of recursive
  160. # references inside the function being transformed.
  161. # Set up placeholder nodes for references within this function
  162. # to other functions defined inside it.
  163. # Don't create an entry for this function itself though,
  164. # since we want self-references to point to the original
  165. # function if this is the top-level node we are transforming.
  166. init = FuncMapInitializer(self)
  167. for stmt in node.body.body:
  168. stmt.accept(init)
  169. new = FuncDef(
  170. node.name,
  171. [self.copy_argument(arg) for arg in node.arguments],
  172. self.block(node.body),
  173. cast(Optional[FunctionLike], self.optional_type(node.type)),
  174. )
  175. self.copy_function_attributes(new, node)
  176. new._fullname = node._fullname
  177. new.is_decorated = node.is_decorated
  178. new.is_conditional = node.is_conditional
  179. new.abstract_status = node.abstract_status
  180. new.is_static = node.is_static
  181. new.is_class = node.is_class
  182. new.is_property = node.is_property
  183. new.is_final = node.is_final
  184. new.original_def = node.original_def
  185. if node in self.func_placeholder_map:
  186. # There is a placeholder definition for this function. Replace
  187. # the attributes of the placeholder with those form the transformed
  188. # function. We know that the classes will be identical (otherwise
  189. # this wouldn't work).
  190. result = self.func_placeholder_map[node]
  191. replace_object_state(result, new)
  192. return result
  193. else:
  194. return new
  195. def visit_lambda_expr(self, node: LambdaExpr) -> LambdaExpr:
  196. new = LambdaExpr(
  197. [self.copy_argument(arg) for arg in node.arguments],
  198. self.block(node.body),
  199. cast(Optional[FunctionLike], self.optional_type(node.type)),
  200. )
  201. self.copy_function_attributes(new, node)
  202. return new
  203. def copy_function_attributes(self, new: FuncItem, original: FuncItem) -> None:
  204. new.info = original.info
  205. new.min_args = original.min_args
  206. new.max_pos = original.max_pos
  207. new.is_overload = original.is_overload
  208. new.is_generator = original.is_generator
  209. new.is_coroutine = original.is_coroutine
  210. new.is_async_generator = original.is_async_generator
  211. new.is_awaitable_coroutine = original.is_awaitable_coroutine
  212. new.line = original.line
  213. def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> OverloadedFuncDef:
  214. items = [cast(OverloadPart, item.accept(self)) for item in node.items]
  215. for newitem, olditem in zip(items, node.items):
  216. newitem.line = olditem.line
  217. new = OverloadedFuncDef(items)
  218. new._fullname = node._fullname
  219. new_type = self.optional_type(node.type)
  220. assert isinstance(new_type, ProperType)
  221. new.type = new_type
  222. new.info = node.info
  223. new.is_static = node.is_static
  224. new.is_class = node.is_class
  225. new.is_property = node.is_property
  226. new.is_final = node.is_final
  227. if node.impl:
  228. new.impl = cast(OverloadPart, node.impl.accept(self))
  229. return new
  230. def visit_class_def(self, node: ClassDef) -> ClassDef:
  231. new = ClassDef(
  232. node.name,
  233. self.block(node.defs),
  234. node.type_vars,
  235. self.expressions(node.base_type_exprs),
  236. self.optional_expr(node.metaclass),
  237. )
  238. new.fullname = node.fullname
  239. new.info = node.info
  240. new.decorators = [self.expr(decorator) for decorator in node.decorators]
  241. return new
  242. def visit_global_decl(self, node: GlobalDecl) -> GlobalDecl:
  243. return GlobalDecl(node.names.copy())
  244. def visit_nonlocal_decl(self, node: NonlocalDecl) -> NonlocalDecl:
  245. return NonlocalDecl(node.names.copy())
  246. def visit_block(self, node: Block) -> Block:
  247. return Block(self.statements(node.body))
  248. def visit_decorator(self, node: Decorator) -> Decorator:
  249. # Note that a Decorator must be transformed to a Decorator.
  250. func = self.visit_func_def(node.func)
  251. func.line = node.func.line
  252. new = Decorator(func, self.expressions(node.decorators), self.visit_var(node.var))
  253. new.is_overload = node.is_overload
  254. return new
  255. def visit_var(self, node: Var) -> Var:
  256. # Note that a Var must be transformed to a Var.
  257. if node in self.var_map:
  258. return self.var_map[node]
  259. new = Var(node.name, self.optional_type(node.type))
  260. new.line = node.line
  261. new._fullname = node._fullname
  262. new.info = node.info
  263. new.is_self = node.is_self
  264. new.is_ready = node.is_ready
  265. new.is_initialized_in_class = node.is_initialized_in_class
  266. new.is_staticmethod = node.is_staticmethod
  267. new.is_classmethod = node.is_classmethod
  268. new.is_property = node.is_property
  269. new.is_final = node.is_final
  270. new.final_value = node.final_value
  271. new.final_unset_in_class = node.final_unset_in_class
  272. new.final_set_in_init = node.final_set_in_init
  273. new.set_line(node)
  274. self.var_map[node] = new
  275. return new
  276. def visit_expression_stmt(self, node: ExpressionStmt) -> ExpressionStmt:
  277. return ExpressionStmt(self.expr(node.expr))
  278. def visit_assignment_stmt(self, node: AssignmentStmt) -> AssignmentStmt:
  279. return self.duplicate_assignment(node)
  280. def duplicate_assignment(self, node: AssignmentStmt) -> AssignmentStmt:
  281. new = AssignmentStmt(
  282. self.expressions(node.lvalues),
  283. self.expr(node.rvalue),
  284. self.optional_type(node.unanalyzed_type),
  285. )
  286. new.line = node.line
  287. new.is_final_def = node.is_final_def
  288. new.type = self.optional_type(node.type)
  289. return new
  290. def visit_operator_assignment_stmt(
  291. self, node: OperatorAssignmentStmt
  292. ) -> OperatorAssignmentStmt:
  293. return OperatorAssignmentStmt(node.op, self.expr(node.lvalue), self.expr(node.rvalue))
  294. def visit_while_stmt(self, node: WhileStmt) -> WhileStmt:
  295. return WhileStmt(
  296. self.expr(node.expr), self.block(node.body), self.optional_block(node.else_body)
  297. )
  298. def visit_for_stmt(self, node: ForStmt) -> ForStmt:
  299. new = ForStmt(
  300. self.expr(node.index),
  301. self.expr(node.expr),
  302. self.block(node.body),
  303. self.optional_block(node.else_body),
  304. self.optional_type(node.unanalyzed_index_type),
  305. )
  306. new.is_async = node.is_async
  307. new.index_type = self.optional_type(node.index_type)
  308. return new
  309. def visit_return_stmt(self, node: ReturnStmt) -> ReturnStmt:
  310. return ReturnStmt(self.optional_expr(node.expr))
  311. def visit_assert_stmt(self, node: AssertStmt) -> AssertStmt:
  312. return AssertStmt(self.expr(node.expr), self.optional_expr(node.msg))
  313. def visit_del_stmt(self, node: DelStmt) -> DelStmt:
  314. return DelStmt(self.expr(node.expr))
  315. def visit_if_stmt(self, node: IfStmt) -> IfStmt:
  316. return IfStmt(
  317. self.expressions(node.expr),
  318. self.blocks(node.body),
  319. self.optional_block(node.else_body),
  320. )
  321. def visit_break_stmt(self, node: BreakStmt) -> BreakStmt:
  322. return BreakStmt()
  323. def visit_continue_stmt(self, node: ContinueStmt) -> ContinueStmt:
  324. return ContinueStmt()
  325. def visit_pass_stmt(self, node: PassStmt) -> PassStmt:
  326. return PassStmt()
  327. def visit_raise_stmt(self, node: RaiseStmt) -> RaiseStmt:
  328. return RaiseStmt(self.optional_expr(node.expr), self.optional_expr(node.from_expr))
  329. def visit_try_stmt(self, node: TryStmt) -> TryStmt:
  330. new = TryStmt(
  331. self.block(node.body),
  332. self.optional_names(node.vars),
  333. self.optional_expressions(node.types),
  334. self.blocks(node.handlers),
  335. self.optional_block(node.else_body),
  336. self.optional_block(node.finally_body),
  337. )
  338. new.is_star = node.is_star
  339. return new
  340. def visit_with_stmt(self, node: WithStmt) -> WithStmt:
  341. new = WithStmt(
  342. self.expressions(node.expr),
  343. self.optional_expressions(node.target),
  344. self.block(node.body),
  345. self.optional_type(node.unanalyzed_type),
  346. )
  347. new.is_async = node.is_async
  348. new.analyzed_types = [self.type(typ) for typ in node.analyzed_types]
  349. return new
  350. def visit_as_pattern(self, p: AsPattern) -> AsPattern:
  351. return AsPattern(
  352. pattern=self.pattern(p.pattern) if p.pattern is not None else None,
  353. name=self.duplicate_name(p.name) if p.name is not None else None,
  354. )
  355. def visit_or_pattern(self, p: OrPattern) -> OrPattern:
  356. return OrPattern([self.pattern(pat) for pat in p.patterns])
  357. def visit_value_pattern(self, p: ValuePattern) -> ValuePattern:
  358. return ValuePattern(self.expr(p.expr))
  359. def visit_singleton_pattern(self, p: SingletonPattern) -> SingletonPattern:
  360. return SingletonPattern(p.value)
  361. def visit_sequence_pattern(self, p: SequencePattern) -> SequencePattern:
  362. return SequencePattern([self.pattern(pat) for pat in p.patterns])
  363. def visit_starred_pattern(self, p: StarredPattern) -> StarredPattern:
  364. return StarredPattern(self.duplicate_name(p.capture) if p.capture is not None else None)
  365. def visit_mapping_pattern(self, p: MappingPattern) -> MappingPattern:
  366. return MappingPattern(
  367. keys=[self.expr(expr) for expr in p.keys],
  368. values=[self.pattern(pat) for pat in p.values],
  369. rest=self.duplicate_name(p.rest) if p.rest is not None else None,
  370. )
  371. def visit_class_pattern(self, p: ClassPattern) -> ClassPattern:
  372. class_ref = p.class_ref.accept(self)
  373. assert isinstance(class_ref, RefExpr)
  374. return ClassPattern(
  375. class_ref=class_ref,
  376. positionals=[self.pattern(pat) for pat in p.positionals],
  377. keyword_keys=list(p.keyword_keys),
  378. keyword_values=[self.pattern(pat) for pat in p.keyword_values],
  379. )
  380. def visit_match_stmt(self, o: MatchStmt) -> MatchStmt:
  381. return MatchStmt(
  382. subject=self.expr(o.subject),
  383. patterns=[self.pattern(p) for p in o.patterns],
  384. guards=self.optional_expressions(o.guards),
  385. bodies=self.blocks(o.bodies),
  386. )
  387. def visit_star_expr(self, node: StarExpr) -> StarExpr:
  388. return StarExpr(node.expr)
  389. def visit_int_expr(self, node: IntExpr) -> IntExpr:
  390. return IntExpr(node.value)
  391. def visit_str_expr(self, node: StrExpr) -> StrExpr:
  392. return StrExpr(node.value)
  393. def visit_bytes_expr(self, node: BytesExpr) -> BytesExpr:
  394. return BytesExpr(node.value)
  395. def visit_float_expr(self, node: FloatExpr) -> FloatExpr:
  396. return FloatExpr(node.value)
  397. def visit_complex_expr(self, node: ComplexExpr) -> ComplexExpr:
  398. return ComplexExpr(node.value)
  399. def visit_ellipsis(self, node: EllipsisExpr) -> EllipsisExpr:
  400. return EllipsisExpr()
  401. def visit_name_expr(self, node: NameExpr) -> NameExpr:
  402. return self.duplicate_name(node)
  403. def duplicate_name(self, node: NameExpr) -> NameExpr:
  404. # This method is used when the transform result must be a NameExpr.
  405. # visit_name_expr() is used when there is no such restriction.
  406. new = NameExpr(node.name)
  407. self.copy_ref(new, node)
  408. new.is_special_form = node.is_special_form
  409. return new
  410. def visit_member_expr(self, node: MemberExpr) -> MemberExpr:
  411. member = MemberExpr(self.expr(node.expr), node.name)
  412. if node.def_var:
  413. # This refers to an attribute and we don't transform attributes by default,
  414. # just normal variables.
  415. member.def_var = node.def_var
  416. self.copy_ref(member, node)
  417. return member
  418. def copy_ref(self, new: RefExpr, original: RefExpr) -> None:
  419. new.kind = original.kind
  420. new.fullname = original.fullname
  421. target = original.node
  422. if isinstance(target, Var):
  423. # Do not transform references to global variables. See
  424. # testGenericFunctionAliasExpand for an example where this is important.
  425. if original.kind != GDEF:
  426. target = self.visit_var(target)
  427. elif isinstance(target, Decorator):
  428. target = self.visit_var(target.var)
  429. elif isinstance(target, FuncDef):
  430. # Use a placeholder node for the function if it exists.
  431. target = self.func_placeholder_map.get(target, target)
  432. new.node = target
  433. new.is_new_def = original.is_new_def
  434. new.is_inferred_def = original.is_inferred_def
  435. def visit_yield_from_expr(self, node: YieldFromExpr) -> YieldFromExpr:
  436. return YieldFromExpr(self.expr(node.expr))
  437. def visit_yield_expr(self, node: YieldExpr) -> YieldExpr:
  438. return YieldExpr(self.optional_expr(node.expr))
  439. def visit_await_expr(self, node: AwaitExpr) -> AwaitExpr:
  440. return AwaitExpr(self.expr(node.expr))
  441. def visit_call_expr(self, node: CallExpr) -> CallExpr:
  442. return CallExpr(
  443. self.expr(node.callee),
  444. self.expressions(node.args),
  445. node.arg_kinds.copy(),
  446. node.arg_names.copy(),
  447. self.optional_expr(node.analyzed),
  448. )
  449. def visit_op_expr(self, node: OpExpr) -> OpExpr:
  450. new = OpExpr(
  451. node.op,
  452. self.expr(node.left),
  453. self.expr(node.right),
  454. cast(Optional[TypeAliasExpr], self.optional_expr(node.analyzed)),
  455. )
  456. new.method_type = self.optional_type(node.method_type)
  457. return new
  458. def visit_comparison_expr(self, node: ComparisonExpr) -> ComparisonExpr:
  459. new = ComparisonExpr(node.operators, self.expressions(node.operands))
  460. new.method_types = [self.optional_type(t) for t in node.method_types]
  461. return new
  462. def visit_cast_expr(self, node: CastExpr) -> CastExpr:
  463. return CastExpr(self.expr(node.expr), self.type(node.type))
  464. def visit_assert_type_expr(self, node: AssertTypeExpr) -> AssertTypeExpr:
  465. return AssertTypeExpr(self.expr(node.expr), self.type(node.type))
  466. def visit_reveal_expr(self, node: RevealExpr) -> RevealExpr:
  467. if node.kind == REVEAL_TYPE:
  468. assert node.expr is not None
  469. return RevealExpr(kind=REVEAL_TYPE, expr=self.expr(node.expr))
  470. else:
  471. # Reveal locals expressions don't have any sub expressions
  472. return node
  473. def visit_super_expr(self, node: SuperExpr) -> SuperExpr:
  474. call = self.expr(node.call)
  475. assert isinstance(call, CallExpr)
  476. new = SuperExpr(node.name, call)
  477. new.info = node.info
  478. return new
  479. def visit_assignment_expr(self, node: AssignmentExpr) -> AssignmentExpr:
  480. return AssignmentExpr(self.expr(node.target), self.expr(node.value))
  481. def visit_unary_expr(self, node: UnaryExpr) -> UnaryExpr:
  482. new = UnaryExpr(node.op, self.expr(node.expr))
  483. new.method_type = self.optional_type(node.method_type)
  484. return new
  485. def visit_list_expr(self, node: ListExpr) -> ListExpr:
  486. return ListExpr(self.expressions(node.items))
  487. def visit_dict_expr(self, node: DictExpr) -> DictExpr:
  488. return DictExpr(
  489. [(self.expr(key) if key else None, self.expr(value)) for key, value in node.items]
  490. )
  491. def visit_tuple_expr(self, node: TupleExpr) -> TupleExpr:
  492. return TupleExpr(self.expressions(node.items))
  493. def visit_set_expr(self, node: SetExpr) -> SetExpr:
  494. return SetExpr(self.expressions(node.items))
  495. def visit_index_expr(self, node: IndexExpr) -> IndexExpr:
  496. new = IndexExpr(self.expr(node.base), self.expr(node.index))
  497. if node.method_type:
  498. new.method_type = self.type(node.method_type)
  499. if node.analyzed:
  500. if isinstance(node.analyzed, TypeApplication):
  501. new.analyzed = self.visit_type_application(node.analyzed)
  502. else:
  503. new.analyzed = self.visit_type_alias_expr(node.analyzed)
  504. new.analyzed.set_line(node.analyzed)
  505. return new
  506. def visit_type_application(self, node: TypeApplication) -> TypeApplication:
  507. return TypeApplication(self.expr(node.expr), self.types(node.types))
  508. def visit_list_comprehension(self, node: ListComprehension) -> ListComprehension:
  509. generator = self.duplicate_generator(node.generator)
  510. generator.set_line(node.generator)
  511. return ListComprehension(generator)
  512. def visit_set_comprehension(self, node: SetComprehension) -> SetComprehension:
  513. generator = self.duplicate_generator(node.generator)
  514. generator.set_line(node.generator)
  515. return SetComprehension(generator)
  516. def visit_dictionary_comprehension(
  517. self, node: DictionaryComprehension
  518. ) -> DictionaryComprehension:
  519. return DictionaryComprehension(
  520. self.expr(node.key),
  521. self.expr(node.value),
  522. [self.expr(index) for index in node.indices],
  523. [self.expr(s) for s in node.sequences],
  524. [[self.expr(cond) for cond in conditions] for conditions in node.condlists],
  525. node.is_async,
  526. )
  527. def visit_generator_expr(self, node: GeneratorExpr) -> GeneratorExpr:
  528. return self.duplicate_generator(node)
  529. def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr:
  530. return GeneratorExpr(
  531. self.expr(node.left_expr),
  532. [self.expr(index) for index in node.indices],
  533. [self.expr(s) for s in node.sequences],
  534. [[self.expr(cond) for cond in conditions] for conditions in node.condlists],
  535. node.is_async,
  536. )
  537. def visit_slice_expr(self, node: SliceExpr) -> SliceExpr:
  538. return SliceExpr(
  539. self.optional_expr(node.begin_index),
  540. self.optional_expr(node.end_index),
  541. self.optional_expr(node.stride),
  542. )
  543. def visit_conditional_expr(self, node: ConditionalExpr) -> ConditionalExpr:
  544. return ConditionalExpr(
  545. self.expr(node.cond), self.expr(node.if_expr), self.expr(node.else_expr)
  546. )
  547. def visit_type_var_expr(self, node: TypeVarExpr) -> TypeVarExpr:
  548. return TypeVarExpr(
  549. node.name,
  550. node.fullname,
  551. self.types(node.values),
  552. self.type(node.upper_bound),
  553. self.type(node.default),
  554. variance=node.variance,
  555. )
  556. def visit_paramspec_expr(self, node: ParamSpecExpr) -> ParamSpecExpr:
  557. return ParamSpecExpr(
  558. node.name,
  559. node.fullname,
  560. self.type(node.upper_bound),
  561. self.type(node.default),
  562. variance=node.variance,
  563. )
  564. def visit_type_var_tuple_expr(self, node: TypeVarTupleExpr) -> TypeVarTupleExpr:
  565. return TypeVarTupleExpr(
  566. node.name,
  567. node.fullname,
  568. self.type(node.upper_bound),
  569. node.tuple_fallback,
  570. self.type(node.default),
  571. variance=node.variance,
  572. )
  573. def visit_type_alias_expr(self, node: TypeAliasExpr) -> TypeAliasExpr:
  574. return TypeAliasExpr(node.node)
  575. def visit_newtype_expr(self, node: NewTypeExpr) -> NewTypeExpr:
  576. res = NewTypeExpr(node.name, node.old_type, line=node.line, column=node.column)
  577. res.info = node.info
  578. return res
  579. def visit_namedtuple_expr(self, node: NamedTupleExpr) -> NamedTupleExpr:
  580. return NamedTupleExpr(node.info)
  581. def visit_enum_call_expr(self, node: EnumCallExpr) -> EnumCallExpr:
  582. return EnumCallExpr(node.info, node.items, node.values)
  583. def visit_typeddict_expr(self, node: TypedDictExpr) -> Node:
  584. return TypedDictExpr(node.info)
  585. def visit__promote_expr(self, node: PromoteExpr) -> PromoteExpr:
  586. return PromoteExpr(node.type)
  587. def visit_temp_node(self, node: TempNode) -> TempNode:
  588. return TempNode(self.type(node.type))
  589. def node(self, node: Node) -> Node:
  590. new = node.accept(self)
  591. new.set_line(node)
  592. return new
  593. def mypyfile(self, node: MypyFile) -> MypyFile:
  594. new = node.accept(self)
  595. assert isinstance(new, MypyFile)
  596. new.set_line(node)
  597. return new
  598. def expr(self, expr: Expression) -> Expression:
  599. new = expr.accept(self)
  600. assert isinstance(new, Expression)
  601. new.set_line(expr)
  602. return new
  603. def stmt(self, stmt: Statement) -> Statement:
  604. new = stmt.accept(self)
  605. assert isinstance(new, Statement)
  606. new.set_line(stmt)
  607. return new
  608. def pattern(self, pattern: Pattern) -> Pattern:
  609. new = pattern.accept(self)
  610. assert isinstance(new, Pattern)
  611. new.set_line(pattern)
  612. return new
  613. # Helpers
  614. #
  615. # All the node helpers also propagate line numbers.
  616. def optional_expr(self, expr: Expression | None) -> Expression | None:
  617. if expr:
  618. return self.expr(expr)
  619. else:
  620. return None
  621. def block(self, block: Block) -> Block:
  622. new = self.visit_block(block)
  623. new.line = block.line
  624. return new
  625. def optional_block(self, block: Block | None) -> Block | None:
  626. if block:
  627. return self.block(block)
  628. else:
  629. return None
  630. def statements(self, statements: list[Statement]) -> list[Statement]:
  631. return [self.stmt(stmt) for stmt in statements]
  632. def expressions(self, expressions: list[Expression]) -> list[Expression]:
  633. return [self.expr(expr) for expr in expressions]
  634. def optional_expressions(
  635. self, expressions: Iterable[Expression | None]
  636. ) -> list[Expression | None]:
  637. return [self.optional_expr(expr) for expr in expressions]
  638. def blocks(self, blocks: list[Block]) -> list[Block]:
  639. return [self.block(block) for block in blocks]
  640. def names(self, names: list[NameExpr]) -> list[NameExpr]:
  641. return [self.duplicate_name(name) for name in names]
  642. def optional_names(self, names: Iterable[NameExpr | None]) -> list[NameExpr | None]:
  643. result: list[NameExpr | None] = []
  644. for name in names:
  645. if name:
  646. result.append(self.duplicate_name(name))
  647. else:
  648. result.append(None)
  649. return result
  650. def type(self, type: Type) -> Type:
  651. # Override this method to transform types.
  652. return type
  653. def optional_type(self, type: Type | None) -> Type | None:
  654. if type:
  655. return self.type(type)
  656. else:
  657. return None
  658. def types(self, types: list[Type]) -> list[Type]:
  659. return [self.type(type) for type in types]
  660. class FuncMapInitializer(TraverserVisitor):
  661. """This traverser creates mappings from nested FuncDefs to placeholder FuncDefs.
  662. The placeholders will later be replaced with transformed nodes.
  663. """
  664. def __init__(self, transformer: TransformVisitor) -> None:
  665. self.transformer = transformer
  666. def visit_func_def(self, node: FuncDef) -> None:
  667. if node not in self.transformer.func_placeholder_map:
  668. # Haven't seen this FuncDef before, so create a placeholder node.
  669. self.transformer.func_placeholder_map[node] = FuncDef(
  670. node.name, node.arguments, node.body, None
  671. )
  672. super().visit_func_def(node)