brain_namedtuple_enum.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
  4. """Astroid hooks for the Python standard library."""
  5. from __future__ import annotations
  6. import functools
  7. import keyword
  8. import sys
  9. from collections.abc import Iterator
  10. from textwrap import dedent
  11. import astroid
  12. from astroid import arguments, bases, inference_tip, nodes, util
  13. from astroid.builder import AstroidBuilder, _extract_single_node, extract_node
  14. from astroid.context import InferenceContext
  15. from astroid.exceptions import (
  16. AstroidTypeError,
  17. AstroidValueError,
  18. InferenceError,
  19. UseInferenceDefault,
  20. )
  21. from astroid.manager import AstroidManager
  22. if sys.version_info >= (3, 8):
  23. from typing import Final
  24. else:
  25. from typing_extensions import Final
  26. ENUM_QNAME: Final[str] = "enum.Enum"
  27. TYPING_NAMEDTUPLE_QUALIFIED: Final = {
  28. "typing.NamedTuple",
  29. "typing_extensions.NamedTuple",
  30. }
  31. TYPING_NAMEDTUPLE_BASENAMES: Final = {
  32. "NamedTuple",
  33. "typing.NamedTuple",
  34. "typing_extensions.NamedTuple",
  35. }
  36. def _infer_first(node, context):
  37. if isinstance(node, util.UninferableBase):
  38. raise UseInferenceDefault
  39. try:
  40. value = next(node.infer(context=context))
  41. except StopIteration as exc:
  42. raise InferenceError from exc
  43. if isinstance(value, util.UninferableBase):
  44. raise UseInferenceDefault()
  45. return value
  46. def _find_func_form_arguments(node, context):
  47. def _extract_namedtuple_arg_or_keyword( # pylint: disable=inconsistent-return-statements
  48. position, key_name=None
  49. ):
  50. if len(args) > position:
  51. return _infer_first(args[position], context)
  52. if key_name and key_name in found_keywords:
  53. return _infer_first(found_keywords[key_name], context)
  54. args = node.args
  55. keywords = node.keywords
  56. found_keywords = (
  57. {keyword.arg: keyword.value for keyword in keywords} if keywords else {}
  58. )
  59. name = _extract_namedtuple_arg_or_keyword(position=0, key_name="typename")
  60. names = _extract_namedtuple_arg_or_keyword(position=1, key_name="field_names")
  61. if name and names:
  62. return name.value, names
  63. raise UseInferenceDefault()
  64. def infer_func_form(
  65. node: nodes.Call,
  66. base_type: list[nodes.NodeNG],
  67. context: InferenceContext | None = None,
  68. enum: bool = False,
  69. ) -> tuple[nodes.ClassDef, str, list[str]]:
  70. """Specific inference function for namedtuple or Python 3 enum."""
  71. # node is a Call node, class name as first argument and generated class
  72. # attributes as second argument
  73. # namedtuple or enums list of attributes can be a list of strings or a
  74. # whitespace-separate string
  75. try:
  76. name, names = _find_func_form_arguments(node, context)
  77. try:
  78. attributes: list[str] = names.value.replace(",", " ").split()
  79. except AttributeError as exc:
  80. # Handle attributes of NamedTuples
  81. if not enum:
  82. attributes = []
  83. fields = _get_namedtuple_fields(node)
  84. if fields:
  85. fields_node = extract_node(fields)
  86. attributes = [
  87. _infer_first(const, context).value for const in fields_node.elts
  88. ]
  89. # Handle attributes of Enums
  90. else:
  91. # Enums supports either iterator of (name, value) pairs
  92. # or mappings.
  93. if hasattr(names, "items") and isinstance(names.items, list):
  94. attributes = [
  95. _infer_first(const[0], context).value
  96. for const in names.items
  97. if isinstance(const[0], nodes.Const)
  98. ]
  99. elif hasattr(names, "elts"):
  100. # Enums can support either ["a", "b", "c"]
  101. # or [("a", 1), ("b", 2), ...], but they can't
  102. # be mixed.
  103. if all(isinstance(const, nodes.Tuple) for const in names.elts):
  104. attributes = [
  105. _infer_first(const.elts[0], context).value
  106. for const in names.elts
  107. if isinstance(const, nodes.Tuple)
  108. ]
  109. else:
  110. attributes = [
  111. _infer_first(const, context).value for const in names.elts
  112. ]
  113. else:
  114. raise AttributeError from exc
  115. if not attributes:
  116. raise AttributeError from exc
  117. except (AttributeError, InferenceError) as exc:
  118. raise UseInferenceDefault from exc
  119. if not enum:
  120. # namedtuple maps sys.intern(str()) over over field_names
  121. attributes = [str(attr) for attr in attributes]
  122. # XXX this should succeed *unless* __str__/__repr__ is incorrect or throws
  123. # in which case we should not have inferred these values and raised earlier
  124. attributes = [attr for attr in attributes if " " not in attr]
  125. # If we can't infer the name of the class, don't crash, up to this point
  126. # we know it is a namedtuple anyway.
  127. name = name or "Uninferable"
  128. # we want to return a Class node instance with proper attributes set
  129. class_node = nodes.ClassDef(name)
  130. # A typical ClassDef automatically adds its name to the parent scope,
  131. # but doing so causes problems, so defer setting parent until after init
  132. # see: https://github.com/PyCQA/pylint/issues/5982
  133. class_node.parent = node.parent
  134. class_node.postinit(
  135. # set base class=tuple
  136. bases=base_type,
  137. body=[],
  138. decorators=None,
  139. )
  140. # XXX add __init__(*attributes) method
  141. for attr in attributes:
  142. fake_node = nodes.EmptyNode()
  143. fake_node.parent = class_node
  144. fake_node.attrname = attr
  145. class_node.instance_attrs[attr] = [fake_node]
  146. return class_node, name, attributes
  147. def _has_namedtuple_base(node):
  148. """Predicate for class inference tip.
  149. :type node: ClassDef
  150. :rtype: bool
  151. """
  152. return set(node.basenames) & TYPING_NAMEDTUPLE_BASENAMES
  153. def _looks_like(node, name) -> bool:
  154. func = node.func
  155. if isinstance(func, nodes.Attribute):
  156. return func.attrname == name
  157. if isinstance(func, nodes.Name):
  158. return func.name == name
  159. return False
  160. _looks_like_namedtuple = functools.partial(_looks_like, name="namedtuple")
  161. _looks_like_enum = functools.partial(_looks_like, name="Enum")
  162. _looks_like_typing_namedtuple = functools.partial(_looks_like, name="NamedTuple")
  163. def infer_named_tuple(
  164. node: nodes.Call, context: InferenceContext | None = None
  165. ) -> Iterator[nodes.ClassDef]:
  166. """Specific inference function for namedtuple Call node."""
  167. tuple_base_name: list[nodes.NodeNG] = [nodes.Name(name="tuple", parent=node.root())]
  168. class_node, name, attributes = infer_func_form(
  169. node, tuple_base_name, context=context
  170. )
  171. call_site = arguments.CallSite.from_call(node, context=context)
  172. node = extract_node("import collections; collections.namedtuple")
  173. try:
  174. func = next(node.infer())
  175. except StopIteration as e:
  176. raise InferenceError(node=node) from e
  177. try:
  178. rename = next(call_site.infer_argument(func, "rename", context)).bool_value()
  179. except (InferenceError, StopIteration):
  180. rename = False
  181. try:
  182. attributes = _check_namedtuple_attributes(name, attributes, rename)
  183. except AstroidTypeError as exc:
  184. raise UseInferenceDefault("TypeError: " + str(exc)) from exc
  185. except AstroidValueError as exc:
  186. raise UseInferenceDefault("ValueError: " + str(exc)) from exc
  187. replace_args = ", ".join(f"{arg}=None" for arg in attributes)
  188. field_def = (
  189. " {name} = property(lambda self: self[{index:d}], "
  190. "doc='Alias for field number {index:d}')"
  191. )
  192. field_defs = "\n".join(
  193. field_def.format(name=name, index=index)
  194. for index, name in enumerate(attributes)
  195. )
  196. fake = AstroidBuilder(AstroidManager()).string_build(
  197. f"""
  198. class {name}(tuple):
  199. __slots__ = ()
  200. _fields = {attributes!r}
  201. def _asdict(self):
  202. return self.__dict__
  203. @classmethod
  204. def _make(cls, iterable, new=tuple.__new__, len=len):
  205. return new(cls, iterable)
  206. def _replace(self, {replace_args}):
  207. return self
  208. def __getnewargs__(self):
  209. return tuple(self)
  210. {field_defs}
  211. """
  212. )
  213. class_node.locals["_asdict"] = fake.body[0].locals["_asdict"]
  214. class_node.locals["_make"] = fake.body[0].locals["_make"]
  215. class_node.locals["_replace"] = fake.body[0].locals["_replace"]
  216. class_node.locals["_fields"] = fake.body[0].locals["_fields"]
  217. for attr in attributes:
  218. class_node.locals[attr] = fake.body[0].locals[attr]
  219. # we use UseInferenceDefault, we can't be a generator so return an iterator
  220. return iter([class_node])
  221. def _get_renamed_namedtuple_attributes(field_names):
  222. names = list(field_names)
  223. seen = set()
  224. for i, name in enumerate(field_names):
  225. if (
  226. not all(c.isalnum() or c == "_" for c in name)
  227. or keyword.iskeyword(name)
  228. or not name
  229. or name[0].isdigit()
  230. or name.startswith("_")
  231. or name in seen
  232. ):
  233. names[i] = "_%d" % i
  234. seen.add(name)
  235. return tuple(names)
  236. def _check_namedtuple_attributes(typename, attributes, rename=False):
  237. attributes = tuple(attributes)
  238. if rename:
  239. attributes = _get_renamed_namedtuple_attributes(attributes)
  240. # The following snippet is derived from the CPython Lib/collections/__init__.py sources
  241. # <snippet>
  242. for name in (typename,) + attributes:
  243. if not isinstance(name, str):
  244. raise AstroidTypeError("Type names and field names must be strings")
  245. if not name.isidentifier():
  246. raise AstroidValueError(
  247. "Type names and field names must be valid" + f"identifiers: {name!r}"
  248. )
  249. if keyword.iskeyword(name):
  250. raise AstroidValueError(
  251. f"Type names and field names cannot be a keyword: {name!r}"
  252. )
  253. seen = set()
  254. for name in attributes:
  255. if name.startswith("_") and not rename:
  256. raise AstroidValueError(
  257. f"Field names cannot start with an underscore: {name!r}"
  258. )
  259. if name in seen:
  260. raise AstroidValueError(f"Encountered duplicate field name: {name!r}")
  261. seen.add(name)
  262. # </snippet>
  263. return attributes
  264. def infer_enum(
  265. node: nodes.Call, context: InferenceContext | None = None
  266. ) -> Iterator[bases.Instance]:
  267. """Specific inference function for enum Call node."""
  268. # Raise `UseInferenceDefault` if `node` is a call to a a user-defined Enum.
  269. try:
  270. inferred = node.func.infer(context)
  271. except (InferenceError, StopIteration) as exc:
  272. raise UseInferenceDefault from exc
  273. if not any(
  274. isinstance(item, nodes.ClassDef) and item.qname() == ENUM_QNAME
  275. for item in inferred
  276. ):
  277. raise UseInferenceDefault
  278. enum_meta = _extract_single_node(
  279. """
  280. class EnumMeta(object):
  281. 'docstring'
  282. def __call__(self, node):
  283. class EnumAttribute(object):
  284. name = ''
  285. value = 0
  286. return EnumAttribute()
  287. def __iter__(self):
  288. class EnumAttribute(object):
  289. name = ''
  290. value = 0
  291. return [EnumAttribute()]
  292. def __reversed__(self):
  293. class EnumAttribute(object):
  294. name = ''
  295. value = 0
  296. return (EnumAttribute, )
  297. def __next__(self):
  298. return next(iter(self))
  299. def __getitem__(self, attr):
  300. class Value(object):
  301. @property
  302. def name(self):
  303. return ''
  304. @property
  305. def value(self):
  306. return attr
  307. return Value()
  308. __members__ = ['']
  309. """
  310. )
  311. class_node = infer_func_form(node, [enum_meta], context=context, enum=True)[0]
  312. return iter([class_node.instantiate_class()])
  313. INT_FLAG_ADDITION_METHODS = """
  314. def __or__(self, other):
  315. return {name}(self.value | other.value)
  316. def __and__(self, other):
  317. return {name}(self.value & other.value)
  318. def __xor__(self, other):
  319. return {name}(self.value ^ other.value)
  320. def __add__(self, other):
  321. return {name}(self.value + other.value)
  322. def __div__(self, other):
  323. return {name}(self.value / other.value)
  324. def __invert__(self):
  325. return {name}(~self.value)
  326. def __mul__(self, other):
  327. return {name}(self.value * other.value)
  328. """
  329. def infer_enum_class(node: nodes.ClassDef) -> nodes.ClassDef:
  330. """Specific inference for enums."""
  331. for basename in (b for cls in node.mro() for b in cls.basenames):
  332. if node.root().name == "enum":
  333. # Skip if the class is directly from enum module.
  334. break
  335. dunder_members = {}
  336. target_names = set()
  337. for local, values in node.locals.items():
  338. if any(not isinstance(value, nodes.AssignName) for value in values):
  339. continue
  340. stmt = values[0].statement(future=True)
  341. if isinstance(stmt, nodes.Assign):
  342. if isinstance(stmt.targets[0], nodes.Tuple):
  343. targets = stmt.targets[0].itered()
  344. else:
  345. targets = stmt.targets
  346. elif isinstance(stmt, nodes.AnnAssign):
  347. targets = [stmt.target]
  348. else:
  349. continue
  350. inferred_return_value = None
  351. if stmt.value is not None:
  352. if isinstance(stmt.value, nodes.Const):
  353. if isinstance(stmt.value.value, str):
  354. inferred_return_value = repr(stmt.value.value)
  355. else:
  356. inferred_return_value = stmt.value.value
  357. else:
  358. inferred_return_value = stmt.value.as_string()
  359. new_targets = []
  360. for target in targets:
  361. if isinstance(target, nodes.Starred):
  362. continue
  363. target_names.add(target.name)
  364. # Replace all the assignments with our mocked class.
  365. classdef = dedent(
  366. """
  367. class {name}({types}):
  368. @property
  369. def value(self):
  370. return {return_value}
  371. @property
  372. def name(self):
  373. return "{name}"
  374. """.format(
  375. name=target.name,
  376. types=", ".join(node.basenames),
  377. return_value=inferred_return_value,
  378. )
  379. )
  380. if "IntFlag" in basename:
  381. # Alright, we need to add some additional methods.
  382. # Unfortunately we still can't infer the resulting objects as
  383. # Enum members, but once we'll be able to do that, the following
  384. # should result in some nice symbolic execution
  385. classdef += INT_FLAG_ADDITION_METHODS.format(name=target.name)
  386. fake = AstroidBuilder(
  387. AstroidManager(), apply_transforms=False
  388. ).string_build(classdef)[target.name]
  389. fake.parent = target.parent
  390. for method in node.mymethods():
  391. fake.locals[method.name] = [method]
  392. new_targets.append(fake.instantiate_class())
  393. dunder_members[local] = fake
  394. node.locals[local] = new_targets
  395. # The undocumented `_value2member_map_` member:
  396. node.locals["_value2member_map_"] = [nodes.Dict(parent=node)]
  397. members = nodes.Dict(parent=node)
  398. members.postinit(
  399. [
  400. (nodes.Const(k, parent=members), nodes.Name(v.name, parent=members))
  401. for k, v in dunder_members.items()
  402. ]
  403. )
  404. node.locals["__members__"] = [members]
  405. # The enum.Enum class itself defines two @DynamicClassAttribute data-descriptors
  406. # "name" and "value" (which we override in the mocked class for each enum member
  407. # above). When dealing with inference of an arbitrary instance of the enum
  408. # class, e.g. in a method defined in the class body like:
  409. # class SomeEnum(enum.Enum):
  410. # def method(self):
  411. # self.name # <- here
  412. # In the absence of an enum member called "name" or "value", these attributes
  413. # should resolve to the descriptor on that particular instance, i.e. enum member.
  414. # For "value", we have no idea what that should be, but for "name", we at least
  415. # know that it should be a string, so infer that as a guess.
  416. if "name" not in target_names:
  417. code = dedent(
  418. """
  419. @property
  420. def name(self):
  421. return ''
  422. """
  423. )
  424. name_dynamicclassattr = AstroidBuilder(AstroidManager()).string_build(code)[
  425. "name"
  426. ]
  427. node.locals["name"] = [name_dynamicclassattr]
  428. break
  429. return node
  430. def infer_typing_namedtuple_class(class_node, context: InferenceContext | None = None):
  431. """Infer a subclass of typing.NamedTuple."""
  432. # Check if it has the corresponding bases
  433. annassigns_fields = [
  434. annassign.target.name
  435. for annassign in class_node.body
  436. if isinstance(annassign, nodes.AnnAssign)
  437. ]
  438. code = dedent(
  439. """
  440. from collections import namedtuple
  441. namedtuple({typename!r}, {fields!r})
  442. """
  443. ).format(typename=class_node.name, fields=",".join(annassigns_fields))
  444. node = extract_node(code)
  445. try:
  446. generated_class_node = next(infer_named_tuple(node, context))
  447. except StopIteration as e:
  448. raise InferenceError(node=node, context=context) from e
  449. for method in class_node.mymethods():
  450. generated_class_node.locals[method.name] = [method]
  451. for body_node in class_node.body:
  452. if isinstance(body_node, nodes.Assign):
  453. for target in body_node.targets:
  454. attr = target.name
  455. generated_class_node.locals[attr] = class_node.locals[attr]
  456. elif isinstance(body_node, nodes.ClassDef):
  457. generated_class_node.locals[body_node.name] = [body_node]
  458. return iter((generated_class_node,))
  459. def infer_typing_namedtuple_function(node, context: InferenceContext | None = None):
  460. """
  461. Starting with python3.9, NamedTuple is a function of the typing module.
  462. The class NamedTuple is build dynamically through a call to `type` during
  463. initialization of the `_NamedTuple` variable.
  464. """
  465. klass = extract_node(
  466. """
  467. from typing import _NamedTuple
  468. _NamedTuple
  469. """
  470. )
  471. return klass.infer(context)
  472. def infer_typing_namedtuple(
  473. node: nodes.Call, context: InferenceContext | None = None
  474. ) -> Iterator[nodes.ClassDef]:
  475. """Infer a typing.NamedTuple(...) call."""
  476. # This is essentially a namedtuple with different arguments
  477. # so we extract the args and infer a named tuple.
  478. try:
  479. func = next(node.func.infer())
  480. except (InferenceError, StopIteration) as exc:
  481. raise UseInferenceDefault from exc
  482. if func.qname() not in TYPING_NAMEDTUPLE_QUALIFIED:
  483. raise UseInferenceDefault
  484. if len(node.args) != 2:
  485. raise UseInferenceDefault
  486. if not isinstance(node.args[1], (nodes.List, nodes.Tuple)):
  487. raise UseInferenceDefault
  488. return infer_named_tuple(node, context)
  489. def _get_namedtuple_fields(node: nodes.Call) -> str:
  490. """Get and return fields of a NamedTuple in code-as-a-string.
  491. Because the fields are represented in their code form we can
  492. extract a node from them later on.
  493. """
  494. names = []
  495. container = None
  496. try:
  497. container = next(node.args[1].infer())
  498. except (InferenceError, StopIteration) as exc:
  499. raise UseInferenceDefault from exc
  500. # We pass on IndexError as we'll try to infer 'field_names' from the keywords
  501. except IndexError:
  502. pass
  503. if not container:
  504. for keyword_node in node.keywords:
  505. if keyword_node.arg == "field_names":
  506. try:
  507. container = next(keyword_node.value.infer())
  508. except (InferenceError, StopIteration) as exc:
  509. raise UseInferenceDefault from exc
  510. break
  511. if not isinstance(container, nodes.BaseContainer):
  512. raise UseInferenceDefault
  513. for elt in container.elts:
  514. if isinstance(elt, nodes.Const):
  515. names.append(elt.as_string())
  516. continue
  517. if not isinstance(elt, (nodes.List, nodes.Tuple)):
  518. raise UseInferenceDefault
  519. if len(elt.elts) != 2:
  520. raise UseInferenceDefault
  521. names.append(elt.elts[0].as_string())
  522. if names:
  523. field_names = f"({','.join(names)},)"
  524. else:
  525. field_names = ""
  526. return field_names
  527. def _is_enum_subclass(cls: astroid.ClassDef) -> bool:
  528. """Return whether cls is a subclass of an Enum."""
  529. return cls.is_subtype_of("enum.Enum")
  530. AstroidManager().register_transform(
  531. nodes.Call, inference_tip(infer_named_tuple), _looks_like_namedtuple
  532. )
  533. AstroidManager().register_transform(
  534. nodes.Call, inference_tip(infer_enum), _looks_like_enum
  535. )
  536. AstroidManager().register_transform(
  537. nodes.ClassDef, infer_enum_class, predicate=_is_enum_subclass
  538. )
  539. AstroidManager().register_transform(
  540. nodes.ClassDef, inference_tip(infer_typing_namedtuple_class), _has_namedtuple_base
  541. )
  542. AstroidManager().register_transform(
  543. nodes.FunctionDef,
  544. inference_tip(infer_typing_namedtuple_function),
  545. lambda node: node.name == "NamedTuple"
  546. and getattr(node.root(), "name", None) == "typing",
  547. )
  548. AstroidManager().register_transform(
  549. nodes.Call, inference_tip(infer_typing_namedtuple), _looks_like_typing_namedtuple
  550. )