| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675 |
- from __future__ import annotations
- from enum import Enum
- from mypy import checker, errorcodes
- from mypy.messages import MessageBuilder
- from mypy.nodes import (
- AssertStmt,
- AssignmentExpr,
- AssignmentStmt,
- BreakStmt,
- ClassDef,
- Context,
- ContinueStmt,
- DictionaryComprehension,
- Expression,
- ExpressionStmt,
- ForStmt,
- FuncDef,
- FuncItem,
- GeneratorExpr,
- GlobalDecl,
- IfStmt,
- Import,
- ImportFrom,
- LambdaExpr,
- ListExpr,
- Lvalue,
- MatchStmt,
- MypyFile,
- NameExpr,
- NonlocalDecl,
- RaiseStmt,
- ReturnStmt,
- StarExpr,
- SymbolTable,
- TryStmt,
- TupleExpr,
- WhileStmt,
- WithStmt,
- implicit_module_attrs,
- )
- from mypy.options import Options
- from mypy.patterns import AsPattern, StarredPattern
- from mypy.reachability import ALWAYS_TRUE, infer_pattern_value
- from mypy.traverser import ExtendedTraverserVisitor
- from mypy.types import Type, UninhabitedType
- class BranchState:
- """BranchState contains information about variable definition at the end of a branching statement.
- `if` and `match` are examples of branching statements.
- `may_be_defined` contains variables that were defined in only some branches.
- `must_be_defined` contains variables that were defined in all branches.
- """
- def __init__(
- self,
- must_be_defined: set[str] | None = None,
- may_be_defined: set[str] | None = None,
- skipped: bool = False,
- ) -> None:
- if may_be_defined is None:
- may_be_defined = set()
- if must_be_defined is None:
- must_be_defined = set()
- self.may_be_defined = set(may_be_defined)
- self.must_be_defined = set(must_be_defined)
- self.skipped = skipped
- def copy(self) -> BranchState:
- return BranchState(
- must_be_defined=set(self.must_be_defined),
- may_be_defined=set(self.may_be_defined),
- skipped=self.skipped,
- )
- class BranchStatement:
- def __init__(self, initial_state: BranchState | None = None) -> None:
- if initial_state is None:
- initial_state = BranchState()
- self.initial_state = initial_state
- self.branches: list[BranchState] = [
- BranchState(
- must_be_defined=self.initial_state.must_be_defined,
- may_be_defined=self.initial_state.may_be_defined,
- )
- ]
- def copy(self) -> BranchStatement:
- result = BranchStatement(self.initial_state)
- result.branches = [b.copy() for b in self.branches]
- return result
- def next_branch(self) -> None:
- self.branches.append(
- BranchState(
- must_be_defined=self.initial_state.must_be_defined,
- may_be_defined=self.initial_state.may_be_defined,
- )
- )
- def record_definition(self, name: str) -> None:
- assert len(self.branches) > 0
- self.branches[-1].must_be_defined.add(name)
- self.branches[-1].may_be_defined.discard(name)
- def delete_var(self, name: str) -> None:
- assert len(self.branches) > 0
- self.branches[-1].must_be_defined.discard(name)
- self.branches[-1].may_be_defined.discard(name)
- def record_nested_branch(self, state: BranchState) -> None:
- assert len(self.branches) > 0
- current_branch = self.branches[-1]
- if state.skipped:
- current_branch.skipped = True
- return
- current_branch.must_be_defined.update(state.must_be_defined)
- current_branch.may_be_defined.update(state.may_be_defined)
- current_branch.may_be_defined.difference_update(current_branch.must_be_defined)
- def skip_branch(self) -> None:
- assert len(self.branches) > 0
- self.branches[-1].skipped = True
- def is_possibly_undefined(self, name: str) -> bool:
- assert len(self.branches) > 0
- return name in self.branches[-1].may_be_defined
- def is_undefined(self, name: str) -> bool:
- assert len(self.branches) > 0
- branch = self.branches[-1]
- return name not in branch.may_be_defined and name not in branch.must_be_defined
- def is_defined_in_a_branch(self, name: str) -> bool:
- assert len(self.branches) > 0
- for b in self.branches:
- if name in b.must_be_defined or name in b.may_be_defined:
- return True
- return False
- def done(self) -> BranchState:
- # First, compute all vars, including skipped branches. We include skipped branches
- # because our goal is to capture all variables that semantic analyzer would
- # consider defined.
- all_vars = set()
- for b in self.branches:
- all_vars.update(b.may_be_defined)
- all_vars.update(b.must_be_defined)
- # For the rest of the things, we only care about branches that weren't skipped.
- non_skipped_branches = [b for b in self.branches if not b.skipped]
- if non_skipped_branches:
- must_be_defined = non_skipped_branches[0].must_be_defined
- for b in non_skipped_branches[1:]:
- must_be_defined.intersection_update(b.must_be_defined)
- else:
- must_be_defined = set()
- # Everything that wasn't defined in all branches but was defined
- # in at least one branch should be in `may_be_defined`!
- may_be_defined = all_vars.difference(must_be_defined)
- return BranchState(
- must_be_defined=must_be_defined,
- may_be_defined=may_be_defined,
- skipped=len(non_skipped_branches) == 0,
- )
- class ScopeType(Enum):
- Global = 1
- Class = 2
- Func = 3
- Generator = 4
- class Scope:
- def __init__(self, stmts: list[BranchStatement], scope_type: ScopeType) -> None:
- self.branch_stmts: list[BranchStatement] = stmts
- self.scope_type = scope_type
- self.undefined_refs: dict[str, set[NameExpr]] = {}
- def copy(self) -> Scope:
- result = Scope([s.copy() for s in self.branch_stmts], self.scope_type)
- result.undefined_refs = self.undefined_refs.copy()
- return result
- def record_undefined_ref(self, o: NameExpr) -> None:
- if o.name not in self.undefined_refs:
- self.undefined_refs[o.name] = set()
- self.undefined_refs[o.name].add(o)
- def pop_undefined_ref(self, name: str) -> set[NameExpr]:
- return self.undefined_refs.pop(name, set())
- class DefinedVariableTracker:
- """DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor."""
- def __init__(self) -> None:
- # There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
- self.scopes: list[Scope] = [Scope([BranchStatement()], ScopeType.Global)]
- # disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful
- # in things like try/except/finally statements.
- self.disable_branch_skip = False
- def copy(self) -> DefinedVariableTracker:
- result = DefinedVariableTracker()
- result.scopes = [s.copy() for s in self.scopes]
- result.disable_branch_skip = self.disable_branch_skip
- return result
- def _scope(self) -> Scope:
- assert len(self.scopes) > 0
- return self.scopes[-1]
- def enter_scope(self, scope_type: ScopeType) -> None:
- assert len(self._scope().branch_stmts) > 0
- initial_state = None
- if scope_type == ScopeType.Generator:
- # Generators are special because they inherit the outer scope.
- initial_state = self._scope().branch_stmts[-1].branches[-1]
- self.scopes.append(Scope([BranchStatement(initial_state)], scope_type))
- def exit_scope(self) -> None:
- self.scopes.pop()
- def in_scope(self, scope_type: ScopeType) -> bool:
- return self._scope().scope_type == scope_type
- def start_branch_statement(self) -> None:
- assert len(self._scope().branch_stmts) > 0
- self._scope().branch_stmts.append(
- BranchStatement(self._scope().branch_stmts[-1].branches[-1])
- )
- def next_branch(self) -> None:
- assert len(self._scope().branch_stmts) > 1
- self._scope().branch_stmts[-1].next_branch()
- def end_branch_statement(self) -> None:
- assert len(self._scope().branch_stmts) > 1
- result = self._scope().branch_stmts.pop().done()
- self._scope().branch_stmts[-1].record_nested_branch(result)
- def skip_branch(self) -> None:
- # Only skip branch if we're outside of "root" branch statement.
- if len(self._scope().branch_stmts) > 1 and not self.disable_branch_skip:
- self._scope().branch_stmts[-1].skip_branch()
- def record_definition(self, name: str) -> None:
- assert len(self.scopes) > 0
- assert len(self.scopes[-1].branch_stmts) > 0
- self._scope().branch_stmts[-1].record_definition(name)
- def delete_var(self, name: str) -> None:
- assert len(self.scopes) > 0
- assert len(self.scopes[-1].branch_stmts) > 0
- self._scope().branch_stmts[-1].delete_var(name)
- def record_undefined_ref(self, o: NameExpr) -> None:
- """Records an undefined reference. These can later be retrieved via `pop_undefined_ref`."""
- assert len(self.scopes) > 0
- self._scope().record_undefined_ref(o)
- def pop_undefined_ref(self, name: str) -> set[NameExpr]:
- """If name has previously been reported as undefined, the NameExpr that was called will be returned."""
- assert len(self.scopes) > 0
- return self._scope().pop_undefined_ref(name)
- def is_possibly_undefined(self, name: str) -> bool:
- assert len(self._scope().branch_stmts) > 0
- # A variable is undefined if it's in a set of `may_be_defined` but not in `must_be_defined`.
- return self._scope().branch_stmts[-1].is_possibly_undefined(name)
- def is_defined_in_different_branch(self, name: str) -> bool:
- """This will return true if a variable is defined in a branch that's not the current branch."""
- assert len(self._scope().branch_stmts) > 0
- stmt = self._scope().branch_stmts[-1]
- if not stmt.is_undefined(name):
- return False
- for stmt in self._scope().branch_stmts:
- if stmt.is_defined_in_a_branch(name):
- return True
- return False
- def is_undefined(self, name: str) -> bool:
- assert len(self._scope().branch_stmts) > 0
- return self._scope().branch_stmts[-1].is_undefined(name)
- class Loop:
- def __init__(self) -> None:
- self.has_break = False
- class PossiblyUndefinedVariableVisitor(ExtendedTraverserVisitor):
- """Detects the following cases:
- - A variable that's defined only part of the time.
- - If a variable is used before definition
- An example of a partial definition:
- if foo():
- x = 1
- print(x) # Error: "x" may be undefined.
- Example of a used before definition:
- x = y
- y: int = 2
- Note that this code does not detect variables not defined in any of the branches -- that is
- handled by the semantic analyzer.
- """
- def __init__(
- self,
- msg: MessageBuilder,
- type_map: dict[Expression, Type],
- options: Options,
- names: SymbolTable,
- ) -> None:
- self.msg = msg
- self.type_map = type_map
- self.options = options
- self.builtins = SymbolTable()
- builtins_mod = names.get("__builtins__", None)
- if builtins_mod:
- assert isinstance(builtins_mod.node, MypyFile)
- self.builtins = builtins_mod.node.names
- self.loops: list[Loop] = []
- self.try_depth = 0
- self.tracker = DefinedVariableTracker()
- for name in implicit_module_attrs:
- self.tracker.record_definition(name)
- def var_used_before_def(self, name: str, context: Context) -> None:
- if self.msg.errors.is_error_code_enabled(errorcodes.USED_BEFORE_DEF):
- self.msg.var_used_before_def(name, context)
- def variable_may_be_undefined(self, name: str, context: Context) -> None:
- if self.msg.errors.is_error_code_enabled(errorcodes.POSSIBLY_UNDEFINED):
- self.msg.variable_may_be_undefined(name, context)
- def process_definition(self, name: str) -> None:
- # Was this name previously used? If yes, it's a used-before-definition error.
- if not self.tracker.in_scope(ScopeType.Class):
- refs = self.tracker.pop_undefined_ref(name)
- for ref in refs:
- if self.loops:
- self.variable_may_be_undefined(name, ref)
- else:
- self.var_used_before_def(name, ref)
- else:
- # Errors in class scopes are caught by the semantic analyzer.
- pass
- self.tracker.record_definition(name)
- def visit_global_decl(self, o: GlobalDecl) -> None:
- for name in o.names:
- self.process_definition(name)
- super().visit_global_decl(o)
- def visit_nonlocal_decl(self, o: NonlocalDecl) -> None:
- for name in o.names:
- self.process_definition(name)
- super().visit_nonlocal_decl(o)
- def process_lvalue(self, lvalue: Lvalue | None) -> None:
- if isinstance(lvalue, NameExpr):
- self.process_definition(lvalue.name)
- elif isinstance(lvalue, StarExpr):
- self.process_lvalue(lvalue.expr)
- elif isinstance(lvalue, (ListExpr, TupleExpr)):
- for item in lvalue.items:
- self.process_lvalue(item)
- def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
- for lvalue in o.lvalues:
- self.process_lvalue(lvalue)
- super().visit_assignment_stmt(o)
- def visit_assignment_expr(self, o: AssignmentExpr) -> None:
- o.value.accept(self)
- self.process_lvalue(o.target)
- def visit_if_stmt(self, o: IfStmt) -> None:
- for e in o.expr:
- e.accept(self)
- self.tracker.start_branch_statement()
- for b in o.body:
- if b.is_unreachable:
- continue
- b.accept(self)
- self.tracker.next_branch()
- if o.else_body:
- if not o.else_body.is_unreachable:
- o.else_body.accept(self)
- else:
- self.tracker.skip_branch()
- self.tracker.end_branch_statement()
- def visit_match_stmt(self, o: MatchStmt) -> None:
- o.subject.accept(self)
- self.tracker.start_branch_statement()
- for i in range(len(o.patterns)):
- pattern = o.patterns[i]
- pattern.accept(self)
- guard = o.guards[i]
- if guard is not None:
- guard.accept(self)
- if not o.bodies[i].is_unreachable:
- o.bodies[i].accept(self)
- else:
- self.tracker.skip_branch()
- is_catchall = infer_pattern_value(pattern) == ALWAYS_TRUE
- if not is_catchall:
- self.tracker.next_branch()
- self.tracker.end_branch_statement()
- def visit_func_def(self, o: FuncDef) -> None:
- self.process_definition(o.name)
- super().visit_func_def(o)
- def visit_func(self, o: FuncItem) -> None:
- if o.is_dynamic() and not self.options.check_untyped_defs:
- return
- args = o.arguments or []
- # Process initializers (defaults) outside the function scope.
- for arg in args:
- if arg.initializer is not None:
- arg.initializer.accept(self)
- self.tracker.enter_scope(ScopeType.Func)
- for arg in args:
- self.process_definition(arg.variable.name)
- super().visit_var(arg.variable)
- o.body.accept(self)
- self.tracker.exit_scope()
- def visit_generator_expr(self, o: GeneratorExpr) -> None:
- self.tracker.enter_scope(ScopeType.Generator)
- for idx in o.indices:
- self.process_lvalue(idx)
- super().visit_generator_expr(o)
- self.tracker.exit_scope()
- def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None:
- self.tracker.enter_scope(ScopeType.Generator)
- for idx in o.indices:
- self.process_lvalue(idx)
- super().visit_dictionary_comprehension(o)
- self.tracker.exit_scope()
- def visit_for_stmt(self, o: ForStmt) -> None:
- o.expr.accept(self)
- self.process_lvalue(o.index)
- o.index.accept(self)
- self.tracker.start_branch_statement()
- loop = Loop()
- self.loops.append(loop)
- o.body.accept(self)
- self.tracker.next_branch()
- self.tracker.end_branch_statement()
- if o.else_body is not None:
- # If the loop has a `break` inside, `else` is executed conditionally.
- # If the loop doesn't have a `break` either the function will return or
- # execute the `else`.
- has_break = loop.has_break
- if has_break:
- self.tracker.start_branch_statement()
- self.tracker.next_branch()
- o.else_body.accept(self)
- if has_break:
- self.tracker.end_branch_statement()
- self.loops.pop()
- def visit_return_stmt(self, o: ReturnStmt) -> None:
- super().visit_return_stmt(o)
- self.tracker.skip_branch()
- def visit_lambda_expr(self, o: LambdaExpr) -> None:
- self.tracker.enter_scope(ScopeType.Func)
- super().visit_lambda_expr(o)
- self.tracker.exit_scope()
- def visit_assert_stmt(self, o: AssertStmt) -> None:
- super().visit_assert_stmt(o)
- if checker.is_false_literal(o.expr):
- self.tracker.skip_branch()
- def visit_raise_stmt(self, o: RaiseStmt) -> None:
- super().visit_raise_stmt(o)
- self.tracker.skip_branch()
- def visit_continue_stmt(self, o: ContinueStmt) -> None:
- super().visit_continue_stmt(o)
- self.tracker.skip_branch()
- def visit_break_stmt(self, o: BreakStmt) -> None:
- super().visit_break_stmt(o)
- if self.loops:
- self.loops[-1].has_break = True
- self.tracker.skip_branch()
- def visit_expression_stmt(self, o: ExpressionStmt) -> None:
- if isinstance(self.type_map.get(o.expr, None), UninhabitedType):
- self.tracker.skip_branch()
- super().visit_expression_stmt(o)
- def visit_try_stmt(self, o: TryStmt) -> None:
- """
- Note that finding undefined vars in `finally` requires different handling from
- the rest of the code. In particular, we want to disallow skipping branches due to jump
- statements in except/else clauses for finally but not for other cases. Imagine a case like:
- def f() -> int:
- try:
- x = 1
- except:
- # This jump statement needs to be handled differently depending on whether or
- # not we're trying to process `finally` or not.
- return 0
- finally:
- # `x` may be undefined here.
- pass
- # `x` is always defined here.
- return x
- """
- self.try_depth += 1
- if o.finally_body is not None:
- # In order to find undefined vars in `finally`, we need to
- # process try/except with branch skipping disabled. However, for the rest of the code
- # after finally, we need to process try/except with branch skipping enabled.
- # Therefore, we need to process try/finally twice.
- # Because processing is not idempotent, we should make a copy of the tracker.
- old_tracker = self.tracker.copy()
- self.tracker.disable_branch_skip = True
- self.process_try_stmt(o)
- self.tracker = old_tracker
- self.process_try_stmt(o)
- self.try_depth -= 1
- def process_try_stmt(self, o: TryStmt) -> None:
- """
- Processes try statement decomposing it into the following:
- if ...:
- body
- else_body
- elif ...:
- except 1
- elif ...:
- except 2
- else:
- except n
- finally
- """
- self.tracker.start_branch_statement()
- o.body.accept(self)
- if o.else_body is not None:
- o.else_body.accept(self)
- if len(o.handlers) > 0:
- assert len(o.handlers) == len(o.vars) == len(o.types)
- for i in range(len(o.handlers)):
- self.tracker.next_branch()
- exc_type = o.types[i]
- if exc_type is not None:
- exc_type.accept(self)
- var = o.vars[i]
- if var is not None:
- self.process_definition(var.name)
- var.accept(self)
- o.handlers[i].accept(self)
- if var is not None:
- self.tracker.delete_var(var.name)
- self.tracker.end_branch_statement()
- if o.finally_body is not None:
- o.finally_body.accept(self)
- def visit_while_stmt(self, o: WhileStmt) -> None:
- o.expr.accept(self)
- self.tracker.start_branch_statement()
- loop = Loop()
- self.loops.append(loop)
- o.body.accept(self)
- has_break = loop.has_break
- if not checker.is_true_literal(o.expr):
- # If this is a loop like `while True`, we can consider the body to be
- # a single branch statement (we're guaranteed that the body is executed at least once).
- # If not, call next_branch() to make all variables defined there conditional.
- self.tracker.next_branch()
- self.tracker.end_branch_statement()
- if o.else_body is not None:
- # If the loop has a `break` inside, `else` is executed conditionally.
- # If the loop doesn't have a `break` either the function will return or
- # execute the `else`.
- if has_break:
- self.tracker.start_branch_statement()
- self.tracker.next_branch()
- if o.else_body:
- o.else_body.accept(self)
- if has_break:
- self.tracker.end_branch_statement()
- self.loops.pop()
- def visit_as_pattern(self, o: AsPattern) -> None:
- if o.name is not None:
- self.process_lvalue(o.name)
- super().visit_as_pattern(o)
- def visit_starred_pattern(self, o: StarredPattern) -> None:
- if o.capture is not None:
- self.process_lvalue(o.capture)
- super().visit_starred_pattern(o)
- def visit_name_expr(self, o: NameExpr) -> None:
- if o.name in self.builtins and self.tracker.in_scope(ScopeType.Global):
- return
- if self.tracker.is_possibly_undefined(o.name):
- # A variable is only defined in some branches.
- self.variable_may_be_undefined(o.name, o)
- # We don't want to report the error on the same variable multiple times.
- self.tracker.record_definition(o.name)
- elif self.tracker.is_defined_in_different_branch(o.name):
- # A variable is defined in one branch but used in a different branch.
- if self.loops or self.try_depth > 0:
- # If we're in a loop or in a try, we can't be sure that this variable
- # is undefined. Report it as "may be undefined".
- self.variable_may_be_undefined(o.name, o)
- else:
- self.var_used_before_def(o.name, o)
- elif self.tracker.is_undefined(o.name):
- # A variable is undefined. It could be due to two things:
- # 1. A variable is just totally undefined
- # 2. The variable is defined later in the code.
- # Case (1) will be caught by semantic analyzer. Case (2) is a forward ref that should
- # be caught by this visitor. Save the ref for later, so that if we see a definition,
- # we know it's a used-before-definition scenario.
- self.tracker.record_undefined_ref(o)
- super().visit_name_expr(o)
- def visit_with_stmt(self, o: WithStmt) -> None:
- for expr, idx in zip(o.expr, o.target):
- expr.accept(self)
- self.process_lvalue(idx)
- o.body.accept(self)
- def visit_class_def(self, o: ClassDef) -> None:
- self.process_definition(o.name)
- self.tracker.enter_scope(ScopeType.Class)
- super().visit_class_def(o)
- self.tracker.exit_scope()
- def visit_import(self, o: Import) -> None:
- for mod, alias in o.ids:
- if alias is not None:
- self.tracker.record_definition(alias)
- else:
- # When you do `import x.y`, only `x` becomes defined.
- names = mod.split(".")
- if names:
- # `names` should always be nonempty, but we don't want mypy
- # to crash on invalid code.
- self.tracker.record_definition(names[0])
- super().visit_import(o)
- def visit_import_from(self, o: ImportFrom) -> None:
- for mod, alias in o.names:
- name = alias
- if name is None:
- name = mod
- self.tracker.record_definition(name)
- super().visit_import_from(o)
|