brain_namedtuple_enum.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
  4. """Astroid hooks for the Python standard library."""
  5. from __future__ import annotations
  6. import functools
  7. import keyword
  8. import sys
  9. from collections.abc import Iterator
  10. from textwrap import dedent
  11. import astroid
  12. from astroid import arguments, bases, inference_tip, nodes, util
  13. from astroid.builder import AstroidBuilder, _extract_single_node, extract_node
  14. from astroid.context import InferenceContext
  15. from astroid.exceptions import (
  16. AstroidTypeError,
  17. AstroidValueError,
  18. InferenceError,
  19. MroError,
  20. UseInferenceDefault,
  21. )
  22. from astroid.manager import AstroidManager
  23. if sys.version_info >= (3, 8):
  24. from typing import Final
  25. else:
  26. from typing_extensions import Final
  27. ENUM_BASE_NAMES = {
  28. "Enum",
  29. "IntEnum",
  30. "enum.Enum",
  31. "enum.IntEnum",
  32. "IntFlag",
  33. "enum.IntFlag",
  34. }
  35. ENUM_QNAME: Final[str] = "enum.Enum"
  36. TYPING_NAMEDTUPLE_QUALIFIED: Final = {
  37. "typing.NamedTuple",
  38. "typing_extensions.NamedTuple",
  39. }
  40. TYPING_NAMEDTUPLE_BASENAMES: Final = {
  41. "NamedTuple",
  42. "typing.NamedTuple",
  43. "typing_extensions.NamedTuple",
  44. }
  45. def _infer_first(node, context):
  46. if isinstance(node, util.UninferableBase):
  47. raise UseInferenceDefault
  48. try:
  49. value = next(node.infer(context=context))
  50. except StopIteration as exc:
  51. raise InferenceError from exc
  52. if isinstance(value, util.UninferableBase):
  53. raise UseInferenceDefault()
  54. return value
  55. def _find_func_form_arguments(node, context):
  56. def _extract_namedtuple_arg_or_keyword( # pylint: disable=inconsistent-return-statements
  57. position, key_name=None
  58. ):
  59. if len(args) > position:
  60. return _infer_first(args[position], context)
  61. if key_name and key_name in found_keywords:
  62. return _infer_first(found_keywords[key_name], context)
  63. args = node.args
  64. keywords = node.keywords
  65. found_keywords = (
  66. {keyword.arg: keyword.value for keyword in keywords} if keywords else {}
  67. )
  68. name = _extract_namedtuple_arg_or_keyword(position=0, key_name="typename")
  69. names = _extract_namedtuple_arg_or_keyword(position=1, key_name="field_names")
  70. if name and names:
  71. return name.value, names
  72. raise UseInferenceDefault()
  73. def infer_func_form(
  74. node: nodes.Call,
  75. base_type: list[nodes.NodeNG],
  76. context: InferenceContext | None = None,
  77. enum: bool = False,
  78. ) -> tuple[nodes.ClassDef, str, list[str]]:
  79. """Specific inference function for namedtuple or Python 3 enum."""
  80. # node is a Call node, class name as first argument and generated class
  81. # attributes as second argument
  82. # namedtuple or enums list of attributes can be a list of strings or a
  83. # whitespace-separate string
  84. try:
  85. name, names = _find_func_form_arguments(node, context)
  86. try:
  87. attributes: list[str] = names.value.replace(",", " ").split()
  88. except AttributeError as exc:
  89. # Handle attributes of NamedTuples
  90. if not enum:
  91. attributes = []
  92. fields = _get_namedtuple_fields(node)
  93. if fields:
  94. fields_node = extract_node(fields)
  95. attributes = [
  96. _infer_first(const, context).value for const in fields_node.elts
  97. ]
  98. # Handle attributes of Enums
  99. else:
  100. # Enums supports either iterator of (name, value) pairs
  101. # or mappings.
  102. if hasattr(names, "items") and isinstance(names.items, list):
  103. attributes = [
  104. _infer_first(const[0], context).value
  105. for const in names.items
  106. if isinstance(const[0], nodes.Const)
  107. ]
  108. elif hasattr(names, "elts"):
  109. # Enums can support either ["a", "b", "c"]
  110. # or [("a", 1), ("b", 2), ...], but they can't
  111. # be mixed.
  112. if all(isinstance(const, nodes.Tuple) for const in names.elts):
  113. attributes = [
  114. _infer_first(const.elts[0], context).value
  115. for const in names.elts
  116. if isinstance(const, nodes.Tuple)
  117. ]
  118. else:
  119. attributes = [
  120. _infer_first(const, context).value for const in names.elts
  121. ]
  122. else:
  123. raise AttributeError from exc
  124. if not attributes:
  125. raise AttributeError from exc
  126. except (AttributeError, InferenceError) as exc:
  127. raise UseInferenceDefault from exc
  128. if not enum:
  129. # namedtuple maps sys.intern(str()) over over field_names
  130. attributes = [str(attr) for attr in attributes]
  131. # XXX this should succeed *unless* __str__/__repr__ is incorrect or throws
  132. # in which case we should not have inferred these values and raised earlier
  133. attributes = [attr for attr in attributes if " " not in attr]
  134. # If we can't infer the name of the class, don't crash, up to this point
  135. # we know it is a namedtuple anyway.
  136. name = name or "Uninferable"
  137. # we want to return a Class node instance with proper attributes set
  138. class_node = nodes.ClassDef(name)
  139. # A typical ClassDef automatically adds its name to the parent scope,
  140. # but doing so causes problems, so defer setting parent until after init
  141. # see: https://github.com/PyCQA/pylint/issues/5982
  142. class_node.parent = node.parent
  143. class_node.postinit(
  144. # set base class=tuple
  145. bases=base_type,
  146. body=[],
  147. decorators=None,
  148. )
  149. # XXX add __init__(*attributes) method
  150. for attr in attributes:
  151. fake_node = nodes.EmptyNode()
  152. fake_node.parent = class_node
  153. fake_node.attrname = attr
  154. class_node.instance_attrs[attr] = [fake_node]
  155. return class_node, name, attributes
  156. def _has_namedtuple_base(node):
  157. """Predicate for class inference tip.
  158. :type node: ClassDef
  159. :rtype: bool
  160. """
  161. return set(node.basenames) & TYPING_NAMEDTUPLE_BASENAMES
  162. def _looks_like(node, name) -> bool:
  163. func = node.func
  164. if isinstance(func, nodes.Attribute):
  165. return func.attrname == name
  166. if isinstance(func, nodes.Name):
  167. return func.name == name
  168. return False
  169. _looks_like_namedtuple = functools.partial(_looks_like, name="namedtuple")
  170. _looks_like_enum = functools.partial(_looks_like, name="Enum")
  171. _looks_like_typing_namedtuple = functools.partial(_looks_like, name="NamedTuple")
  172. def infer_named_tuple(
  173. node: nodes.Call, context: InferenceContext | None = None
  174. ) -> Iterator[nodes.ClassDef]:
  175. """Specific inference function for namedtuple Call node."""
  176. tuple_base_name: list[nodes.NodeNG] = [nodes.Name(name="tuple", parent=node.root())]
  177. class_node, name, attributes = infer_func_form(
  178. node, tuple_base_name, context=context
  179. )
  180. call_site = arguments.CallSite.from_call(node, context=context)
  181. node = extract_node("import collections; collections.namedtuple")
  182. try:
  183. func = next(node.infer())
  184. except StopIteration as e:
  185. raise InferenceError(node=node) from e
  186. try:
  187. rename = next(call_site.infer_argument(func, "rename", context)).bool_value()
  188. except (InferenceError, StopIteration):
  189. rename = False
  190. try:
  191. attributes = _check_namedtuple_attributes(name, attributes, rename)
  192. except AstroidTypeError as exc:
  193. raise UseInferenceDefault("TypeError: " + str(exc)) from exc
  194. except AstroidValueError as exc:
  195. raise UseInferenceDefault("ValueError: " + str(exc)) from exc
  196. replace_args = ", ".join(f"{arg}=None" for arg in attributes)
  197. field_def = (
  198. " {name} = property(lambda self: self[{index:d}], "
  199. "doc='Alias for field number {index:d}')"
  200. )
  201. field_defs = "\n".join(
  202. field_def.format(name=name, index=index)
  203. for index, name in enumerate(attributes)
  204. )
  205. fake = AstroidBuilder(AstroidManager()).string_build(
  206. f"""
  207. class {name}(tuple):
  208. __slots__ = ()
  209. _fields = {attributes!r}
  210. def _asdict(self):
  211. return self.__dict__
  212. @classmethod
  213. def _make(cls, iterable, new=tuple.__new__, len=len):
  214. return new(cls, iterable)
  215. def _replace(self, {replace_args}):
  216. return self
  217. def __getnewargs__(self):
  218. return tuple(self)
  219. {field_defs}
  220. """
  221. )
  222. class_node.locals["_asdict"] = fake.body[0].locals["_asdict"]
  223. class_node.locals["_make"] = fake.body[0].locals["_make"]
  224. class_node.locals["_replace"] = fake.body[0].locals["_replace"]
  225. class_node.locals["_fields"] = fake.body[0].locals["_fields"]
  226. for attr in attributes:
  227. class_node.locals[attr] = fake.body[0].locals[attr]
  228. # we use UseInferenceDefault, we can't be a generator so return an iterator
  229. return iter([class_node])
  230. def _get_renamed_namedtuple_attributes(field_names):
  231. names = list(field_names)
  232. seen = set()
  233. for i, name in enumerate(field_names):
  234. if (
  235. not all(c.isalnum() or c == "_" for c in name)
  236. or keyword.iskeyword(name)
  237. or not name
  238. or name[0].isdigit()
  239. or name.startswith("_")
  240. or name in seen
  241. ):
  242. names[i] = "_%d" % i
  243. seen.add(name)
  244. return tuple(names)
  245. def _check_namedtuple_attributes(typename, attributes, rename=False):
  246. attributes = tuple(attributes)
  247. if rename:
  248. attributes = _get_renamed_namedtuple_attributes(attributes)
  249. # The following snippet is derived from the CPython Lib/collections/__init__.py sources
  250. # <snippet>
  251. for name in (typename,) + attributes:
  252. if not isinstance(name, str):
  253. raise AstroidTypeError("Type names and field names must be strings")
  254. if not name.isidentifier():
  255. raise AstroidValueError(
  256. "Type names and field names must be valid" + f"identifiers: {name!r}"
  257. )
  258. if keyword.iskeyword(name):
  259. raise AstroidValueError(
  260. f"Type names and field names cannot be a keyword: {name!r}"
  261. )
  262. seen = set()
  263. for name in attributes:
  264. if name.startswith("_") and not rename:
  265. raise AstroidValueError(
  266. f"Field names cannot start with an underscore: {name!r}"
  267. )
  268. if name in seen:
  269. raise AstroidValueError(f"Encountered duplicate field name: {name!r}")
  270. seen.add(name)
  271. # </snippet>
  272. return attributes
  273. def infer_enum(
  274. node: nodes.Call, context: InferenceContext | None = None
  275. ) -> Iterator[bases.Instance]:
  276. """Specific inference function for enum Call node."""
  277. # Raise `UseInferenceDefault` if `node` is a call to a a user-defined Enum.
  278. try:
  279. inferred = node.func.infer(context)
  280. except (InferenceError, StopIteration) as exc:
  281. raise UseInferenceDefault from exc
  282. if not any(
  283. isinstance(item, nodes.ClassDef) and item.qname() == ENUM_QNAME
  284. for item in inferred
  285. ):
  286. raise UseInferenceDefault
  287. enum_meta = _extract_single_node(
  288. """
  289. class EnumMeta(object):
  290. 'docstring'
  291. def __call__(self, node):
  292. class EnumAttribute(object):
  293. name = ''
  294. value = 0
  295. return EnumAttribute()
  296. def __iter__(self):
  297. class EnumAttribute(object):
  298. name = ''
  299. value = 0
  300. return [EnumAttribute()]
  301. def __reversed__(self):
  302. class EnumAttribute(object):
  303. name = ''
  304. value = 0
  305. return (EnumAttribute, )
  306. def __next__(self):
  307. return next(iter(self))
  308. def __getitem__(self, attr):
  309. class Value(object):
  310. @property
  311. def name(self):
  312. return ''
  313. @property
  314. def value(self):
  315. return attr
  316. return Value()
  317. __members__ = ['']
  318. """
  319. )
  320. class_node = infer_func_form(node, [enum_meta], context=context, enum=True)[0]
  321. return iter([class_node.instantiate_class()])
  322. INT_FLAG_ADDITION_METHODS = """
  323. def __or__(self, other):
  324. return {name}(self.value | other.value)
  325. def __and__(self, other):
  326. return {name}(self.value & other.value)
  327. def __xor__(self, other):
  328. return {name}(self.value ^ other.value)
  329. def __add__(self, other):
  330. return {name}(self.value + other.value)
  331. def __div__(self, other):
  332. return {name}(self.value / other.value)
  333. def __invert__(self):
  334. return {name}(~self.value)
  335. def __mul__(self, other):
  336. return {name}(self.value * other.value)
  337. """
  338. def infer_enum_class(node: nodes.ClassDef) -> nodes.ClassDef:
  339. """Specific inference for enums."""
  340. for basename in (b for cls in node.mro() for b in cls.basenames):
  341. if node.root().name == "enum":
  342. # Skip if the class is directly from enum module.
  343. break
  344. dunder_members = {}
  345. target_names = set()
  346. for local, values in node.locals.items():
  347. if any(not isinstance(value, nodes.AssignName) for value in values):
  348. continue
  349. stmt = values[0].statement(future=True)
  350. if isinstance(stmt, nodes.Assign):
  351. if isinstance(stmt.targets[0], nodes.Tuple):
  352. targets = stmt.targets[0].itered()
  353. else:
  354. targets = stmt.targets
  355. elif isinstance(stmt, nodes.AnnAssign):
  356. targets = [stmt.target]
  357. else:
  358. continue
  359. inferred_return_value = None
  360. if stmt.value is not None:
  361. if isinstance(stmt.value, nodes.Const):
  362. if isinstance(stmt.value.value, str):
  363. inferred_return_value = repr(stmt.value.value)
  364. else:
  365. inferred_return_value = stmt.value.value
  366. else:
  367. inferred_return_value = stmt.value.as_string()
  368. new_targets = []
  369. for target in targets:
  370. if isinstance(target, nodes.Starred):
  371. continue
  372. target_names.add(target.name)
  373. # Replace all the assignments with our mocked class.
  374. classdef = dedent(
  375. """
  376. class {name}({types}):
  377. @property
  378. def value(self):
  379. return {return_value}
  380. @property
  381. def name(self):
  382. return "{name}"
  383. """.format(
  384. name=target.name,
  385. types=", ".join(node.basenames),
  386. return_value=inferred_return_value,
  387. )
  388. )
  389. if "IntFlag" in basename:
  390. # Alright, we need to add some additional methods.
  391. # Unfortunately we still can't infer the resulting objects as
  392. # Enum members, but once we'll be able to do that, the following
  393. # should result in some nice symbolic execution
  394. classdef += INT_FLAG_ADDITION_METHODS.format(name=target.name)
  395. fake = AstroidBuilder(
  396. AstroidManager(), apply_transforms=False
  397. ).string_build(classdef)[target.name]
  398. fake.parent = target.parent
  399. for method in node.mymethods():
  400. fake.locals[method.name] = [method]
  401. new_targets.append(fake.instantiate_class())
  402. dunder_members[local] = fake
  403. node.locals[local] = new_targets
  404. # The undocumented `_value2member_map_` member:
  405. node.locals["_value2member_map_"] = [nodes.Dict(parent=node)]
  406. members = nodes.Dict(parent=node)
  407. members.postinit(
  408. [
  409. (nodes.Const(k, parent=members), nodes.Name(v.name, parent=members))
  410. for k, v in dunder_members.items()
  411. ]
  412. )
  413. node.locals["__members__"] = [members]
  414. # The enum.Enum class itself defines two @DynamicClassAttribute data-descriptors
  415. # "name" and "value" (which we override in the mocked class for each enum member
  416. # above). When dealing with inference of an arbitrary instance of the enum
  417. # class, e.g. in a method defined in the class body like:
  418. # class SomeEnum(enum.Enum):
  419. # def method(self):
  420. # self.name # <- here
  421. # In the absence of an enum member called "name" or "value", these attributes
  422. # should resolve to the descriptor on that particular instance, i.e. enum member.
  423. # For "value", we have no idea what that should be, but for "name", we at least
  424. # know that it should be a string, so infer that as a guess.
  425. if "name" not in target_names:
  426. code = dedent(
  427. """
  428. @property
  429. def name(self):
  430. return ''
  431. """
  432. )
  433. name_dynamicclassattr = AstroidBuilder(AstroidManager()).string_build(code)[
  434. "name"
  435. ]
  436. node.locals["name"] = [name_dynamicclassattr]
  437. break
  438. return node
  439. def infer_typing_namedtuple_class(class_node, context: InferenceContext | None = None):
  440. """Infer a subclass of typing.NamedTuple."""
  441. # Check if it has the corresponding bases
  442. annassigns_fields = [
  443. annassign.target.name
  444. for annassign in class_node.body
  445. if isinstance(annassign, nodes.AnnAssign)
  446. ]
  447. code = dedent(
  448. """
  449. from collections import namedtuple
  450. namedtuple({typename!r}, {fields!r})
  451. """
  452. ).format(typename=class_node.name, fields=",".join(annassigns_fields))
  453. node = extract_node(code)
  454. try:
  455. generated_class_node = next(infer_named_tuple(node, context))
  456. except StopIteration as e:
  457. raise InferenceError(node=node, context=context) from e
  458. for method in class_node.mymethods():
  459. generated_class_node.locals[method.name] = [method]
  460. for body_node in class_node.body:
  461. if isinstance(body_node, nodes.Assign):
  462. for target in body_node.targets:
  463. attr = target.name
  464. generated_class_node.locals[attr] = class_node.locals[attr]
  465. elif isinstance(body_node, nodes.ClassDef):
  466. generated_class_node.locals[body_node.name] = [body_node]
  467. return iter((generated_class_node,))
  468. def infer_typing_namedtuple_function(node, context: InferenceContext | None = None):
  469. """
  470. Starting with python3.9, NamedTuple is a function of the typing module.
  471. The class NamedTuple is build dynamically through a call to `type` during
  472. initialization of the `_NamedTuple` variable.
  473. """
  474. klass = extract_node(
  475. """
  476. from typing import _NamedTuple
  477. _NamedTuple
  478. """
  479. )
  480. return klass.infer(context)
  481. def infer_typing_namedtuple(
  482. node: nodes.Call, context: InferenceContext | None = None
  483. ) -> Iterator[nodes.ClassDef]:
  484. """Infer a typing.NamedTuple(...) call."""
  485. # This is essentially a namedtuple with different arguments
  486. # so we extract the args and infer a named tuple.
  487. try:
  488. func = next(node.func.infer())
  489. except (InferenceError, StopIteration) as exc:
  490. raise UseInferenceDefault from exc
  491. if func.qname() not in TYPING_NAMEDTUPLE_QUALIFIED:
  492. raise UseInferenceDefault
  493. if len(node.args) != 2:
  494. raise UseInferenceDefault
  495. if not isinstance(node.args[1], (nodes.List, nodes.Tuple)):
  496. raise UseInferenceDefault
  497. return infer_named_tuple(node, context)
  498. def _get_namedtuple_fields(node: nodes.Call) -> str:
  499. """Get and return fields of a NamedTuple in code-as-a-string.
  500. Because the fields are represented in their code form we can
  501. extract a node from them later on.
  502. """
  503. names = []
  504. container = None
  505. try:
  506. container = next(node.args[1].infer())
  507. except (InferenceError, StopIteration) as exc:
  508. raise UseInferenceDefault from exc
  509. # We pass on IndexError as we'll try to infer 'field_names' from the keywords
  510. except IndexError:
  511. pass
  512. if not container:
  513. for keyword_node in node.keywords:
  514. if keyword_node.arg == "field_names":
  515. try:
  516. container = next(keyword_node.value.infer())
  517. except (InferenceError, StopIteration) as exc:
  518. raise UseInferenceDefault from exc
  519. break
  520. if not isinstance(container, nodes.BaseContainer):
  521. raise UseInferenceDefault
  522. for elt in container.elts:
  523. if isinstance(elt, nodes.Const):
  524. names.append(elt.as_string())
  525. continue
  526. if not isinstance(elt, (nodes.List, nodes.Tuple)):
  527. raise UseInferenceDefault
  528. if len(elt.elts) != 2:
  529. raise UseInferenceDefault
  530. names.append(elt.elts[0].as_string())
  531. if names:
  532. field_names = f"({','.join(names)},)"
  533. else:
  534. field_names = ""
  535. return field_names
  536. def _is_enum_subclass(cls: astroid.ClassDef) -> bool:
  537. """Return whether cls is a subclass of an Enum."""
  538. try:
  539. return any(
  540. klass.name in ENUM_BASE_NAMES
  541. and getattr(klass.root(), "name", None) == "enum"
  542. for klass in cls.mro()
  543. )
  544. except MroError:
  545. return False
  546. AstroidManager().register_transform(
  547. nodes.Call, inference_tip(infer_named_tuple), _looks_like_namedtuple
  548. )
  549. AstroidManager().register_transform(
  550. nodes.Call, inference_tip(infer_enum), _looks_like_enum
  551. )
  552. AstroidManager().register_transform(
  553. nodes.ClassDef, infer_enum_class, predicate=_is_enum_subclass
  554. )
  555. AstroidManager().register_transform(
  556. nodes.ClassDef, inference_tip(infer_typing_namedtuple_class), _has_namedtuple_base
  557. )
  558. AstroidManager().register_transform(
  559. nodes.FunctionDef,
  560. inference_tip(infer_typing_namedtuple_function),
  561. lambda node: node.name == "NamedTuple"
  562. and getattr(node.root(), "name", None) == "typing",
  563. )
  564. AstroidManager().register_transform(
  565. nodes.Call, inference_tip(infer_typing_namedtuple), _looks_like_typing_namedtuple
  566. )