default.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. from __future__ import annotations
  2. from functools import partial
  3. from typing import Callable
  4. import mypy.errorcodes as codes
  5. from mypy import message_registry
  6. from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
  7. from mypy.plugin import (
  8. AttributeContext,
  9. ClassDefContext,
  10. FunctionContext,
  11. FunctionSigContext,
  12. MethodContext,
  13. MethodSigContext,
  14. Plugin,
  15. )
  16. from mypy.plugins.common import try_getting_str_literals
  17. from mypy.subtypes import is_subtype
  18. from mypy.typeops import is_literal_type_like, make_simplified_union
  19. from mypy.types import (
  20. TPDICT_FB_NAMES,
  21. AnyType,
  22. CallableType,
  23. FunctionLike,
  24. Instance,
  25. LiteralType,
  26. NoneType,
  27. TupleType,
  28. Type,
  29. TypedDictType,
  30. TypeOfAny,
  31. TypeVarType,
  32. UnionType,
  33. get_proper_type,
  34. get_proper_types,
  35. )
  36. class DefaultPlugin(Plugin):
  37. """Type checker plugin that is enabled by default."""
  38. def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
  39. from mypy.plugins import ctypes, singledispatch
  40. if fullname == "_ctypes.Array":
  41. return ctypes.array_constructor_callback
  42. elif fullname == "functools.singledispatch":
  43. return singledispatch.create_singledispatch_function_callback
  44. return None
  45. def get_function_signature_hook(
  46. self, fullname: str
  47. ) -> Callable[[FunctionSigContext], FunctionLike] | None:
  48. from mypy.plugins import attrs, dataclasses
  49. if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
  50. return attrs.evolve_function_sig_callback
  51. elif fullname == "dataclasses.replace":
  52. return dataclasses.replace_function_sig_callback
  53. return None
  54. def get_method_signature_hook(
  55. self, fullname: str
  56. ) -> Callable[[MethodSigContext], FunctionLike] | None:
  57. from mypy.plugins import ctypes, singledispatch
  58. if fullname == "typing.Mapping.get":
  59. return typed_dict_get_signature_callback
  60. elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
  61. return typed_dict_setdefault_signature_callback
  62. elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
  63. return typed_dict_pop_signature_callback
  64. elif fullname in {n + ".update" for n in TPDICT_FB_NAMES}:
  65. return typed_dict_update_signature_callback
  66. elif fullname == "_ctypes.Array.__setitem__":
  67. return ctypes.array_setitem_callback
  68. elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD:
  69. return singledispatch.call_singledispatch_function_callback
  70. return None
  71. def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
  72. from mypy.plugins import ctypes, singledispatch
  73. if fullname == "typing.Mapping.get":
  74. return typed_dict_get_callback
  75. elif fullname == "builtins.int.__pow__":
  76. return int_pow_callback
  77. elif fullname == "builtins.int.__neg__":
  78. return int_neg_callback
  79. elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"):
  80. return tuple_mul_callback
  81. elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
  82. return typed_dict_setdefault_callback
  83. elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
  84. return typed_dict_pop_callback
  85. elif fullname in {n + ".__delitem__" for n in TPDICT_FB_NAMES}:
  86. return typed_dict_delitem_callback
  87. elif fullname == "_ctypes.Array.__getitem__":
  88. return ctypes.array_getitem_callback
  89. elif fullname == "_ctypes.Array.__iter__":
  90. return ctypes.array_iter_callback
  91. elif fullname == singledispatch.SINGLEDISPATCH_REGISTER_METHOD:
  92. return singledispatch.singledispatch_register_callback
  93. elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
  94. return singledispatch.call_singledispatch_function_after_register_argument
  95. return None
  96. def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
  97. from mypy.plugins import ctypes, enums
  98. if fullname == "_ctypes.Array.value":
  99. return ctypes.array_value_callback
  100. elif fullname == "_ctypes.Array.raw":
  101. return ctypes.array_raw_callback
  102. elif fullname in enums.ENUM_NAME_ACCESS:
  103. return enums.enum_name_callback
  104. elif fullname in enums.ENUM_VALUE_ACCESS:
  105. return enums.enum_value_callback
  106. return None
  107. def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
  108. from mypy.plugins import attrs, dataclasses
  109. # These dataclass and attrs hooks run in the main semantic analysis pass
  110. # and only tag known dataclasses/attrs classes, so that the second
  111. # hooks (in get_class_decorator_hook_2) can detect dataclasses/attrs classes
  112. # in the MRO.
  113. if fullname in dataclasses.dataclass_makers:
  114. return dataclasses.dataclass_tag_callback
  115. if (
  116. fullname in attrs.attr_class_makers
  117. or fullname in attrs.attr_dataclass_makers
  118. or fullname in attrs.attr_frozen_makers
  119. or fullname in attrs.attr_define_makers
  120. ):
  121. return attrs.attr_tag_callback
  122. return None
  123. def get_class_decorator_hook_2(
  124. self, fullname: str
  125. ) -> Callable[[ClassDefContext], bool] | None:
  126. from mypy.plugins import attrs, dataclasses, functools
  127. if fullname in dataclasses.dataclass_makers:
  128. return dataclasses.dataclass_class_maker_callback
  129. elif fullname in functools.functools_total_ordering_makers:
  130. return functools.functools_total_ordering_maker_callback
  131. elif fullname in attrs.attr_class_makers:
  132. return attrs.attr_class_maker_callback
  133. elif fullname in attrs.attr_dataclass_makers:
  134. return partial(attrs.attr_class_maker_callback, auto_attribs_default=True)
  135. elif fullname in attrs.attr_frozen_makers:
  136. return partial(
  137. attrs.attr_class_maker_callback, auto_attribs_default=None, frozen_default=True
  138. )
  139. elif fullname in attrs.attr_define_makers:
  140. return partial(
  141. attrs.attr_class_maker_callback, auto_attribs_default=None, slots_default=True
  142. )
  143. return None
  144. def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
  145. """Try to infer a better signature type for TypedDict.get.
  146. This is used to get better type context for the second argument that
  147. depends on a TypedDict value type.
  148. """
  149. signature = ctx.default_signature
  150. if (
  151. isinstance(ctx.type, TypedDictType)
  152. and len(ctx.args) == 2
  153. and len(ctx.args[0]) == 1
  154. and isinstance(ctx.args[0][0], StrExpr)
  155. and len(signature.arg_types) == 2
  156. and len(signature.variables) == 1
  157. and len(ctx.args[1]) == 1
  158. ):
  159. key = ctx.args[0][0].value
  160. value_type = get_proper_type(ctx.type.items.get(key))
  161. ret_type = signature.ret_type
  162. if value_type:
  163. default_arg = ctx.args[1][0]
  164. if (
  165. isinstance(value_type, TypedDictType)
  166. and isinstance(default_arg, DictExpr)
  167. and len(default_arg.items) == 0
  168. ):
  169. # Caller has empty dict {} as default for typed dict.
  170. value_type = value_type.copy_modified(required_keys=set())
  171. # Tweak the signature to include the value type as context. It's
  172. # only needed for type inference since there's a union with a type
  173. # variable that accepts everything.
  174. tv = signature.variables[0]
  175. assert isinstance(tv, TypeVarType)
  176. return signature.copy_modified(
  177. arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])],
  178. ret_type=ret_type,
  179. )
  180. return signature
  181. def typed_dict_get_callback(ctx: MethodContext) -> Type:
  182. """Infer a precise return type for TypedDict.get with literal first argument."""
  183. if (
  184. isinstance(ctx.type, TypedDictType)
  185. and len(ctx.arg_types) >= 1
  186. and len(ctx.arg_types[0]) == 1
  187. ):
  188. keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
  189. if keys is None:
  190. return ctx.default_return_type
  191. output_types: list[Type] = []
  192. for key in keys:
  193. value_type = get_proper_type(ctx.type.items.get(key))
  194. if value_type is None:
  195. return ctx.default_return_type
  196. if len(ctx.arg_types) == 1:
  197. output_types.append(value_type)
  198. elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
  199. default_arg = ctx.args[1][0]
  200. if (
  201. isinstance(default_arg, DictExpr)
  202. and len(default_arg.items) == 0
  203. and isinstance(value_type, TypedDictType)
  204. ):
  205. # Special case '{}' as the default for a typed dict type.
  206. output_types.append(value_type.copy_modified(required_keys=set()))
  207. else:
  208. output_types.append(value_type)
  209. output_types.append(ctx.arg_types[1][0])
  210. if len(ctx.arg_types) == 1:
  211. output_types.append(NoneType())
  212. return make_simplified_union(output_types)
  213. return ctx.default_return_type
  214. def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
  215. """Try to infer a better signature type for TypedDict.pop.
  216. This is used to get better type context for the second argument that
  217. depends on a TypedDict value type.
  218. """
  219. signature = ctx.default_signature
  220. str_type = ctx.api.named_generic_type("builtins.str", [])
  221. if (
  222. isinstance(ctx.type, TypedDictType)
  223. and len(ctx.args) == 2
  224. and len(ctx.args[0]) == 1
  225. and isinstance(ctx.args[0][0], StrExpr)
  226. and len(signature.arg_types) == 2
  227. and len(signature.variables) == 1
  228. and len(ctx.args[1]) == 1
  229. ):
  230. key = ctx.args[0][0].value
  231. value_type = ctx.type.items.get(key)
  232. if value_type:
  233. # Tweak the signature to include the value type as context. It's
  234. # only needed for type inference since there's a union with a type
  235. # variable that accepts everything.
  236. tv = signature.variables[0]
  237. assert isinstance(tv, TypeVarType)
  238. typ = make_simplified_union([value_type, tv])
  239. return signature.copy_modified(arg_types=[str_type, typ], ret_type=typ)
  240. return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
  241. def typed_dict_pop_callback(ctx: MethodContext) -> Type:
  242. """Type check and infer a precise return type for TypedDict.pop."""
  243. if (
  244. isinstance(ctx.type, TypedDictType)
  245. and len(ctx.arg_types) >= 1
  246. and len(ctx.arg_types[0]) == 1
  247. ):
  248. keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
  249. if keys is None:
  250. ctx.api.fail(
  251. message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
  252. ctx.context,
  253. code=codes.LITERAL_REQ,
  254. )
  255. return AnyType(TypeOfAny.from_error)
  256. value_types = []
  257. for key in keys:
  258. if key in ctx.type.required_keys:
  259. ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
  260. value_type = ctx.type.items.get(key)
  261. if value_type:
  262. value_types.append(value_type)
  263. else:
  264. ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
  265. return AnyType(TypeOfAny.from_error)
  266. if len(ctx.args[1]) == 0:
  267. return make_simplified_union(value_types)
  268. elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
  269. return make_simplified_union([*value_types, ctx.arg_types[1][0]])
  270. return ctx.default_return_type
  271. def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType:
  272. """Try to infer a better signature type for TypedDict.setdefault.
  273. This is used to get better type context for the second argument that
  274. depends on a TypedDict value type.
  275. """
  276. signature = ctx.default_signature
  277. str_type = ctx.api.named_generic_type("builtins.str", [])
  278. if (
  279. isinstance(ctx.type, TypedDictType)
  280. and len(ctx.args) == 2
  281. and len(ctx.args[0]) == 1
  282. and isinstance(ctx.args[0][0], StrExpr)
  283. and len(signature.arg_types) == 2
  284. and len(ctx.args[1]) == 1
  285. ):
  286. key = ctx.args[0][0].value
  287. value_type = ctx.type.items.get(key)
  288. if value_type:
  289. return signature.copy_modified(arg_types=[str_type, value_type])
  290. return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
  291. def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
  292. """Type check TypedDict.setdefault and infer a precise return type."""
  293. if (
  294. isinstance(ctx.type, TypedDictType)
  295. and len(ctx.arg_types) == 2
  296. and len(ctx.arg_types[0]) == 1
  297. and len(ctx.arg_types[1]) == 1
  298. ):
  299. keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
  300. if keys is None:
  301. ctx.api.fail(
  302. message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
  303. ctx.context,
  304. code=codes.LITERAL_REQ,
  305. )
  306. return AnyType(TypeOfAny.from_error)
  307. default_type = ctx.arg_types[1][0]
  308. value_types = []
  309. for key in keys:
  310. value_type = ctx.type.items.get(key)
  311. if value_type is None:
  312. ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
  313. return AnyType(TypeOfAny.from_error)
  314. # The signature_callback above can't always infer the right signature
  315. # (e.g. when the expression is a variable that happens to be a Literal str)
  316. # so we need to handle the check ourselves here and make sure the provided
  317. # default can be assigned to all key-value pairs we're updating.
  318. if not is_subtype(default_type, value_type):
  319. ctx.api.msg.typeddict_setdefault_arguments_inconsistent(
  320. default_type, value_type, ctx.context
  321. )
  322. return AnyType(TypeOfAny.from_error)
  323. value_types.append(value_type)
  324. return make_simplified_union(value_types)
  325. return ctx.default_return_type
  326. def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
  327. """Type check TypedDict.__delitem__."""
  328. if (
  329. isinstance(ctx.type, TypedDictType)
  330. and len(ctx.arg_types) == 1
  331. and len(ctx.arg_types[0]) == 1
  332. ):
  333. keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
  334. if keys is None:
  335. ctx.api.fail(
  336. message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
  337. ctx.context,
  338. code=codes.LITERAL_REQ,
  339. )
  340. return AnyType(TypeOfAny.from_error)
  341. for key in keys:
  342. if key in ctx.type.required_keys:
  343. ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
  344. elif key not in ctx.type.items:
  345. ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
  346. return ctx.default_return_type
  347. def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
  348. """Try to infer a better signature type for TypedDict.update."""
  349. signature = ctx.default_signature
  350. if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1:
  351. arg_type = get_proper_type(signature.arg_types[0])
  352. assert isinstance(arg_type, TypedDictType)
  353. arg_type = arg_type.as_anonymous()
  354. arg_type = arg_type.copy_modified(required_keys=set())
  355. if ctx.args and ctx.args[0]:
  356. with ctx.api.msg.filter_errors():
  357. inferred = get_proper_type(
  358. ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type)
  359. )
  360. possible_tds = []
  361. if isinstance(inferred, TypedDictType):
  362. possible_tds = [inferred]
  363. elif isinstance(inferred, UnionType):
  364. possible_tds = [
  365. t
  366. for t in get_proper_types(inferred.relevant_items())
  367. if isinstance(t, TypedDictType)
  368. ]
  369. items = []
  370. for td in possible_tds:
  371. item = arg_type.copy_modified(
  372. required_keys=(arg_type.required_keys | td.required_keys)
  373. & arg_type.items.keys()
  374. )
  375. if not ctx.api.options.extra_checks:
  376. item = item.copy_modified(item_names=list(td.items))
  377. items.append(item)
  378. if items:
  379. arg_type = make_simplified_union(items)
  380. return signature.copy_modified(arg_types=[arg_type])
  381. return signature
  382. def int_pow_callback(ctx: MethodContext) -> Type:
  383. """Infer a more precise return type for int.__pow__."""
  384. # int.__pow__ has an optional modulo argument,
  385. # so we expect 2 argument positions
  386. if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0:
  387. arg = ctx.args[0][0]
  388. if isinstance(arg, IntExpr):
  389. exponent = arg.value
  390. elif isinstance(arg, UnaryExpr) and arg.op == "-" and isinstance(arg.expr, IntExpr):
  391. exponent = -arg.expr.value
  392. else:
  393. # Right operand not an int literal or a negated literal -- give up.
  394. return ctx.default_return_type
  395. if exponent >= 0:
  396. return ctx.api.named_generic_type("builtins.int", [])
  397. else:
  398. return ctx.api.named_generic_type("builtins.float", [])
  399. return ctx.default_return_type
  400. def int_neg_callback(ctx: MethodContext) -> Type:
  401. """Infer a more precise return type for int.__neg__.
  402. This is mainly used to infer the return type as LiteralType
  403. if the original underlying object is a LiteralType object
  404. """
  405. if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
  406. value = ctx.type.last_known_value.value
  407. fallback = ctx.type.last_known_value.fallback
  408. if isinstance(value, int):
  409. if is_literal_type_like(ctx.api.type_context[-1]):
  410. return LiteralType(value=-value, fallback=fallback)
  411. else:
  412. return ctx.type.copy_modified(
  413. last_known_value=LiteralType(
  414. value=-value, fallback=ctx.type, line=ctx.type.line, column=ctx.type.column
  415. )
  416. )
  417. elif isinstance(ctx.type, LiteralType):
  418. value = ctx.type.value
  419. fallback = ctx.type.fallback
  420. if isinstance(value, int):
  421. return LiteralType(value=-value, fallback=fallback)
  422. return ctx.default_return_type
  423. def tuple_mul_callback(ctx: MethodContext) -> Type:
  424. """Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.
  425. This is used to return a specific sized tuple if multiplied by Literal int
  426. """
  427. if not isinstance(ctx.type, TupleType):
  428. return ctx.default_return_type
  429. arg_type = get_proper_type(ctx.arg_types[0][0])
  430. if isinstance(arg_type, Instance) and arg_type.last_known_value is not None:
  431. value = arg_type.last_known_value.value
  432. if isinstance(value, int):
  433. return ctx.type.copy_modified(items=ctx.type.items * value)
  434. elif isinstance(ctx.type, LiteralType):
  435. value = arg_type.value
  436. if isinstance(value, int):
  437. return ctx.type.copy_modified(items=ctx.type.items * value)
  438. return ctx.default_return_type