| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343 |
- from __future__ import annotations
- from mypy.argmap import map_actuals_to_formals
- from mypy.fixup import TypeFixer
- from mypy.nodes import (
- ARG_POS,
- MDEF,
- SYMBOL_FUNCBASE_TYPES,
- Argument,
- Block,
- CallExpr,
- ClassDef,
- Decorator,
- Expression,
- FuncDef,
- JsonDict,
- NameExpr,
- Node,
- PassStmt,
- RefExpr,
- SymbolTableNode,
- Var,
- )
- from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
- from mypy.semanal_shared import (
- ALLOW_INCOMPATIBLE_OVERRIDE,
- parse_bool,
- require_bool_literal_argument,
- set_callable_name,
- )
- from mypy.typeops import try_getting_str_literals as try_getting_str_literals
- from mypy.types import (
- AnyType,
- CallableType,
- Instance,
- LiteralType,
- NoneType,
- Overloaded,
- Type,
- TypeOfAny,
- TypeType,
- TypeVarType,
- deserialize_type,
- get_proper_type,
- )
- from mypy.types_utils import is_optional
- from mypy.typevars import fill_typevars
- from mypy.util import get_unique_redefinition_name
- def _get_decorator_bool_argument(ctx: ClassDefContext, name: str, default: bool) -> bool:
- """Return the bool argument for the decorator.
- This handles both @decorator(...) and @decorator.
- """
- if isinstance(ctx.reason, CallExpr):
- return _get_bool_argument(ctx, ctx.reason, name, default)
- else:
- return default
- def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr, name: str, default: bool) -> bool:
- """Return the boolean value for an argument to a call or the
- default if it's not found.
- """
- attr_value = _get_argument(expr, name)
- if attr_value:
- return require_bool_literal_argument(ctx.api, attr_value, name, default)
- return default
- def _get_argument(call: CallExpr, name: str) -> Expression | None:
- """Return the expression for the specific argument."""
- # To do this we use the CallableType of the callee to find the FormalArgument,
- # then walk the actual CallExpr looking for the appropriate argument.
- #
- # Note: I'm not hard-coding the index so that in the future we can support other
- # attrib and class makers.
- callee_type = _get_callee_type(call)
- if not callee_type:
- return None
- argument = callee_type.argument_by_name(name)
- if not argument:
- return None
- assert argument.name
- for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)):
- if argument.pos is not None and not attr_name and i == argument.pos:
- return attr_value
- if attr_name == argument.name:
- return attr_value
- return None
- def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) -> CallableType:
- """Perform limited lookup of a matching overload item.
- Full overload resolution is only supported during type checking, but plugins
- sometimes need to resolve overloads. This can be used in some such use cases.
- Resolve overloads based on these things only:
- * Match using argument kinds and names
- * If formal argument has type None, only accept the "None" expression in the callee
- * If formal argument has type Literal[True] or Literal[False], only accept the
- relevant bool literal
- Return the first matching overload item, or the last one if nothing matches.
- """
- for item in overload.items[:-1]:
- ok = True
- mapped = map_actuals_to_formals(
- call.arg_kinds,
- call.arg_names,
- item.arg_kinds,
- item.arg_names,
- lambda i: AnyType(TypeOfAny.special_form),
- )
- # Look for extra actuals
- matched_actuals = set()
- for actuals in mapped:
- matched_actuals.update(actuals)
- if any(i not in matched_actuals for i in range(len(call.args))):
- ok = False
- for arg_type, kind, actuals in zip(item.arg_types, item.arg_kinds, mapped):
- if kind.is_required() and not actuals:
- # Missing required argument
- ok = False
- break
- elif actuals:
- args = [call.args[i] for i in actuals]
- arg_type = get_proper_type(arg_type)
- arg_none = any(isinstance(arg, NameExpr) and arg.name == "None" for arg in args)
- if isinstance(arg_type, NoneType):
- if not arg_none:
- ok = False
- break
- elif (
- arg_none
- and not is_optional(arg_type)
- and not (
- isinstance(arg_type, Instance)
- and arg_type.type.fullname == "builtins.object"
- )
- and not isinstance(arg_type, AnyType)
- ):
- ok = False
- break
- elif isinstance(arg_type, LiteralType) and type(arg_type.value) is bool:
- if not any(parse_bool(arg) == arg_type.value for arg in args):
- ok = False
- break
- if ok:
- return item
- return overload.items[-1]
- def _get_callee_type(call: CallExpr) -> CallableType | None:
- """Return the type of the callee, regardless of its syntatic form."""
- callee_node: Node | None = call.callee
- if isinstance(callee_node, RefExpr):
- callee_node = callee_node.node
- # Some decorators may be using typing.dataclass_transform, which is itself a decorator, so we
- # need to unwrap them to get at the true callee
- if isinstance(callee_node, Decorator):
- callee_node = callee_node.func
- if isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type:
- callee_node_type = get_proper_type(callee_node.type)
- if isinstance(callee_node_type, Overloaded):
- return find_shallow_matching_overload_item(callee_node_type, call)
- elif isinstance(callee_node_type, CallableType):
- return callee_node_type
- return None
- def add_method(
- ctx: ClassDefContext,
- name: str,
- args: list[Argument],
- return_type: Type,
- self_type: Type | None = None,
- tvar_def: TypeVarType | None = None,
- is_classmethod: bool = False,
- is_staticmethod: bool = False,
- ) -> None:
- """
- Adds a new method to a class.
- Deprecated, use add_method_to_class() instead.
- """
- add_method_to_class(
- ctx.api,
- ctx.cls,
- name=name,
- args=args,
- return_type=return_type,
- self_type=self_type,
- tvar_def=tvar_def,
- is_classmethod=is_classmethod,
- is_staticmethod=is_staticmethod,
- )
- def add_method_to_class(
- api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
- cls: ClassDef,
- name: str,
- args: list[Argument],
- return_type: Type,
- self_type: Type | None = None,
- tvar_def: TypeVarType | None = None,
- is_classmethod: bool = False,
- is_staticmethod: bool = False,
- ) -> None:
- """Adds a new method to a class definition."""
- assert not (
- is_classmethod is True and is_staticmethod is True
- ), "Can't add a new method that's both staticmethod and classmethod."
- info = cls.info
- # First remove any previously generated methods with the same name
- # to avoid clashes and problems in the semantic analyzer.
- if name in info.names:
- sym = info.names[name]
- if sym.plugin_generated and isinstance(sym.node, FuncDef):
- cls.defs.body.remove(sym.node)
- if isinstance(api, SemanticAnalyzerPluginInterface):
- function_type = api.named_type("builtins.function")
- else:
- function_type = api.named_generic_type("builtins.function", [])
- if is_classmethod:
- self_type = self_type or TypeType(fill_typevars(info))
- first = [Argument(Var("_cls"), self_type, None, ARG_POS, True)]
- elif is_staticmethod:
- first = []
- else:
- self_type = self_type or fill_typevars(info)
- first = [Argument(Var("self"), self_type, None, ARG_POS)]
- args = first + args
- arg_types, arg_names, arg_kinds = [], [], []
- for arg in args:
- assert arg.type_annotation, "All arguments must be fully typed."
- arg_types.append(arg.type_annotation)
- arg_names.append(arg.variable.name)
- arg_kinds.append(arg.kind)
- signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
- if tvar_def:
- signature.variables = [tvar_def]
- func = FuncDef(name, args, Block([PassStmt()]))
- func.info = info
- func.type = set_callable_name(signature, func)
- func.is_class = is_classmethod
- func.is_static = is_staticmethod
- func._fullname = info.fullname + "." + name
- func.line = info.line
- # NOTE: we would like the plugin generated node to dominate, but we still
- # need to keep any existing definitions so they get semantically analyzed.
- if name in info.names:
- # Get a nice unique name instead.
- r_name = get_unique_redefinition_name(name, info.names)
- info.names[r_name] = info.names[name]
- # Add decorator for is_staticmethod. It's unnecessary for is_classmethod.
- if is_staticmethod:
- func.is_decorated = True
- v = Var(name, func.type)
- v.info = info
- v._fullname = func._fullname
- v.is_staticmethod = True
- dec = Decorator(func, [], v)
- dec.line = info.line
- sym = SymbolTableNode(MDEF, dec)
- else:
- sym = SymbolTableNode(MDEF, func)
- sym.plugin_generated = True
- info.names[name] = sym
- info.defn.defs.body.append(func)
- def add_attribute_to_class(
- api: SemanticAnalyzerPluginInterface,
- cls: ClassDef,
- name: str,
- typ: Type,
- final: bool = False,
- no_serialize: bool = False,
- override_allow_incompatible: bool = False,
- fullname: str | None = None,
- is_classvar: bool = False,
- ) -> None:
- """
- Adds a new attribute to a class definition.
- This currently only generates the symbol table entry and no corresponding AssignmentStatement
- """
- info = cls.info
- # NOTE: we would like the plugin generated node to dominate, but we still
- # need to keep any existing definitions so they get semantically analyzed.
- if name in info.names:
- # Get a nice unique name instead.
- r_name = get_unique_redefinition_name(name, info.names)
- info.names[r_name] = info.names[name]
- node = Var(name, typ)
- node.info = info
- node.is_final = final
- node.is_classvar = is_classvar
- if name in ALLOW_INCOMPATIBLE_OVERRIDE:
- node.allow_incompatible_override = True
- else:
- node.allow_incompatible_override = override_allow_incompatible
- if fullname:
- node._fullname = fullname
- else:
- node._fullname = info.fullname + "." + name
- info.names[name] = SymbolTableNode(
- MDEF, node, plugin_generated=True, no_serialize=no_serialize
- )
- def deserialize_and_fixup_type(data: str | JsonDict, api: SemanticAnalyzerPluginInterface) -> Type:
- typ = deserialize_type(data)
- typ.accept(TypeFixer(api.modules, allow_missing=False))
- return typ
|