| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800 |
- """Base visitor that implements an identity AST transform.
- Subclass TransformVisitor to perform non-trivial transformations.
- """
- from __future__ import annotations
- from typing import Iterable, Optional, cast
- from mypy.nodes import (
- GDEF,
- REVEAL_TYPE,
- Argument,
- AssertStmt,
- AssertTypeExpr,
- AssignmentExpr,
- AssignmentStmt,
- AwaitExpr,
- Block,
- BreakStmt,
- BytesExpr,
- CallExpr,
- CastExpr,
- ClassDef,
- ComparisonExpr,
- ComplexExpr,
- ConditionalExpr,
- ContinueStmt,
- Decorator,
- DelStmt,
- DictExpr,
- DictionaryComprehension,
- EllipsisExpr,
- EnumCallExpr,
- Expression,
- ExpressionStmt,
- FloatExpr,
- ForStmt,
- FuncDef,
- FuncItem,
- GeneratorExpr,
- GlobalDecl,
- IfStmt,
- Import,
- ImportAll,
- ImportFrom,
- IndexExpr,
- IntExpr,
- LambdaExpr,
- ListComprehension,
- ListExpr,
- MatchStmt,
- MemberExpr,
- MypyFile,
- NamedTupleExpr,
- NameExpr,
- NewTypeExpr,
- Node,
- NonlocalDecl,
- OperatorAssignmentStmt,
- OpExpr,
- OverloadedFuncDef,
- OverloadPart,
- ParamSpecExpr,
- PassStmt,
- PromoteExpr,
- RaiseStmt,
- RefExpr,
- ReturnStmt,
- RevealExpr,
- SetComprehension,
- SetExpr,
- SliceExpr,
- StarExpr,
- Statement,
- StrExpr,
- SuperExpr,
- SymbolTable,
- TempNode,
- TryStmt,
- TupleExpr,
- TypeAliasExpr,
- TypeApplication,
- TypedDictExpr,
- TypeVarExpr,
- TypeVarTupleExpr,
- UnaryExpr,
- Var,
- WhileStmt,
- WithStmt,
- YieldExpr,
- YieldFromExpr,
- )
- from mypy.patterns import (
- AsPattern,
- ClassPattern,
- MappingPattern,
- OrPattern,
- Pattern,
- SequencePattern,
- SingletonPattern,
- StarredPattern,
- ValuePattern,
- )
- from mypy.traverser import TraverserVisitor
- from mypy.types import FunctionLike, ProperType, Type
- from mypy.util import replace_object_state
- from mypy.visitor import NodeVisitor
- class TransformVisitor(NodeVisitor[Node]):
- """Transform a semantically analyzed AST (or subtree) to an identical copy.
- Use the node() method to transform an AST node.
- Subclass to perform a non-identity transform.
- Notes:
- * This can only be used to transform functions or classes, not top-level
- statements, and/or modules as a whole.
- * Do not duplicate TypeInfo nodes. This would generally not be desirable.
- * Only update some name binding cross-references, but only those that
- refer to Var, Decorator or FuncDef nodes, not those targeting ClassDef or
- TypeInfo nodes.
- * Types are not transformed, but you can override type() to also perform
- type transformation.
- TODO nested classes and functions have not been tested well enough
- """
- def __init__(self) -> None:
- # To simplify testing, set this flag to True if you want to transform
- # all statements in a file (this is prohibited in normal mode).
- self.test_only = False
- # There may be multiple references to a Var node. Keep track of
- # Var translations using a dictionary.
- self.var_map: dict[Var, Var] = {}
- # These are uninitialized placeholder nodes used temporarily for nested
- # functions while we are transforming a top-level function. This maps an
- # untransformed node to a placeholder (which will later become the
- # transformed node).
- self.func_placeholder_map: dict[FuncDef, FuncDef] = {}
- def visit_mypy_file(self, node: MypyFile) -> MypyFile:
- assert self.test_only, "This visitor should not be used for whole files."
- # NOTE: The 'names' and 'imports' instance variables will be empty!
- ignored_lines = {line: codes.copy() for line, codes in node.ignored_lines.items()}
- new = MypyFile(self.statements(node.defs), [], node.is_bom, ignored_lines=ignored_lines)
- new._fullname = node._fullname
- new.path = node.path
- new.names = SymbolTable()
- return new
- def visit_import(self, node: Import) -> Import:
- return Import(node.ids.copy())
- def visit_import_from(self, node: ImportFrom) -> ImportFrom:
- return ImportFrom(node.id, node.relative, node.names.copy())
- def visit_import_all(self, node: ImportAll) -> ImportAll:
- return ImportAll(node.id, node.relative)
- def copy_argument(self, argument: Argument) -> Argument:
- arg = Argument(
- self.visit_var(argument.variable),
- argument.type_annotation,
- argument.initializer,
- argument.kind,
- )
- # Refresh lines of the inner things
- arg.set_line(argument)
- return arg
- def visit_func_def(self, node: FuncDef) -> FuncDef:
- # Note that a FuncDef must be transformed to a FuncDef.
- # These contortions are needed to handle the case of recursive
- # references inside the function being transformed.
- # Set up placeholder nodes for references within this function
- # to other functions defined inside it.
- # Don't create an entry for this function itself though,
- # since we want self-references to point to the original
- # function if this is the top-level node we are transforming.
- init = FuncMapInitializer(self)
- for stmt in node.body.body:
- stmt.accept(init)
- new = FuncDef(
- node.name,
- [self.copy_argument(arg) for arg in node.arguments],
- self.block(node.body),
- cast(Optional[FunctionLike], self.optional_type(node.type)),
- )
- self.copy_function_attributes(new, node)
- new._fullname = node._fullname
- new.is_decorated = node.is_decorated
- new.is_conditional = node.is_conditional
- new.abstract_status = node.abstract_status
- new.is_static = node.is_static
- new.is_class = node.is_class
- new.is_property = node.is_property
- new.is_final = node.is_final
- new.original_def = node.original_def
- if node in self.func_placeholder_map:
- # There is a placeholder definition for this function. Replace
- # the attributes of the placeholder with those form the transformed
- # function. We know that the classes will be identical (otherwise
- # this wouldn't work).
- result = self.func_placeholder_map[node]
- replace_object_state(result, new)
- return result
- else:
- return new
- def visit_lambda_expr(self, node: LambdaExpr) -> LambdaExpr:
- new = LambdaExpr(
- [self.copy_argument(arg) for arg in node.arguments],
- self.block(node.body),
- cast(Optional[FunctionLike], self.optional_type(node.type)),
- )
- self.copy_function_attributes(new, node)
- return new
- def copy_function_attributes(self, new: FuncItem, original: FuncItem) -> None:
- new.info = original.info
- new.min_args = original.min_args
- new.max_pos = original.max_pos
- new.is_overload = original.is_overload
- new.is_generator = original.is_generator
- new.is_coroutine = original.is_coroutine
- new.is_async_generator = original.is_async_generator
- new.is_awaitable_coroutine = original.is_awaitable_coroutine
- new.line = original.line
- def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> OverloadedFuncDef:
- items = [cast(OverloadPart, item.accept(self)) for item in node.items]
- for newitem, olditem in zip(items, node.items):
- newitem.line = olditem.line
- new = OverloadedFuncDef(items)
- new._fullname = node._fullname
- new_type = self.optional_type(node.type)
- assert isinstance(new_type, ProperType)
- new.type = new_type
- new.info = node.info
- new.is_static = node.is_static
- new.is_class = node.is_class
- new.is_property = node.is_property
- new.is_final = node.is_final
- if node.impl:
- new.impl = cast(OverloadPart, node.impl.accept(self))
- return new
- def visit_class_def(self, node: ClassDef) -> ClassDef:
- new = ClassDef(
- node.name,
- self.block(node.defs),
- node.type_vars,
- self.expressions(node.base_type_exprs),
- self.optional_expr(node.metaclass),
- )
- new.fullname = node.fullname
- new.info = node.info
- new.decorators = [self.expr(decorator) for decorator in node.decorators]
- return new
- def visit_global_decl(self, node: GlobalDecl) -> GlobalDecl:
- return GlobalDecl(node.names.copy())
- def visit_nonlocal_decl(self, node: NonlocalDecl) -> NonlocalDecl:
- return NonlocalDecl(node.names.copy())
- def visit_block(self, node: Block) -> Block:
- return Block(self.statements(node.body))
- def visit_decorator(self, node: Decorator) -> Decorator:
- # Note that a Decorator must be transformed to a Decorator.
- func = self.visit_func_def(node.func)
- func.line = node.func.line
- new = Decorator(func, self.expressions(node.decorators), self.visit_var(node.var))
- new.is_overload = node.is_overload
- return new
- def visit_var(self, node: Var) -> Var:
- # Note that a Var must be transformed to a Var.
- if node in self.var_map:
- return self.var_map[node]
- new = Var(node.name, self.optional_type(node.type))
- new.line = node.line
- new._fullname = node._fullname
- new.info = node.info
- new.is_self = node.is_self
- new.is_ready = node.is_ready
- new.is_initialized_in_class = node.is_initialized_in_class
- new.is_staticmethod = node.is_staticmethod
- new.is_classmethod = node.is_classmethod
- new.is_property = node.is_property
- new.is_final = node.is_final
- new.final_value = node.final_value
- new.final_unset_in_class = node.final_unset_in_class
- new.final_set_in_init = node.final_set_in_init
- new.set_line(node)
- self.var_map[node] = new
- return new
- def visit_expression_stmt(self, node: ExpressionStmt) -> ExpressionStmt:
- return ExpressionStmt(self.expr(node.expr))
- def visit_assignment_stmt(self, node: AssignmentStmt) -> AssignmentStmt:
- return self.duplicate_assignment(node)
- def duplicate_assignment(self, node: AssignmentStmt) -> AssignmentStmt:
- new = AssignmentStmt(
- self.expressions(node.lvalues),
- self.expr(node.rvalue),
- self.optional_type(node.unanalyzed_type),
- )
- new.line = node.line
- new.is_final_def = node.is_final_def
- new.type = self.optional_type(node.type)
- return new
- def visit_operator_assignment_stmt(
- self, node: OperatorAssignmentStmt
- ) -> OperatorAssignmentStmt:
- return OperatorAssignmentStmt(node.op, self.expr(node.lvalue), self.expr(node.rvalue))
- def visit_while_stmt(self, node: WhileStmt) -> WhileStmt:
- return WhileStmt(
- self.expr(node.expr), self.block(node.body), self.optional_block(node.else_body)
- )
- def visit_for_stmt(self, node: ForStmt) -> ForStmt:
- new = ForStmt(
- self.expr(node.index),
- self.expr(node.expr),
- self.block(node.body),
- self.optional_block(node.else_body),
- self.optional_type(node.unanalyzed_index_type),
- )
- new.is_async = node.is_async
- new.index_type = self.optional_type(node.index_type)
- return new
- def visit_return_stmt(self, node: ReturnStmt) -> ReturnStmt:
- return ReturnStmt(self.optional_expr(node.expr))
- def visit_assert_stmt(self, node: AssertStmt) -> AssertStmt:
- return AssertStmt(self.expr(node.expr), self.optional_expr(node.msg))
- def visit_del_stmt(self, node: DelStmt) -> DelStmt:
- return DelStmt(self.expr(node.expr))
- def visit_if_stmt(self, node: IfStmt) -> IfStmt:
- return IfStmt(
- self.expressions(node.expr),
- self.blocks(node.body),
- self.optional_block(node.else_body),
- )
- def visit_break_stmt(self, node: BreakStmt) -> BreakStmt:
- return BreakStmt()
- def visit_continue_stmt(self, node: ContinueStmt) -> ContinueStmt:
- return ContinueStmt()
- def visit_pass_stmt(self, node: PassStmt) -> PassStmt:
- return PassStmt()
- def visit_raise_stmt(self, node: RaiseStmt) -> RaiseStmt:
- return RaiseStmt(self.optional_expr(node.expr), self.optional_expr(node.from_expr))
- def visit_try_stmt(self, node: TryStmt) -> TryStmt:
- new = TryStmt(
- self.block(node.body),
- self.optional_names(node.vars),
- self.optional_expressions(node.types),
- self.blocks(node.handlers),
- self.optional_block(node.else_body),
- self.optional_block(node.finally_body),
- )
- new.is_star = node.is_star
- return new
- def visit_with_stmt(self, node: WithStmt) -> WithStmt:
- new = WithStmt(
- self.expressions(node.expr),
- self.optional_expressions(node.target),
- self.block(node.body),
- self.optional_type(node.unanalyzed_type),
- )
- new.is_async = node.is_async
- new.analyzed_types = [self.type(typ) for typ in node.analyzed_types]
- return new
- def visit_as_pattern(self, p: AsPattern) -> AsPattern:
- return AsPattern(
- pattern=self.pattern(p.pattern) if p.pattern is not None else None,
- name=self.duplicate_name(p.name) if p.name is not None else None,
- )
- def visit_or_pattern(self, p: OrPattern) -> OrPattern:
- return OrPattern([self.pattern(pat) for pat in p.patterns])
- def visit_value_pattern(self, p: ValuePattern) -> ValuePattern:
- return ValuePattern(self.expr(p.expr))
- def visit_singleton_pattern(self, p: SingletonPattern) -> SingletonPattern:
- return SingletonPattern(p.value)
- def visit_sequence_pattern(self, p: SequencePattern) -> SequencePattern:
- return SequencePattern([self.pattern(pat) for pat in p.patterns])
- def visit_starred_pattern(self, p: StarredPattern) -> StarredPattern:
- return StarredPattern(self.duplicate_name(p.capture) if p.capture is not None else None)
- def visit_mapping_pattern(self, p: MappingPattern) -> MappingPattern:
- return MappingPattern(
- keys=[self.expr(expr) for expr in p.keys],
- values=[self.pattern(pat) for pat in p.values],
- rest=self.duplicate_name(p.rest) if p.rest is not None else None,
- )
- def visit_class_pattern(self, p: ClassPattern) -> ClassPattern:
- class_ref = p.class_ref.accept(self)
- assert isinstance(class_ref, RefExpr)
- return ClassPattern(
- class_ref=class_ref,
- positionals=[self.pattern(pat) for pat in p.positionals],
- keyword_keys=list(p.keyword_keys),
- keyword_values=[self.pattern(pat) for pat in p.keyword_values],
- )
- def visit_match_stmt(self, o: MatchStmt) -> MatchStmt:
- return MatchStmt(
- subject=self.expr(o.subject),
- patterns=[self.pattern(p) for p in o.patterns],
- guards=self.optional_expressions(o.guards),
- bodies=self.blocks(o.bodies),
- )
- def visit_star_expr(self, node: StarExpr) -> StarExpr:
- return StarExpr(node.expr)
- def visit_int_expr(self, node: IntExpr) -> IntExpr:
- return IntExpr(node.value)
- def visit_str_expr(self, node: StrExpr) -> StrExpr:
- return StrExpr(node.value)
- def visit_bytes_expr(self, node: BytesExpr) -> BytesExpr:
- return BytesExpr(node.value)
- def visit_float_expr(self, node: FloatExpr) -> FloatExpr:
- return FloatExpr(node.value)
- def visit_complex_expr(self, node: ComplexExpr) -> ComplexExpr:
- return ComplexExpr(node.value)
- def visit_ellipsis(self, node: EllipsisExpr) -> EllipsisExpr:
- return EllipsisExpr()
- def visit_name_expr(self, node: NameExpr) -> NameExpr:
- return self.duplicate_name(node)
- def duplicate_name(self, node: NameExpr) -> NameExpr:
- # This method is used when the transform result must be a NameExpr.
- # visit_name_expr() is used when there is no such restriction.
- new = NameExpr(node.name)
- self.copy_ref(new, node)
- new.is_special_form = node.is_special_form
- return new
- def visit_member_expr(self, node: MemberExpr) -> MemberExpr:
- member = MemberExpr(self.expr(node.expr), node.name)
- if node.def_var:
- # This refers to an attribute and we don't transform attributes by default,
- # just normal variables.
- member.def_var = node.def_var
- self.copy_ref(member, node)
- return member
- def copy_ref(self, new: RefExpr, original: RefExpr) -> None:
- new.kind = original.kind
- new.fullname = original.fullname
- target = original.node
- if isinstance(target, Var):
- # Do not transform references to global variables. See
- # testGenericFunctionAliasExpand for an example where this is important.
- if original.kind != GDEF:
- target = self.visit_var(target)
- elif isinstance(target, Decorator):
- target = self.visit_var(target.var)
- elif isinstance(target, FuncDef):
- # Use a placeholder node for the function if it exists.
- target = self.func_placeholder_map.get(target, target)
- new.node = target
- new.is_new_def = original.is_new_def
- new.is_inferred_def = original.is_inferred_def
- def visit_yield_from_expr(self, node: YieldFromExpr) -> YieldFromExpr:
- return YieldFromExpr(self.expr(node.expr))
- def visit_yield_expr(self, node: YieldExpr) -> YieldExpr:
- return YieldExpr(self.optional_expr(node.expr))
- def visit_await_expr(self, node: AwaitExpr) -> AwaitExpr:
- return AwaitExpr(self.expr(node.expr))
- def visit_call_expr(self, node: CallExpr) -> CallExpr:
- return CallExpr(
- self.expr(node.callee),
- self.expressions(node.args),
- node.arg_kinds.copy(),
- node.arg_names.copy(),
- self.optional_expr(node.analyzed),
- )
- def visit_op_expr(self, node: OpExpr) -> OpExpr:
- new = OpExpr(
- node.op,
- self.expr(node.left),
- self.expr(node.right),
- cast(Optional[TypeAliasExpr], self.optional_expr(node.analyzed)),
- )
- new.method_type = self.optional_type(node.method_type)
- return new
- def visit_comparison_expr(self, node: ComparisonExpr) -> ComparisonExpr:
- new = ComparisonExpr(node.operators, self.expressions(node.operands))
- new.method_types = [self.optional_type(t) for t in node.method_types]
- return new
- def visit_cast_expr(self, node: CastExpr) -> CastExpr:
- return CastExpr(self.expr(node.expr), self.type(node.type))
- def visit_assert_type_expr(self, node: AssertTypeExpr) -> AssertTypeExpr:
- return AssertTypeExpr(self.expr(node.expr), self.type(node.type))
- def visit_reveal_expr(self, node: RevealExpr) -> RevealExpr:
- if node.kind == REVEAL_TYPE:
- assert node.expr is not None
- return RevealExpr(kind=REVEAL_TYPE, expr=self.expr(node.expr))
- else:
- # Reveal locals expressions don't have any sub expressions
- return node
- def visit_super_expr(self, node: SuperExpr) -> SuperExpr:
- call = self.expr(node.call)
- assert isinstance(call, CallExpr)
- new = SuperExpr(node.name, call)
- new.info = node.info
- return new
- def visit_assignment_expr(self, node: AssignmentExpr) -> AssignmentExpr:
- return AssignmentExpr(self.expr(node.target), self.expr(node.value))
- def visit_unary_expr(self, node: UnaryExpr) -> UnaryExpr:
- new = UnaryExpr(node.op, self.expr(node.expr))
- new.method_type = self.optional_type(node.method_type)
- return new
- def visit_list_expr(self, node: ListExpr) -> ListExpr:
- return ListExpr(self.expressions(node.items))
- def visit_dict_expr(self, node: DictExpr) -> DictExpr:
- return DictExpr(
- [(self.expr(key) if key else None, self.expr(value)) for key, value in node.items]
- )
- def visit_tuple_expr(self, node: TupleExpr) -> TupleExpr:
- return TupleExpr(self.expressions(node.items))
- def visit_set_expr(self, node: SetExpr) -> SetExpr:
- return SetExpr(self.expressions(node.items))
- def visit_index_expr(self, node: IndexExpr) -> IndexExpr:
- new = IndexExpr(self.expr(node.base), self.expr(node.index))
- if node.method_type:
- new.method_type = self.type(node.method_type)
- if node.analyzed:
- if isinstance(node.analyzed, TypeApplication):
- new.analyzed = self.visit_type_application(node.analyzed)
- else:
- new.analyzed = self.visit_type_alias_expr(node.analyzed)
- new.analyzed.set_line(node.analyzed)
- return new
- def visit_type_application(self, node: TypeApplication) -> TypeApplication:
- return TypeApplication(self.expr(node.expr), self.types(node.types))
- def visit_list_comprehension(self, node: ListComprehension) -> ListComprehension:
- generator = self.duplicate_generator(node.generator)
- generator.set_line(node.generator)
- return ListComprehension(generator)
- def visit_set_comprehension(self, node: SetComprehension) -> SetComprehension:
- generator = self.duplicate_generator(node.generator)
- generator.set_line(node.generator)
- return SetComprehension(generator)
- def visit_dictionary_comprehension(
- self, node: DictionaryComprehension
- ) -> DictionaryComprehension:
- return DictionaryComprehension(
- self.expr(node.key),
- self.expr(node.value),
- [self.expr(index) for index in node.indices],
- [self.expr(s) for s in node.sequences],
- [[self.expr(cond) for cond in conditions] for conditions in node.condlists],
- node.is_async,
- )
- def visit_generator_expr(self, node: GeneratorExpr) -> GeneratorExpr:
- return self.duplicate_generator(node)
- def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr:
- return GeneratorExpr(
- self.expr(node.left_expr),
- [self.expr(index) for index in node.indices],
- [self.expr(s) for s in node.sequences],
- [[self.expr(cond) for cond in conditions] for conditions in node.condlists],
- node.is_async,
- )
- def visit_slice_expr(self, node: SliceExpr) -> SliceExpr:
- return SliceExpr(
- self.optional_expr(node.begin_index),
- self.optional_expr(node.end_index),
- self.optional_expr(node.stride),
- )
- def visit_conditional_expr(self, node: ConditionalExpr) -> ConditionalExpr:
- return ConditionalExpr(
- self.expr(node.cond), self.expr(node.if_expr), self.expr(node.else_expr)
- )
- def visit_type_var_expr(self, node: TypeVarExpr) -> TypeVarExpr:
- return TypeVarExpr(
- node.name,
- node.fullname,
- self.types(node.values),
- self.type(node.upper_bound),
- self.type(node.default),
- variance=node.variance,
- )
- def visit_paramspec_expr(self, node: ParamSpecExpr) -> ParamSpecExpr:
- return ParamSpecExpr(
- node.name,
- node.fullname,
- self.type(node.upper_bound),
- self.type(node.default),
- variance=node.variance,
- )
- def visit_type_var_tuple_expr(self, node: TypeVarTupleExpr) -> TypeVarTupleExpr:
- return TypeVarTupleExpr(
- node.name,
- node.fullname,
- self.type(node.upper_bound),
- node.tuple_fallback,
- self.type(node.default),
- variance=node.variance,
- )
- def visit_type_alias_expr(self, node: TypeAliasExpr) -> TypeAliasExpr:
- return TypeAliasExpr(node.node)
- def visit_newtype_expr(self, node: NewTypeExpr) -> NewTypeExpr:
- res = NewTypeExpr(node.name, node.old_type, line=node.line, column=node.column)
- res.info = node.info
- return res
- def visit_namedtuple_expr(self, node: NamedTupleExpr) -> NamedTupleExpr:
- return NamedTupleExpr(node.info)
- def visit_enum_call_expr(self, node: EnumCallExpr) -> EnumCallExpr:
- return EnumCallExpr(node.info, node.items, node.values)
- def visit_typeddict_expr(self, node: TypedDictExpr) -> Node:
- return TypedDictExpr(node.info)
- def visit__promote_expr(self, node: PromoteExpr) -> PromoteExpr:
- return PromoteExpr(node.type)
- def visit_temp_node(self, node: TempNode) -> TempNode:
- return TempNode(self.type(node.type))
- def node(self, node: Node) -> Node:
- new = node.accept(self)
- new.set_line(node)
- return new
- def mypyfile(self, node: MypyFile) -> MypyFile:
- new = node.accept(self)
- assert isinstance(new, MypyFile)
- new.set_line(node)
- return new
- def expr(self, expr: Expression) -> Expression:
- new = expr.accept(self)
- assert isinstance(new, Expression)
- new.set_line(expr)
- return new
- def stmt(self, stmt: Statement) -> Statement:
- new = stmt.accept(self)
- assert isinstance(new, Statement)
- new.set_line(stmt)
- return new
- def pattern(self, pattern: Pattern) -> Pattern:
- new = pattern.accept(self)
- assert isinstance(new, Pattern)
- new.set_line(pattern)
- return new
- # Helpers
- #
- # All the node helpers also propagate line numbers.
- def optional_expr(self, expr: Expression | None) -> Expression | None:
- if expr:
- return self.expr(expr)
- else:
- return None
- def block(self, block: Block) -> Block:
- new = self.visit_block(block)
- new.line = block.line
- return new
- def optional_block(self, block: Block | None) -> Block | None:
- if block:
- return self.block(block)
- else:
- return None
- def statements(self, statements: list[Statement]) -> list[Statement]:
- return [self.stmt(stmt) for stmt in statements]
- def expressions(self, expressions: list[Expression]) -> list[Expression]:
- return [self.expr(expr) for expr in expressions]
- def optional_expressions(
- self, expressions: Iterable[Expression | None]
- ) -> list[Expression | None]:
- return [self.optional_expr(expr) for expr in expressions]
- def blocks(self, blocks: list[Block]) -> list[Block]:
- return [self.block(block) for block in blocks]
- def names(self, names: list[NameExpr]) -> list[NameExpr]:
- return [self.duplicate_name(name) for name in names]
- def optional_names(self, names: Iterable[NameExpr | None]) -> list[NameExpr | None]:
- result: list[NameExpr | None] = []
- for name in names:
- if name:
- result.append(self.duplicate_name(name))
- else:
- result.append(None)
- return result
- def type(self, type: Type) -> Type:
- # Override this method to transform types.
- return type
- def optional_type(self, type: Type | None) -> Type | None:
- if type:
- return self.type(type)
- else:
- return None
- def types(self, types: list[Type]) -> list[Type]:
- return [self.type(type) for type in types]
- class FuncMapInitializer(TraverserVisitor):
- """This traverser creates mappings from nested FuncDefs to placeholder FuncDefs.
- The placeholders will later be replaced with transformed nodes.
- """
- def __init__(self, transformer: TransformVisitor) -> None:
- self.transformer = transformer
- def visit_func_def(self, node: FuncDef) -> None:
- if node not in self.transformer.func_placeholder_map:
- # Haven't seen this FuncDef before, so create a placeholder node.
- self.transformer.func_placeholder_map[node] = FuncDef(
- node.name, node.arguments, node.body, None
- )
- super().visit_func_def(node)
|