default.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. from __future__ import annotations
  2. from functools import partial
  3. from typing import Callable
  4. import mypy.errorcodes as codes
  5. from mypy import message_registry
  6. from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
  7. from mypy.plugin import (
  8. AttributeContext,
  9. ClassDefContext,
  10. FunctionContext,
  11. FunctionSigContext,
  12. MethodContext,
  13. MethodSigContext,
  14. Plugin,
  15. )
  16. from mypy.plugins.common import try_getting_str_literals
  17. from mypy.subtypes import is_subtype
  18. from mypy.typeops import is_literal_type_like, make_simplified_union
  19. from mypy.types import (
  20. TPDICT_FB_NAMES,
  21. AnyType,
  22. CallableType,
  23. FunctionLike,
  24. Instance,
  25. LiteralType,
  26. NoneType,
  27. TupleType,
  28. Type,
  29. TypedDictType,
  30. TypeOfAny,
  31. TypeVarType,
  32. get_proper_type,
  33. )
  34. class DefaultPlugin(Plugin):
  35. """Type checker plugin that is enabled by default."""
  36. def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
  37. from mypy.plugins import ctypes, singledispatch
  38. if fullname == "_ctypes.Array":
  39. return ctypes.array_constructor_callback
  40. elif fullname == "functools.singledispatch":
  41. return singledispatch.create_singledispatch_function_callback
  42. return None
  43. def get_function_signature_hook(
  44. self, fullname: str
  45. ) -> Callable[[FunctionSigContext], FunctionLike] | None:
  46. from mypy.plugins import attrs
  47. if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
  48. return attrs.evolve_function_sig_callback
  49. return None
  50. def get_method_signature_hook(
  51. self, fullname: str
  52. ) -> Callable[[MethodSigContext], FunctionLike] | None:
  53. from mypy.plugins import ctypes, singledispatch
  54. if fullname == "typing.Mapping.get":
  55. return typed_dict_get_signature_callback
  56. elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
  57. return typed_dict_setdefault_signature_callback
  58. elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
  59. return typed_dict_pop_signature_callback
  60. elif fullname in {n + ".update" for n in TPDICT_FB_NAMES}:
  61. return typed_dict_update_signature_callback
  62. elif fullname == "_ctypes.Array.__setitem__":
  63. return ctypes.array_setitem_callback
  64. elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD:
  65. return singledispatch.call_singledispatch_function_callback
  66. return None
  67. def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
  68. from mypy.plugins import ctypes, singledispatch
  69. if fullname == "typing.Mapping.get":
  70. return typed_dict_get_callback
  71. elif fullname == "builtins.int.__pow__":
  72. return int_pow_callback
  73. elif fullname == "builtins.int.__neg__":
  74. return int_neg_callback
  75. elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"):
  76. return tuple_mul_callback
  77. elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
  78. return typed_dict_setdefault_callback
  79. elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
  80. return typed_dict_pop_callback
  81. elif fullname in {n + ".__delitem__" for n in TPDICT_FB_NAMES}:
  82. return typed_dict_delitem_callback
  83. elif fullname == "_ctypes.Array.__getitem__":
  84. return ctypes.array_getitem_callback
  85. elif fullname == "_ctypes.Array.__iter__":
  86. return ctypes.array_iter_callback
  87. elif fullname == singledispatch.SINGLEDISPATCH_REGISTER_METHOD:
  88. return singledispatch.singledispatch_register_callback
  89. elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
  90. return singledispatch.call_singledispatch_function_after_register_argument
  91. return None
  92. def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
  93. from mypy.plugins import ctypes, enums
  94. if fullname == "_ctypes.Array.value":
  95. return ctypes.array_value_callback
  96. elif fullname == "_ctypes.Array.raw":
  97. return ctypes.array_raw_callback
  98. elif fullname in enums.ENUM_NAME_ACCESS:
  99. return enums.enum_name_callback
  100. elif fullname in enums.ENUM_VALUE_ACCESS:
  101. return enums.enum_value_callback
  102. return None
  103. def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
  104. from mypy.plugins import attrs, dataclasses
  105. # These dataclass and attrs hooks run in the main semantic analysis pass
  106. # and only tag known dataclasses/attrs classes, so that the second
  107. # hooks (in get_class_decorator_hook_2) can detect dataclasses/attrs classes
  108. # in the MRO.
  109. if fullname in dataclasses.dataclass_makers:
  110. return dataclasses.dataclass_tag_callback
  111. if (
  112. fullname in attrs.attr_class_makers
  113. or fullname in attrs.attr_dataclass_makers
  114. or fullname in attrs.attr_frozen_makers
  115. or fullname in attrs.attr_define_makers
  116. ):
  117. return attrs.attr_tag_callback
  118. return None
  119. def get_class_decorator_hook_2(
  120. self, fullname: str
  121. ) -> Callable[[ClassDefContext], bool] | None:
  122. from mypy.plugins import attrs, dataclasses, functools
  123. if fullname in dataclasses.dataclass_makers:
  124. return dataclasses.dataclass_class_maker_callback
  125. elif fullname in functools.functools_total_ordering_makers:
  126. return functools.functools_total_ordering_maker_callback
  127. elif fullname in attrs.attr_class_makers:
  128. return attrs.attr_class_maker_callback
  129. elif fullname in attrs.attr_dataclass_makers:
  130. return partial(attrs.attr_class_maker_callback, auto_attribs_default=True)
  131. elif fullname in attrs.attr_frozen_makers:
  132. return partial(
  133. attrs.attr_class_maker_callback, auto_attribs_default=None, frozen_default=True
  134. )
  135. elif fullname in attrs.attr_define_makers:
  136. return partial(attrs.attr_class_maker_callback, auto_attribs_default=None)
  137. return None
  138. def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
  139. """Try to infer a better signature type for TypedDict.get.
  140. This is used to get better type context for the second argument that
  141. depends on a TypedDict value type.
  142. """
  143. signature = ctx.default_signature
  144. if (
  145. isinstance(ctx.type, TypedDictType)
  146. and len(ctx.args) == 2
  147. and len(ctx.args[0]) == 1
  148. and isinstance(ctx.args[0][0], StrExpr)
  149. and len(signature.arg_types) == 2
  150. and len(signature.variables) == 1
  151. and len(ctx.args[1]) == 1
  152. ):
  153. key = ctx.args[0][0].value
  154. value_type = get_proper_type(ctx.type.items.get(key))
  155. ret_type = signature.ret_type
  156. if value_type:
  157. default_arg = ctx.args[1][0]
  158. if (
  159. isinstance(value_type, TypedDictType)
  160. and isinstance(default_arg, DictExpr)
  161. and len(default_arg.items) == 0
  162. ):
  163. # Caller has empty dict {} as default for typed dict.
  164. value_type = value_type.copy_modified(required_keys=set())
  165. # Tweak the signature to include the value type as context. It's
  166. # only needed for type inference since there's a union with a type
  167. # variable that accepts everything.
  168. tv = signature.variables[0]
  169. assert isinstance(tv, TypeVarType)
  170. return signature.copy_modified(
  171. arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])],
  172. ret_type=ret_type,
  173. )
  174. return signature
  175. def typed_dict_get_callback(ctx: MethodContext) -> Type:
  176. """Infer a precise return type for TypedDict.get with literal first argument."""
  177. if (
  178. isinstance(ctx.type, TypedDictType)
  179. and len(ctx.arg_types) >= 1
  180. and len(ctx.arg_types[0]) == 1
  181. ):
  182. keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
  183. if keys is None:
  184. return ctx.default_return_type
  185. output_types: list[Type] = []
  186. for key in keys:
  187. value_type = get_proper_type(ctx.type.items.get(key))
  188. if value_type is None:
  189. return ctx.default_return_type
  190. if len(ctx.arg_types) == 1:
  191. output_types.append(value_type)
  192. elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
  193. default_arg = ctx.args[1][0]
  194. if (
  195. isinstance(default_arg, DictExpr)
  196. and len(default_arg.items) == 0
  197. and isinstance(value_type, TypedDictType)
  198. ):
  199. # Special case '{}' as the default for a typed dict type.
  200. output_types.append(value_type.copy_modified(required_keys=set()))
  201. else:
  202. output_types.append(value_type)
  203. output_types.append(ctx.arg_types[1][0])
  204. if len(ctx.arg_types) == 1:
  205. output_types.append(NoneType())
  206. return make_simplified_union(output_types)
  207. return ctx.default_return_type
  208. def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
  209. """Try to infer a better signature type for TypedDict.pop.
  210. This is used to get better type context for the second argument that
  211. depends on a TypedDict value type.
  212. """
  213. signature = ctx.default_signature
  214. str_type = ctx.api.named_generic_type("builtins.str", [])
  215. if (
  216. isinstance(ctx.type, TypedDictType)
  217. and len(ctx.args) == 2
  218. and len(ctx.args[0]) == 1
  219. and isinstance(ctx.args[0][0], StrExpr)
  220. and len(signature.arg_types) == 2
  221. and len(signature.variables) == 1
  222. and len(ctx.args[1]) == 1
  223. ):
  224. key = ctx.args[0][0].value
  225. value_type = ctx.type.items.get(key)
  226. if value_type:
  227. # Tweak the signature to include the value type as context. It's
  228. # only needed for type inference since there's a union with a type
  229. # variable that accepts everything.
  230. tv = signature.variables[0]
  231. assert isinstance(tv, TypeVarType)
  232. typ = make_simplified_union([value_type, tv])
  233. return signature.copy_modified(arg_types=[str_type, typ], ret_type=typ)
  234. return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
  235. def typed_dict_pop_callback(ctx: MethodContext) -> Type:
  236. """Type check and infer a precise return type for TypedDict.pop."""
  237. if (
  238. isinstance(ctx.type, TypedDictType)
  239. and len(ctx.arg_types) >= 1
  240. and len(ctx.arg_types[0]) == 1
  241. ):
  242. keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
  243. if keys is None:
  244. ctx.api.fail(
  245. message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
  246. ctx.context,
  247. code=codes.LITERAL_REQ,
  248. )
  249. return AnyType(TypeOfAny.from_error)
  250. value_types = []
  251. for key in keys:
  252. if key in ctx.type.required_keys:
  253. ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
  254. value_type = ctx.type.items.get(key)
  255. if value_type:
  256. value_types.append(value_type)
  257. else:
  258. ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
  259. return AnyType(TypeOfAny.from_error)
  260. if len(ctx.args[1]) == 0:
  261. return make_simplified_union(value_types)
  262. elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
  263. return make_simplified_union([*value_types, ctx.arg_types[1][0]])
  264. return ctx.default_return_type
  265. def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType:
  266. """Try to infer a better signature type for TypedDict.setdefault.
  267. This is used to get better type context for the second argument that
  268. depends on a TypedDict value type.
  269. """
  270. signature = ctx.default_signature
  271. str_type = ctx.api.named_generic_type("builtins.str", [])
  272. if (
  273. isinstance(ctx.type, TypedDictType)
  274. and len(ctx.args) == 2
  275. and len(ctx.args[0]) == 1
  276. and isinstance(ctx.args[0][0], StrExpr)
  277. and len(signature.arg_types) == 2
  278. and len(ctx.args[1]) == 1
  279. ):
  280. key = ctx.args[0][0].value
  281. value_type = ctx.type.items.get(key)
  282. if value_type:
  283. return signature.copy_modified(arg_types=[str_type, value_type])
  284. return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
  285. def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
  286. """Type check TypedDict.setdefault and infer a precise return type."""
  287. if (
  288. isinstance(ctx.type, TypedDictType)
  289. and len(ctx.arg_types) == 2
  290. and len(ctx.arg_types[0]) == 1
  291. and len(ctx.arg_types[1]) == 1
  292. ):
  293. keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
  294. if keys is None:
  295. ctx.api.fail(
  296. message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
  297. ctx.context,
  298. code=codes.LITERAL_REQ,
  299. )
  300. return AnyType(TypeOfAny.from_error)
  301. default_type = ctx.arg_types[1][0]
  302. value_types = []
  303. for key in keys:
  304. value_type = ctx.type.items.get(key)
  305. if value_type is None:
  306. ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
  307. return AnyType(TypeOfAny.from_error)
  308. # The signature_callback above can't always infer the right signature
  309. # (e.g. when the expression is a variable that happens to be a Literal str)
  310. # so we need to handle the check ourselves here and make sure the provided
  311. # default can be assigned to all key-value pairs we're updating.
  312. if not is_subtype(default_type, value_type):
  313. ctx.api.msg.typeddict_setdefault_arguments_inconsistent(
  314. default_type, value_type, ctx.context
  315. )
  316. return AnyType(TypeOfAny.from_error)
  317. value_types.append(value_type)
  318. return make_simplified_union(value_types)
  319. return ctx.default_return_type
  320. def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
  321. """Type check TypedDict.__delitem__."""
  322. if (
  323. isinstance(ctx.type, TypedDictType)
  324. and len(ctx.arg_types) == 1
  325. and len(ctx.arg_types[0]) == 1
  326. ):
  327. keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
  328. if keys is None:
  329. ctx.api.fail(
  330. message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
  331. ctx.context,
  332. code=codes.LITERAL_REQ,
  333. )
  334. return AnyType(TypeOfAny.from_error)
  335. for key in keys:
  336. if key in ctx.type.required_keys:
  337. ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
  338. elif key not in ctx.type.items:
  339. ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
  340. return ctx.default_return_type
  341. def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
  342. """Try to infer a better signature type for TypedDict.update."""
  343. signature = ctx.default_signature
  344. if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1:
  345. arg_type = get_proper_type(signature.arg_types[0])
  346. assert isinstance(arg_type, TypedDictType)
  347. arg_type = arg_type.as_anonymous()
  348. arg_type = arg_type.copy_modified(required_keys=set())
  349. return signature.copy_modified(arg_types=[arg_type])
  350. return signature
  351. def int_pow_callback(ctx: MethodContext) -> Type:
  352. """Infer a more precise return type for int.__pow__."""
  353. # int.__pow__ has an optional modulo argument,
  354. # so we expect 2 argument positions
  355. if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0:
  356. arg = ctx.args[0][0]
  357. if isinstance(arg, IntExpr):
  358. exponent = arg.value
  359. elif isinstance(arg, UnaryExpr) and arg.op == "-" and isinstance(arg.expr, IntExpr):
  360. exponent = -arg.expr.value
  361. else:
  362. # Right operand not an int literal or a negated literal -- give up.
  363. return ctx.default_return_type
  364. if exponent >= 0:
  365. return ctx.api.named_generic_type("builtins.int", [])
  366. else:
  367. return ctx.api.named_generic_type("builtins.float", [])
  368. return ctx.default_return_type
  369. def int_neg_callback(ctx: MethodContext) -> Type:
  370. """Infer a more precise return type for int.__neg__.
  371. This is mainly used to infer the return type as LiteralType
  372. if the original underlying object is a LiteralType object
  373. """
  374. if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
  375. value = ctx.type.last_known_value.value
  376. fallback = ctx.type.last_known_value.fallback
  377. if isinstance(value, int):
  378. if is_literal_type_like(ctx.api.type_context[-1]):
  379. return LiteralType(value=-value, fallback=fallback)
  380. else:
  381. return ctx.type.copy_modified(
  382. last_known_value=LiteralType(
  383. value=-value, fallback=ctx.type, line=ctx.type.line, column=ctx.type.column
  384. )
  385. )
  386. elif isinstance(ctx.type, LiteralType):
  387. value = ctx.type.value
  388. fallback = ctx.type.fallback
  389. if isinstance(value, int):
  390. return LiteralType(value=-value, fallback=fallback)
  391. return ctx.default_return_type
  392. def tuple_mul_callback(ctx: MethodContext) -> Type:
  393. """Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.
  394. This is used to return a specific sized tuple if multiplied by Literal int
  395. """
  396. if not isinstance(ctx.type, TupleType):
  397. return ctx.default_return_type
  398. arg_type = get_proper_type(ctx.arg_types[0][0])
  399. if isinstance(arg_type, Instance) and arg_type.last_known_value is not None:
  400. value = arg_type.last_known_value.value
  401. if isinstance(value, int):
  402. return ctx.type.copy_modified(items=ctx.type.items * value)
  403. elif isinstance(ctx.type, LiteralType):
  404. value = arg_type.value
  405. if isinstance(value, int):
  406. return ctx.type.copy_modified(items=ctx.type.items * value)
  407. return ctx.default_return_type