brain_typing.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
  4. """Astroid hooks for typing.py support."""
  5. from __future__ import annotations
  6. import sys
  7. import typing
  8. from collections.abc import Iterator
  9. from functools import partial
  10. from astroid import context, extract_node, inference_tip
  11. from astroid.builder import _extract_single_node
  12. from astroid.const import PY38_PLUS, PY39_PLUS
  13. from astroid.exceptions import (
  14. AttributeInferenceError,
  15. InferenceError,
  16. UseInferenceDefault,
  17. )
  18. from astroid.manager import AstroidManager
  19. from astroid.nodes.node_classes import (
  20. Assign,
  21. AssignName,
  22. Attribute,
  23. Call,
  24. Const,
  25. JoinedStr,
  26. Name,
  27. NodeNG,
  28. Subscript,
  29. Tuple,
  30. )
  31. from astroid.nodes.scoped_nodes import ClassDef, FunctionDef
  32. if sys.version_info >= (3, 8):
  33. from typing import Final
  34. else:
  35. from typing_extensions import Final
  36. TYPING_TYPEVARS = {"TypeVar", "NewType"}
  37. TYPING_TYPEVARS_QUALIFIED: Final = {
  38. "typing.TypeVar",
  39. "typing.NewType",
  40. "typing_extensions.TypeVar",
  41. }
  42. TYPING_TYPEDDICT_QUALIFIED: Final = {"typing.TypedDict", "typing_extensions.TypedDict"}
  43. TYPING_TYPE_TEMPLATE = """
  44. class Meta(type):
  45. def __getitem__(self, item):
  46. return self
  47. @property
  48. def __args__(self):
  49. return ()
  50. class {0}(metaclass=Meta):
  51. pass
  52. """
  53. TYPING_MEMBERS = set(getattr(typing, "__all__", []))
  54. TYPING_ALIAS = frozenset(
  55. (
  56. "typing.Hashable",
  57. "typing.Awaitable",
  58. "typing.Coroutine",
  59. "typing.AsyncIterable",
  60. "typing.AsyncIterator",
  61. "typing.Iterable",
  62. "typing.Iterator",
  63. "typing.Reversible",
  64. "typing.Sized",
  65. "typing.Container",
  66. "typing.Collection",
  67. "typing.Callable",
  68. "typing.AbstractSet",
  69. "typing.MutableSet",
  70. "typing.Mapping",
  71. "typing.MutableMapping",
  72. "typing.Sequence",
  73. "typing.MutableSequence",
  74. "typing.ByteString",
  75. "typing.Tuple",
  76. "typing.List",
  77. "typing.Deque",
  78. "typing.Set",
  79. "typing.FrozenSet",
  80. "typing.MappingView",
  81. "typing.KeysView",
  82. "typing.ItemsView",
  83. "typing.ValuesView",
  84. "typing.ContextManager",
  85. "typing.AsyncContextManager",
  86. "typing.Dict",
  87. "typing.DefaultDict",
  88. "typing.OrderedDict",
  89. "typing.Counter",
  90. "typing.ChainMap",
  91. "typing.Generator",
  92. "typing.AsyncGenerator",
  93. "typing.Type",
  94. "typing.Pattern",
  95. "typing.Match",
  96. )
  97. )
  98. CLASS_GETITEM_TEMPLATE = """
  99. @classmethod
  100. def __class_getitem__(cls, item):
  101. return cls
  102. """
  103. def looks_like_typing_typevar_or_newtype(node) -> bool:
  104. func = node.func
  105. if isinstance(func, Attribute):
  106. return func.attrname in TYPING_TYPEVARS
  107. if isinstance(func, Name):
  108. return func.name in TYPING_TYPEVARS
  109. return False
  110. def infer_typing_typevar_or_newtype(
  111. node: Call, context_itton: context.InferenceContext | None = None
  112. ) -> Iterator[ClassDef]:
  113. """Infer a typing.TypeVar(...) or typing.NewType(...) call."""
  114. try:
  115. func = next(node.func.infer(context=context_itton))
  116. except (InferenceError, StopIteration) as exc:
  117. raise UseInferenceDefault from exc
  118. if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
  119. raise UseInferenceDefault
  120. if not node.args:
  121. raise UseInferenceDefault
  122. # Cannot infer from a dynamic class name (f-string)
  123. if isinstance(node.args[0], JoinedStr):
  124. raise UseInferenceDefault
  125. typename = node.args[0].as_string().strip("'")
  126. node = ClassDef(
  127. name=typename,
  128. lineno=node.lineno,
  129. col_offset=node.col_offset,
  130. parent=node.parent,
  131. end_lineno=node.end_lineno,
  132. end_col_offset=node.end_col_offset,
  133. )
  134. return node.infer(context=context_itton)
  135. def _looks_like_typing_subscript(node) -> bool:
  136. """Try to figure out if a Subscript node *might* be a typing-related subscript."""
  137. if isinstance(node, Name):
  138. return node.name in TYPING_MEMBERS
  139. if isinstance(node, Attribute):
  140. return node.attrname in TYPING_MEMBERS
  141. if isinstance(node, Subscript):
  142. return _looks_like_typing_subscript(node.value)
  143. return False
  144. def infer_typing_attr(
  145. node: Subscript, ctx: context.InferenceContext | None = None
  146. ) -> Iterator[ClassDef]:
  147. """Infer a typing.X[...] subscript."""
  148. try:
  149. value = next(node.value.infer()) # type: ignore[union-attr] # value shouldn't be None for Subscript.
  150. except (InferenceError, StopIteration) as exc:
  151. raise UseInferenceDefault from exc
  152. if not value.qname().startswith("typing.") or value.qname() in TYPING_ALIAS:
  153. # If typing subscript belongs to an alias handle it separately.
  154. raise UseInferenceDefault
  155. if isinstance(value, ClassDef) and value.qname() in {
  156. "typing.Generic",
  157. "typing.Annotated",
  158. "typing_extensions.Annotated",
  159. }:
  160. # typing.Generic and typing.Annotated (PY39) are subscriptable
  161. # through __class_getitem__. Since astroid can't easily
  162. # infer the native methods, replace them for an easy inference tip
  163. func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
  164. value.locals["__class_getitem__"] = [func_to_add]
  165. if (
  166. isinstance(node.parent, ClassDef)
  167. and node in node.parent.bases
  168. and getattr(node.parent, "__cache", None)
  169. ):
  170. # node.parent.slots is evaluated and cached before the inference tip
  171. # is first applied. Remove the last result to allow a recalculation of slots
  172. cache = node.parent.__cache # type: ignore[attr-defined] # Unrecognized getattr
  173. if cache.get(node.parent.slots) is not None:
  174. del cache[node.parent.slots]
  175. return iter([value])
  176. node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1]))
  177. return node.infer(context=ctx)
  178. def _looks_like_typedDict( # pylint: disable=invalid-name
  179. node: FunctionDef | ClassDef,
  180. ) -> bool:
  181. """Check if node is TypedDict FunctionDef."""
  182. return node.qname() in TYPING_TYPEDDICT_QUALIFIED
  183. def infer_old_typedDict( # pylint: disable=invalid-name
  184. node: ClassDef, ctx: context.InferenceContext | None = None
  185. ) -> Iterator[ClassDef]:
  186. func_to_add = _extract_single_node("dict")
  187. node.locals["__call__"] = [func_to_add]
  188. return iter([node])
  189. def infer_typedDict( # pylint: disable=invalid-name
  190. node: FunctionDef, ctx: context.InferenceContext | None = None
  191. ) -> Iterator[ClassDef]:
  192. """Replace TypedDict FunctionDef with ClassDef."""
  193. class_def = ClassDef(
  194. name="TypedDict",
  195. lineno=node.lineno,
  196. col_offset=node.col_offset,
  197. parent=node.parent,
  198. )
  199. class_def.postinit(bases=[extract_node("dict")], body=[], decorators=None)
  200. func_to_add = _extract_single_node("dict")
  201. class_def.locals["__call__"] = [func_to_add]
  202. return iter([class_def])
  203. def _looks_like_typing_alias(node: Call) -> bool:
  204. """
  205. Returns True if the node corresponds to a call to _alias function.
  206. For example :
  207. MutableSet = _alias(collections.abc.MutableSet, T)
  208. :param node: call node
  209. """
  210. return (
  211. isinstance(node.func, Name)
  212. and node.func.name == "_alias"
  213. and (
  214. # _alias function works also for builtins object such as list and dict
  215. isinstance(node.args[0], (Attribute, Name))
  216. )
  217. )
  218. def _forbid_class_getitem_access(node: ClassDef) -> None:
  219. """Disable the access to __class_getitem__ method for the node in parameters."""
  220. def full_raiser(origin_func, attr, *args, **kwargs):
  221. """
  222. Raises an AttributeInferenceError in case of access to __class_getitem__ method.
  223. Otherwise, just call origin_func.
  224. """
  225. if attr == "__class_getitem__":
  226. raise AttributeInferenceError("__class_getitem__ access is not allowed")
  227. return origin_func(attr, *args, **kwargs)
  228. try:
  229. node.getattr("__class_getitem__")
  230. # If we are here, then we are sure to modify an object that does have
  231. # __class_getitem__ method (which origin is the protocol defined in
  232. # collections module) whereas the typing module considers it should not.
  233. # We do not want __class_getitem__ to be found in the classdef
  234. partial_raiser = partial(full_raiser, node.getattr)
  235. node.getattr = partial_raiser
  236. except AttributeInferenceError:
  237. pass
  238. def infer_typing_alias(
  239. node: Call, ctx: context.InferenceContext | None = None
  240. ) -> Iterator[ClassDef]:
  241. """
  242. Infers the call to _alias function
  243. Insert ClassDef, with same name as aliased class,
  244. in mro to simulate _GenericAlias.
  245. :param node: call node
  246. :param context: inference context
  247. """
  248. if (
  249. not isinstance(node.parent, Assign)
  250. or not len(node.parent.targets) == 1
  251. or not isinstance(node.parent.targets[0], AssignName)
  252. ):
  253. raise UseInferenceDefault
  254. try:
  255. res = next(node.args[0].infer(context=ctx))
  256. except StopIteration as e:
  257. raise InferenceError(node=node.args[0], context=ctx) from e
  258. assign_name = node.parent.targets[0]
  259. class_def = ClassDef(
  260. name=assign_name.name,
  261. lineno=assign_name.lineno,
  262. col_offset=assign_name.col_offset,
  263. parent=node.parent,
  264. )
  265. if isinstance(res, ClassDef):
  266. # Only add `res` as base if it's a `ClassDef`
  267. # This isn't the case for `typing.Pattern` and `typing.Match`
  268. class_def.postinit(bases=[res], body=[], decorators=None)
  269. maybe_type_var = node.args[1]
  270. if (
  271. not PY39_PLUS
  272. and not (isinstance(maybe_type_var, Tuple) and not maybe_type_var.elts)
  273. or PY39_PLUS
  274. and isinstance(maybe_type_var, Const)
  275. and maybe_type_var.value > 0
  276. ):
  277. # If typing alias is subscriptable, add `__class_getitem__` to ClassDef
  278. func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
  279. class_def.locals["__class_getitem__"] = [func_to_add]
  280. else:
  281. # If not, make sure that `__class_getitem__` access is forbidden.
  282. # This is an issue in cases where the aliased class implements it,
  283. # but the typing alias isn't subscriptable. E.g., `typing.ByteString` for PY39+
  284. _forbid_class_getitem_access(class_def)
  285. return iter([class_def])
  286. def _looks_like_special_alias(node: Call) -> bool:
  287. """Return True if call is for Tuple or Callable alias.
  288. In PY37 and PY38 the call is to '_VariadicGenericAlias' with 'tuple' as
  289. first argument. In PY39+ it is replaced by a call to '_TupleType'.
  290. PY37: Tuple = _VariadicGenericAlias(tuple, (), inst=False, special=True)
  291. PY39: Tuple = _TupleType(tuple, -1, inst=False, name='Tuple')
  292. PY37: Callable = _VariadicGenericAlias(collections.abc.Callable, (), special=True)
  293. PY39: Callable = _CallableType(collections.abc.Callable, 2)
  294. """
  295. return isinstance(node.func, Name) and (
  296. not PY39_PLUS
  297. and node.func.name == "_VariadicGenericAlias"
  298. and (
  299. isinstance(node.args[0], Name)
  300. and node.args[0].name == "tuple"
  301. or isinstance(node.args[0], Attribute)
  302. and node.args[0].as_string() == "collections.abc.Callable"
  303. )
  304. or PY39_PLUS
  305. and (
  306. node.func.name == "_TupleType"
  307. and isinstance(node.args[0], Name)
  308. and node.args[0].name == "tuple"
  309. or node.func.name == "_CallableType"
  310. and isinstance(node.args[0], Attribute)
  311. and node.args[0].as_string() == "collections.abc.Callable"
  312. )
  313. )
  314. def infer_special_alias(
  315. node: Call, ctx: context.InferenceContext | None = None
  316. ) -> Iterator[ClassDef]:
  317. """Infer call to tuple alias as new subscriptable class typing.Tuple."""
  318. if not (
  319. isinstance(node.parent, Assign)
  320. and len(node.parent.targets) == 1
  321. and isinstance(node.parent.targets[0], AssignName)
  322. ):
  323. raise UseInferenceDefault
  324. try:
  325. res = next(node.args[0].infer(context=ctx))
  326. except StopIteration as e:
  327. raise InferenceError(node=node.args[0], context=ctx) from e
  328. assign_name = node.parent.targets[0]
  329. class_def = ClassDef(
  330. name=assign_name.name,
  331. parent=node.parent,
  332. )
  333. class_def.postinit(bases=[res], body=[], decorators=None)
  334. func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
  335. class_def.locals["__class_getitem__"] = [func_to_add]
  336. return iter([class_def])
  337. def _looks_like_typing_cast(node: Call) -> bool:
  338. return isinstance(node, Call) and (
  339. isinstance(node.func, Name)
  340. and node.func.name == "cast"
  341. or isinstance(node.func, Attribute)
  342. and node.func.attrname == "cast"
  343. )
  344. def infer_typing_cast(
  345. node: Call, ctx: context.InferenceContext | None = None
  346. ) -> Iterator[NodeNG]:
  347. """Infer call to cast() returning same type as casted-from var."""
  348. if not isinstance(node.func, (Name, Attribute)):
  349. raise UseInferenceDefault
  350. try:
  351. func = next(node.func.infer(context=ctx))
  352. except (InferenceError, StopIteration) as exc:
  353. raise UseInferenceDefault from exc
  354. if (
  355. not isinstance(func, FunctionDef)
  356. or func.qname() != "typing.cast"
  357. or len(node.args) != 2
  358. ):
  359. raise UseInferenceDefault
  360. return node.args[1].infer(context=ctx)
  361. AstroidManager().register_transform(
  362. Call,
  363. inference_tip(infer_typing_typevar_or_newtype),
  364. looks_like_typing_typevar_or_newtype,
  365. )
  366. AstroidManager().register_transform(
  367. Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript
  368. )
  369. AstroidManager().register_transform(
  370. Call, inference_tip(infer_typing_cast), _looks_like_typing_cast
  371. )
  372. if PY39_PLUS:
  373. AstroidManager().register_transform(
  374. FunctionDef, inference_tip(infer_typedDict), _looks_like_typedDict
  375. )
  376. elif PY38_PLUS:
  377. AstroidManager().register_transform(
  378. ClassDef, inference_tip(infer_old_typedDict), _looks_like_typedDict
  379. )
  380. AstroidManager().register_transform(
  381. Call, inference_tip(infer_typing_alias), _looks_like_typing_alias
  382. )
  383. AstroidManager().register_transform(
  384. Call, inference_tip(infer_special_alias), _looks_like_special_alias
  385. )