| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502 |
- from __future__ import annotations
- from functools import partial
- from typing import Callable
- import mypy.errorcodes as codes
- from mypy import message_registry
- from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
- from mypy.plugin import (
- AttributeContext,
- ClassDefContext,
- FunctionContext,
- FunctionSigContext,
- MethodContext,
- MethodSigContext,
- Plugin,
- )
- from mypy.plugins.common import try_getting_str_literals
- from mypy.subtypes import is_subtype
- from mypy.typeops import is_literal_type_like, make_simplified_union
- from mypy.types import (
- TPDICT_FB_NAMES,
- AnyType,
- CallableType,
- FunctionLike,
- Instance,
- LiteralType,
- NoneType,
- TupleType,
- Type,
- TypedDictType,
- TypeOfAny,
- TypeVarType,
- UnionType,
- get_proper_type,
- get_proper_types,
- )
- class DefaultPlugin(Plugin):
- """Type checker plugin that is enabled by default."""
- def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
- from mypy.plugins import ctypes, singledispatch
- if fullname == "_ctypes.Array":
- return ctypes.array_constructor_callback
- elif fullname == "functools.singledispatch":
- return singledispatch.create_singledispatch_function_callback
- return None
- def get_function_signature_hook(
- self, fullname: str
- ) -> Callable[[FunctionSigContext], FunctionLike] | None:
- from mypy.plugins import attrs, dataclasses
- if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
- return attrs.evolve_function_sig_callback
- elif fullname == "dataclasses.replace":
- return dataclasses.replace_function_sig_callback
- return None
- def get_method_signature_hook(
- self, fullname: str
- ) -> Callable[[MethodSigContext], FunctionLike] | None:
- from mypy.plugins import ctypes, singledispatch
- if fullname == "typing.Mapping.get":
- return typed_dict_get_signature_callback
- elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
- return typed_dict_setdefault_signature_callback
- elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
- return typed_dict_pop_signature_callback
- elif fullname in {n + ".update" for n in TPDICT_FB_NAMES}:
- return typed_dict_update_signature_callback
- elif fullname == "_ctypes.Array.__setitem__":
- return ctypes.array_setitem_callback
- elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD:
- return singledispatch.call_singledispatch_function_callback
- return None
- def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
- from mypy.plugins import ctypes, singledispatch
- if fullname == "typing.Mapping.get":
- return typed_dict_get_callback
- elif fullname == "builtins.int.__pow__":
- return int_pow_callback
- elif fullname == "builtins.int.__neg__":
- return int_neg_callback
- elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"):
- return tuple_mul_callback
- elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
- return typed_dict_setdefault_callback
- elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
- return typed_dict_pop_callback
- elif fullname in {n + ".__delitem__" for n in TPDICT_FB_NAMES}:
- return typed_dict_delitem_callback
- elif fullname == "_ctypes.Array.__getitem__":
- return ctypes.array_getitem_callback
- elif fullname == "_ctypes.Array.__iter__":
- return ctypes.array_iter_callback
- elif fullname == singledispatch.SINGLEDISPATCH_REGISTER_METHOD:
- return singledispatch.singledispatch_register_callback
- elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
- return singledispatch.call_singledispatch_function_after_register_argument
- return None
- def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
- from mypy.plugins import ctypes, enums
- if fullname == "_ctypes.Array.value":
- return ctypes.array_value_callback
- elif fullname == "_ctypes.Array.raw":
- return ctypes.array_raw_callback
- elif fullname in enums.ENUM_NAME_ACCESS:
- return enums.enum_name_callback
- elif fullname in enums.ENUM_VALUE_ACCESS:
- return enums.enum_value_callback
- return None
- def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
- from mypy.plugins import attrs, dataclasses
- # These dataclass and attrs hooks run in the main semantic analysis pass
- # and only tag known dataclasses/attrs classes, so that the second
- # hooks (in get_class_decorator_hook_2) can detect dataclasses/attrs classes
- # in the MRO.
- if fullname in dataclasses.dataclass_makers:
- return dataclasses.dataclass_tag_callback
- if (
- fullname in attrs.attr_class_makers
- or fullname in attrs.attr_dataclass_makers
- or fullname in attrs.attr_frozen_makers
- or fullname in attrs.attr_define_makers
- ):
- return attrs.attr_tag_callback
- return None
- def get_class_decorator_hook_2(
- self, fullname: str
- ) -> Callable[[ClassDefContext], bool] | None:
- from mypy.plugins import attrs, dataclasses, functools
- if fullname in dataclasses.dataclass_makers:
- return dataclasses.dataclass_class_maker_callback
- elif fullname in functools.functools_total_ordering_makers:
- return functools.functools_total_ordering_maker_callback
- elif fullname in attrs.attr_class_makers:
- return attrs.attr_class_maker_callback
- elif fullname in attrs.attr_dataclass_makers:
- return partial(attrs.attr_class_maker_callback, auto_attribs_default=True)
- elif fullname in attrs.attr_frozen_makers:
- return partial(
- attrs.attr_class_maker_callback, auto_attribs_default=None, frozen_default=True
- )
- elif fullname in attrs.attr_define_makers:
- return partial(
- attrs.attr_class_maker_callback, auto_attribs_default=None, slots_default=True
- )
- return None
- def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
- """Try to infer a better signature type for TypedDict.get.
- This is used to get better type context for the second argument that
- depends on a TypedDict value type.
- """
- signature = ctx.default_signature
- if (
- isinstance(ctx.type, TypedDictType)
- and len(ctx.args) == 2
- and len(ctx.args[0]) == 1
- and isinstance(ctx.args[0][0], StrExpr)
- and len(signature.arg_types) == 2
- and len(signature.variables) == 1
- and len(ctx.args[1]) == 1
- ):
- key = ctx.args[0][0].value
- value_type = get_proper_type(ctx.type.items.get(key))
- ret_type = signature.ret_type
- if value_type:
- default_arg = ctx.args[1][0]
- if (
- isinstance(value_type, TypedDictType)
- and isinstance(default_arg, DictExpr)
- and len(default_arg.items) == 0
- ):
- # Caller has empty dict {} as default for typed dict.
- value_type = value_type.copy_modified(required_keys=set())
- # Tweak the signature to include the value type as context. It's
- # only needed for type inference since there's a union with a type
- # variable that accepts everything.
- tv = signature.variables[0]
- assert isinstance(tv, TypeVarType)
- return signature.copy_modified(
- arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])],
- ret_type=ret_type,
- )
- return signature
- def typed_dict_get_callback(ctx: MethodContext) -> Type:
- """Infer a precise return type for TypedDict.get with literal first argument."""
- if (
- isinstance(ctx.type, TypedDictType)
- and len(ctx.arg_types) >= 1
- and len(ctx.arg_types[0]) == 1
- ):
- keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
- if keys is None:
- return ctx.default_return_type
- output_types: list[Type] = []
- for key in keys:
- value_type = get_proper_type(ctx.type.items.get(key))
- if value_type is None:
- return ctx.default_return_type
- if len(ctx.arg_types) == 1:
- output_types.append(value_type)
- elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
- default_arg = ctx.args[1][0]
- if (
- isinstance(default_arg, DictExpr)
- and len(default_arg.items) == 0
- and isinstance(value_type, TypedDictType)
- ):
- # Special case '{}' as the default for a typed dict type.
- output_types.append(value_type.copy_modified(required_keys=set()))
- else:
- output_types.append(value_type)
- output_types.append(ctx.arg_types[1][0])
- if len(ctx.arg_types) == 1:
- output_types.append(NoneType())
- return make_simplified_union(output_types)
- return ctx.default_return_type
- def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
- """Try to infer a better signature type for TypedDict.pop.
- This is used to get better type context for the second argument that
- depends on a TypedDict value type.
- """
- signature = ctx.default_signature
- str_type = ctx.api.named_generic_type("builtins.str", [])
- if (
- isinstance(ctx.type, TypedDictType)
- and len(ctx.args) == 2
- and len(ctx.args[0]) == 1
- and isinstance(ctx.args[0][0], StrExpr)
- and len(signature.arg_types) == 2
- and len(signature.variables) == 1
- and len(ctx.args[1]) == 1
- ):
- key = ctx.args[0][0].value
- value_type = ctx.type.items.get(key)
- if value_type:
- # Tweak the signature to include the value type as context. It's
- # only needed for type inference since there's a union with a type
- # variable that accepts everything.
- tv = signature.variables[0]
- assert isinstance(tv, TypeVarType)
- typ = make_simplified_union([value_type, tv])
- return signature.copy_modified(arg_types=[str_type, typ], ret_type=typ)
- return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
- def typed_dict_pop_callback(ctx: MethodContext) -> Type:
- """Type check and infer a precise return type for TypedDict.pop."""
- if (
- isinstance(ctx.type, TypedDictType)
- and len(ctx.arg_types) >= 1
- and len(ctx.arg_types[0]) == 1
- ):
- keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
- if keys is None:
- ctx.api.fail(
- message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
- ctx.context,
- code=codes.LITERAL_REQ,
- )
- return AnyType(TypeOfAny.from_error)
- value_types = []
- for key in keys:
- if key in ctx.type.required_keys:
- ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
- value_type = ctx.type.items.get(key)
- if value_type:
- value_types.append(value_type)
- else:
- ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
- return AnyType(TypeOfAny.from_error)
- if len(ctx.args[1]) == 0:
- return make_simplified_union(value_types)
- elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
- return make_simplified_union([*value_types, ctx.arg_types[1][0]])
- return ctx.default_return_type
- def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType:
- """Try to infer a better signature type for TypedDict.setdefault.
- This is used to get better type context for the second argument that
- depends on a TypedDict value type.
- """
- signature = ctx.default_signature
- str_type = ctx.api.named_generic_type("builtins.str", [])
- if (
- isinstance(ctx.type, TypedDictType)
- and len(ctx.args) == 2
- and len(ctx.args[0]) == 1
- and isinstance(ctx.args[0][0], StrExpr)
- and len(signature.arg_types) == 2
- and len(ctx.args[1]) == 1
- ):
- key = ctx.args[0][0].value
- value_type = ctx.type.items.get(key)
- if value_type:
- return signature.copy_modified(arg_types=[str_type, value_type])
- return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
- def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
- """Type check TypedDict.setdefault and infer a precise return type."""
- if (
- isinstance(ctx.type, TypedDictType)
- and len(ctx.arg_types) == 2
- and len(ctx.arg_types[0]) == 1
- and len(ctx.arg_types[1]) == 1
- ):
- keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
- if keys is None:
- ctx.api.fail(
- message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
- ctx.context,
- code=codes.LITERAL_REQ,
- )
- return AnyType(TypeOfAny.from_error)
- default_type = ctx.arg_types[1][0]
- value_types = []
- for key in keys:
- value_type = ctx.type.items.get(key)
- if value_type is None:
- ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
- return AnyType(TypeOfAny.from_error)
- # The signature_callback above can't always infer the right signature
- # (e.g. when the expression is a variable that happens to be a Literal str)
- # so we need to handle the check ourselves here and make sure the provided
- # default can be assigned to all key-value pairs we're updating.
- if not is_subtype(default_type, value_type):
- ctx.api.msg.typeddict_setdefault_arguments_inconsistent(
- default_type, value_type, ctx.context
- )
- return AnyType(TypeOfAny.from_error)
- value_types.append(value_type)
- return make_simplified_union(value_types)
- return ctx.default_return_type
- def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
- """Type check TypedDict.__delitem__."""
- if (
- isinstance(ctx.type, TypedDictType)
- and len(ctx.arg_types) == 1
- and len(ctx.arg_types[0]) == 1
- ):
- keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
- if keys is None:
- ctx.api.fail(
- message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
- ctx.context,
- code=codes.LITERAL_REQ,
- )
- return AnyType(TypeOfAny.from_error)
- for key in keys:
- if key in ctx.type.required_keys:
- ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
- elif key not in ctx.type.items:
- ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
- return ctx.default_return_type
- def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
- """Try to infer a better signature type for TypedDict.update."""
- signature = ctx.default_signature
- if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1:
- arg_type = get_proper_type(signature.arg_types[0])
- assert isinstance(arg_type, TypedDictType)
- arg_type = arg_type.as_anonymous()
- arg_type = arg_type.copy_modified(required_keys=set())
- if ctx.args and ctx.args[0]:
- with ctx.api.msg.filter_errors():
- inferred = get_proper_type(
- ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type)
- )
- possible_tds = []
- if isinstance(inferred, TypedDictType):
- possible_tds = [inferred]
- elif isinstance(inferred, UnionType):
- possible_tds = [
- t
- for t in get_proper_types(inferred.relevant_items())
- if isinstance(t, TypedDictType)
- ]
- items = []
- for td in possible_tds:
- item = arg_type.copy_modified(
- required_keys=(arg_type.required_keys | td.required_keys)
- & arg_type.items.keys()
- )
- if not ctx.api.options.extra_checks:
- item = item.copy_modified(item_names=list(td.items))
- items.append(item)
- if items:
- arg_type = make_simplified_union(items)
- return signature.copy_modified(arg_types=[arg_type])
- return signature
- def int_pow_callback(ctx: MethodContext) -> Type:
- """Infer a more precise return type for int.__pow__."""
- # int.__pow__ has an optional modulo argument,
- # so we expect 2 argument positions
- if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0:
- arg = ctx.args[0][0]
- if isinstance(arg, IntExpr):
- exponent = arg.value
- elif isinstance(arg, UnaryExpr) and arg.op == "-" and isinstance(arg.expr, IntExpr):
- exponent = -arg.expr.value
- else:
- # Right operand not an int literal or a negated literal -- give up.
- return ctx.default_return_type
- if exponent >= 0:
- return ctx.api.named_generic_type("builtins.int", [])
- else:
- return ctx.api.named_generic_type("builtins.float", [])
- return ctx.default_return_type
- def int_neg_callback(ctx: MethodContext) -> Type:
- """Infer a more precise return type for int.__neg__.
- This is mainly used to infer the return type as LiteralType
- if the original underlying object is a LiteralType object
- """
- if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
- value = ctx.type.last_known_value.value
- fallback = ctx.type.last_known_value.fallback
- if isinstance(value, int):
- if is_literal_type_like(ctx.api.type_context[-1]):
- return LiteralType(value=-value, fallback=fallback)
- else:
- return ctx.type.copy_modified(
- last_known_value=LiteralType(
- value=-value, fallback=ctx.type, line=ctx.type.line, column=ctx.type.column
- )
- )
- elif isinstance(ctx.type, LiteralType):
- value = ctx.type.value
- fallback = ctx.type.fallback
- if isinstance(value, int):
- return LiteralType(value=-value, fallback=fallback)
- return ctx.default_return_type
- def tuple_mul_callback(ctx: MethodContext) -> Type:
- """Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.
- This is used to return a specific sized tuple if multiplied by Literal int
- """
- if not isinstance(ctx.type, TupleType):
- return ctx.default_return_type
- arg_type = get_proper_type(ctx.arg_types[0][0])
- if isinstance(arg_type, Instance) and arg_type.last_known_value is not None:
- value = arg_type.last_known_value.value
- if isinstance(value, int):
- return ctx.type.copy_modified(items=ctx.type.items * value)
- elif isinstance(ctx.type, LiteralType):
- value = arg_type.value
- if isinstance(value, int):
- return ctx.type.copy_modified(items=ctx.type.items * value)
- return ctx.default_return_type
|