brain_dataclasses.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
  4. """
  5. Astroid hook for the dataclasses library.
  6. Support built-in dataclasses, pydantic.dataclasses, and marshmallow_dataclass-annotated
  7. dataclasses. References:
  8. - https://docs.python.org/3/library/dataclasses.html
  9. - https://pydantic-docs.helpmanual.io/usage/dataclasses/
  10. - https://lovasoa.github.io/marshmallow_dataclass/
  11. """
  12. from __future__ import annotations
  13. import sys
  14. from collections.abc import Iterator
  15. from typing import Tuple, Union
  16. from astroid import bases, context, helpers, nodes
  17. from astroid.builder import parse
  18. from astroid.const import PY39_PLUS, PY310_PLUS
  19. from astroid.exceptions import AstroidSyntaxError, InferenceError, UseInferenceDefault
  20. from astroid.inference_tip import inference_tip
  21. from astroid.manager import AstroidManager
  22. from astroid.typing import InferenceResult
  23. from astroid.util import Uninferable, UninferableBase
  24. if sys.version_info >= (3, 8):
  25. from typing import Literal
  26. else:
  27. from typing_extensions import Literal
  28. _FieldDefaultReturn = Union[
  29. None,
  30. Tuple[Literal["default"], nodes.NodeNG],
  31. Tuple[Literal["default_factory"], nodes.Call],
  32. ]
  33. DATACLASSES_DECORATORS = frozenset(("dataclass",))
  34. FIELD_NAME = "field"
  35. DATACLASS_MODULES = frozenset(
  36. ("dataclasses", "marshmallow_dataclass", "pydantic.dataclasses")
  37. )
  38. DEFAULT_FACTORY = "_HAS_DEFAULT_FACTORY" # based on typing.py
  39. def is_decorated_with_dataclass(
  40. node: nodes.ClassDef, decorator_names: frozenset[str] = DATACLASSES_DECORATORS
  41. ) -> bool:
  42. """Return True if a decorated node has a `dataclass` decorator applied."""
  43. if not isinstance(node, nodes.ClassDef) or not node.decorators:
  44. return False
  45. return any(
  46. _looks_like_dataclass_decorator(decorator_attribute, decorator_names)
  47. for decorator_attribute in node.decorators.nodes
  48. )
  49. def dataclass_transform(node: nodes.ClassDef) -> None:
  50. """Rewrite a dataclass to be easily understood by pylint."""
  51. node.is_dataclass = True
  52. for assign_node in _get_dataclass_attributes(node):
  53. name = assign_node.target.name
  54. rhs_node = nodes.Unknown(
  55. lineno=assign_node.lineno,
  56. col_offset=assign_node.col_offset,
  57. parent=assign_node,
  58. )
  59. rhs_node = AstroidManager().visit_transforms(rhs_node)
  60. node.instance_attrs[name] = [rhs_node]
  61. if not _check_generate_dataclass_init(node):
  62. return
  63. kw_only_decorated = False
  64. if PY310_PLUS and node.decorators.nodes:
  65. for decorator in node.decorators.nodes:
  66. if not isinstance(decorator, nodes.Call):
  67. kw_only_decorated = False
  68. break
  69. for keyword in decorator.keywords:
  70. if keyword.arg == "kw_only":
  71. kw_only_decorated = keyword.value.bool_value()
  72. init_str = _generate_dataclass_init(
  73. node,
  74. list(_get_dataclass_attributes(node, init=True)),
  75. kw_only_decorated,
  76. )
  77. try:
  78. init_node = parse(init_str)["__init__"]
  79. except AstroidSyntaxError:
  80. pass
  81. else:
  82. init_node.parent = node
  83. init_node.lineno, init_node.col_offset = None, None
  84. node.locals["__init__"] = [init_node]
  85. root = node.root()
  86. if DEFAULT_FACTORY not in root.locals:
  87. new_assign = parse(f"{DEFAULT_FACTORY} = object()").body[0]
  88. new_assign.parent = root
  89. root.locals[DEFAULT_FACTORY] = [new_assign.targets[0]]
  90. def _get_dataclass_attributes(
  91. node: nodes.ClassDef, init: bool = False
  92. ) -> Iterator[nodes.AnnAssign]:
  93. """Yield the AnnAssign nodes of dataclass attributes for the node.
  94. If init is True, also include InitVars.
  95. """
  96. for assign_node in node.body:
  97. if not isinstance(assign_node, nodes.AnnAssign) or not isinstance(
  98. assign_node.target, nodes.AssignName
  99. ):
  100. continue
  101. # Annotation is never None
  102. if _is_class_var(assign_node.annotation): # type: ignore[arg-type]
  103. continue
  104. if _is_keyword_only_sentinel(assign_node.annotation):
  105. continue
  106. # Annotation is never None
  107. if not init and _is_init_var(assign_node.annotation): # type: ignore[arg-type]
  108. continue
  109. yield assign_node
  110. def _check_generate_dataclass_init(node: nodes.ClassDef) -> bool:
  111. """Return True if we should generate an __init__ method for node.
  112. This is True when:
  113. - node doesn't define its own __init__ method
  114. - the dataclass decorator was called *without* the keyword argument init=False
  115. """
  116. if "__init__" in node.locals:
  117. return False
  118. found = None
  119. for decorator_attribute in node.decorators.nodes:
  120. if not isinstance(decorator_attribute, nodes.Call):
  121. continue
  122. if _looks_like_dataclass_decorator(decorator_attribute):
  123. found = decorator_attribute
  124. if found is None:
  125. return True
  126. # Check for keyword arguments of the form init=False
  127. return not any(
  128. keyword.arg == "init"
  129. and not keyword.value.bool_value() # type: ignore[union-attr] # value is never None
  130. for keyword in found.keywords
  131. )
  132. def _find_arguments_from_base_classes(
  133. node: nodes.ClassDef,
  134. ) -> tuple[
  135. dict[str, tuple[str | None, str | None]], dict[str, tuple[str | None, str | None]]
  136. ]:
  137. """Iterate through all bases and get their typing and defaults."""
  138. pos_only_store: dict[str, tuple[str | None, str | None]] = {}
  139. kw_only_store: dict[str, tuple[str | None, str | None]] = {}
  140. # See TODO down below
  141. # all_have_defaults = True
  142. for base in reversed(node.mro()):
  143. if not base.is_dataclass:
  144. continue
  145. try:
  146. base_init: nodes.FunctionDef = base.locals["__init__"][0]
  147. except KeyError:
  148. continue
  149. pos_only, kw_only = base_init.args._get_arguments_data()
  150. for posarg, data in pos_only.items():
  151. # if data[1] is None:
  152. # if all_have_defaults and pos_only_store:
  153. # # TODO: This should return an Uninferable as this would raise
  154. # # a TypeError at runtime. However, transforms can't return
  155. # # Uninferables currently.
  156. # pass
  157. # all_have_defaults = False
  158. pos_only_store[posarg] = data
  159. for kwarg, data in kw_only.items():
  160. kw_only_store[kwarg] = data
  161. return pos_only_store, kw_only_store
  162. def _parse_arguments_into_strings(
  163. pos_only_store: dict[str, tuple[str | None, str | None]],
  164. kw_only_store: dict[str, tuple[str | None, str | None]],
  165. ) -> tuple[str, str]:
  166. """Parse positional and keyword arguments into strings for an __init__ method."""
  167. pos_only, kw_only = "", ""
  168. for pos_arg, data in pos_only_store.items():
  169. pos_only += pos_arg
  170. if data[0]:
  171. pos_only += ": " + data[0]
  172. if data[1]:
  173. pos_only += " = " + data[1]
  174. pos_only += ", "
  175. for kw_arg, data in kw_only_store.items():
  176. kw_only += kw_arg
  177. if data[0]:
  178. kw_only += ": " + data[0]
  179. if data[1]:
  180. kw_only += " = " + data[1]
  181. kw_only += ", "
  182. return pos_only, kw_only
  183. def _get_previous_field_default(node: nodes.ClassDef, name: str) -> nodes.NodeNG | None:
  184. """Get the default value of a previously defined field."""
  185. for base in reversed(node.mro()):
  186. if not base.is_dataclass:
  187. continue
  188. if name in base.locals:
  189. for assign in base.locals[name]:
  190. if (
  191. isinstance(assign.parent, nodes.AnnAssign)
  192. and assign.parent.value
  193. and isinstance(assign.parent.value, nodes.Call)
  194. and _looks_like_dataclass_field_call(assign.parent.value)
  195. ):
  196. default = _get_field_default(assign.parent.value)
  197. if default:
  198. return default[1]
  199. return None
  200. def _generate_dataclass_init( # pylint: disable=too-many-locals
  201. node: nodes.ClassDef, assigns: list[nodes.AnnAssign], kw_only_decorated: bool
  202. ) -> str:
  203. """Return an init method for a dataclass given the targets."""
  204. params: list[str] = []
  205. kw_only_params: list[str] = []
  206. assignments: list[str] = []
  207. prev_pos_only_store, prev_kw_only_store = _find_arguments_from_base_classes(node)
  208. for assign in assigns:
  209. name, annotation, value = assign.target.name, assign.annotation, assign.value
  210. # Check whether this assign is overriden by a property assignment
  211. property_node: nodes.FunctionDef | None = None
  212. for additional_assign in node.locals[name]:
  213. if not isinstance(additional_assign, nodes.FunctionDef):
  214. continue
  215. if not additional_assign.decorators:
  216. continue
  217. if "builtins.property" in additional_assign.decoratornames():
  218. property_node = additional_assign
  219. break
  220. is_field = isinstance(value, nodes.Call) and _looks_like_dataclass_field_call(
  221. value, check_scope=False
  222. )
  223. if is_field:
  224. # Skip any fields that have `init=False`
  225. if any(
  226. keyword.arg == "init" and not keyword.value.bool_value()
  227. for keyword in value.keywords # type: ignore[union-attr] # value is never None
  228. ):
  229. # Also remove the name from the previous arguments to be inserted later
  230. prev_pos_only_store.pop(name, None)
  231. prev_kw_only_store.pop(name, None)
  232. continue
  233. if _is_init_var(annotation): # type: ignore[arg-type] # annotation is never None
  234. init_var = True
  235. if isinstance(annotation, nodes.Subscript):
  236. annotation = annotation.slice
  237. else:
  238. # Cannot determine type annotation for parameter from InitVar
  239. annotation = None
  240. assignment_str = ""
  241. else:
  242. init_var = False
  243. assignment_str = f"self.{name} = {name}"
  244. ann_str, default_str = None, None
  245. if annotation is not None:
  246. ann_str = annotation.as_string()
  247. if value:
  248. if is_field:
  249. result = _get_field_default(value) # type: ignore[arg-type]
  250. if result:
  251. default_type, default_node = result
  252. if default_type == "default":
  253. default_str = default_node.as_string()
  254. elif default_type == "default_factory":
  255. default_str = DEFAULT_FACTORY
  256. assignment_str = (
  257. f"self.{name} = {default_node.as_string()} "
  258. f"if {name} is {DEFAULT_FACTORY} else {name}"
  259. )
  260. else:
  261. default_str = value.as_string()
  262. elif property_node:
  263. # We set the result of the property call as default
  264. # This hides the fact that this would normally be a 'property object'
  265. # But we can't represent those as string
  266. try:
  267. # Call str to make sure also Uninferable gets stringified
  268. default_str = str(next(property_node.infer_call_result()).as_string())
  269. except (InferenceError, StopIteration):
  270. pass
  271. else:
  272. # Even with `init=False` the default value still can be propogated to
  273. # later assignments. Creating weird signatures like:
  274. # (self, a: str = 1) -> None
  275. previous_default = _get_previous_field_default(node, name)
  276. if previous_default:
  277. default_str = previous_default.as_string()
  278. # Construct the param string to add to the init if necessary
  279. param_str = name
  280. if ann_str is not None:
  281. param_str += f": {ann_str}"
  282. if default_str is not None:
  283. param_str += f" = {default_str}"
  284. # If the field is a kw_only field, we need to add it to the kw_only_params
  285. # This overwrites whether or not the class is kw_only decorated
  286. if is_field:
  287. kw_only = [k for k in value.keywords if k.arg == "kw_only"] # type: ignore[union-attr]
  288. if kw_only:
  289. if kw_only[0].value.bool_value():
  290. kw_only_params.append(param_str)
  291. else:
  292. params.append(param_str)
  293. continue
  294. # If kw_only decorated, we need to add all parameters to the kw_only_params
  295. if kw_only_decorated:
  296. if name in prev_kw_only_store:
  297. prev_kw_only_store[name] = (ann_str, default_str)
  298. else:
  299. kw_only_params.append(param_str)
  300. else:
  301. # If the name was previously seen, overwrite that data
  302. # pylint: disable-next=else-if-used
  303. if name in prev_pos_only_store:
  304. prev_pos_only_store[name] = (ann_str, default_str)
  305. elif name in prev_kw_only_store:
  306. params = [name] + params
  307. prev_kw_only_store.pop(name)
  308. else:
  309. params.append(param_str)
  310. if not init_var:
  311. assignments.append(assignment_str)
  312. prev_pos_only, prev_kw_only = _parse_arguments_into_strings(
  313. prev_pos_only_store, prev_kw_only_store
  314. )
  315. # Construct the new init method paramter string
  316. # First we do the positional only parameters, making sure to add the
  317. # the self parameter and the comma to allow adding keyword only parameters
  318. params_string = "" if "self" in prev_pos_only else "self, "
  319. params_string += prev_pos_only + ", ".join(params)
  320. if not params_string.endswith(", "):
  321. params_string += ", "
  322. # Then we add the keyword only parameters
  323. if prev_kw_only or kw_only_params:
  324. params_string += "*, "
  325. params_string += f"{prev_kw_only}{', '.join(kw_only_params)}"
  326. assignments_string = "\n ".join(assignments) if assignments else "pass"
  327. return f"def __init__({params_string}) -> None:\n {assignments_string}"
  328. def infer_dataclass_attribute(
  329. node: nodes.Unknown, ctx: context.InferenceContext | None = None
  330. ) -> Iterator[InferenceResult]:
  331. """Inference tip for an Unknown node that was dynamically generated to
  332. represent a dataclass attribute.
  333. In the case that a default value is provided, that is inferred first.
  334. Then, an Instance of the annotated class is yielded.
  335. """
  336. assign = node.parent
  337. if not isinstance(assign, nodes.AnnAssign):
  338. yield Uninferable
  339. return
  340. annotation, value = assign.annotation, assign.value
  341. if value is not None:
  342. yield from value.infer(context=ctx)
  343. if annotation is not None:
  344. yield from _infer_instance_from_annotation(annotation, ctx=ctx)
  345. else:
  346. yield Uninferable
  347. def infer_dataclass_field_call(
  348. node: nodes.Call, ctx: context.InferenceContext | None = None
  349. ) -> Iterator[InferenceResult]:
  350. """Inference tip for dataclass field calls."""
  351. if not isinstance(node.parent, (nodes.AnnAssign, nodes.Assign)):
  352. raise UseInferenceDefault
  353. result = _get_field_default(node)
  354. if not result:
  355. yield Uninferable
  356. else:
  357. default_type, default = result
  358. if default_type == "default":
  359. yield from default.infer(context=ctx)
  360. else:
  361. new_call = parse(default.as_string()).body[0].value
  362. new_call.parent = node.parent
  363. yield from new_call.infer(context=ctx)
  364. def _looks_like_dataclass_decorator(
  365. node: nodes.NodeNG, decorator_names: frozenset[str] = DATACLASSES_DECORATORS
  366. ) -> bool:
  367. """Return True if node looks like a dataclass decorator.
  368. Uses inference to lookup the value of the node, and if that fails,
  369. matches against specific names.
  370. """
  371. if isinstance(node, nodes.Call): # decorator with arguments
  372. node = node.func
  373. try:
  374. inferred = next(node.infer())
  375. except (InferenceError, StopIteration):
  376. inferred = Uninferable
  377. if isinstance(inferred, UninferableBase):
  378. if isinstance(node, nodes.Name):
  379. return node.name in decorator_names
  380. if isinstance(node, nodes.Attribute):
  381. return node.attrname in decorator_names
  382. return False
  383. return (
  384. isinstance(inferred, nodes.FunctionDef)
  385. and inferred.name in decorator_names
  386. and inferred.root().name in DATACLASS_MODULES
  387. )
  388. def _looks_like_dataclass_attribute(node: nodes.Unknown) -> bool:
  389. """Return True if node was dynamically generated as the child of an AnnAssign
  390. statement.
  391. """
  392. parent = node.parent
  393. if not parent:
  394. return False
  395. scope = parent.scope()
  396. return (
  397. isinstance(parent, nodes.AnnAssign)
  398. and isinstance(scope, nodes.ClassDef)
  399. and is_decorated_with_dataclass(scope)
  400. )
  401. def _looks_like_dataclass_field_call(
  402. node: nodes.Call, check_scope: bool = True
  403. ) -> bool:
  404. """Return True if node is calling dataclasses field or Field
  405. from an AnnAssign statement directly in the body of a ClassDef.
  406. If check_scope is False, skips checking the statement and body.
  407. """
  408. if check_scope:
  409. stmt = node.statement(future=True)
  410. scope = stmt.scope()
  411. if not (
  412. isinstance(stmt, nodes.AnnAssign)
  413. and stmt.value is not None
  414. and isinstance(scope, nodes.ClassDef)
  415. and is_decorated_with_dataclass(scope)
  416. ):
  417. return False
  418. try:
  419. inferred = next(node.func.infer())
  420. except (InferenceError, StopIteration):
  421. return False
  422. if not isinstance(inferred, nodes.FunctionDef):
  423. return False
  424. return inferred.name == FIELD_NAME and inferred.root().name in DATACLASS_MODULES
  425. def _get_field_default(field_call: nodes.Call) -> _FieldDefaultReturn:
  426. """Return a the default value of a field call, and the corresponding keyword
  427. argument name.
  428. field(default=...) results in the ... node
  429. field(default_factory=...) results in a Call node with func ... and no arguments
  430. If neither or both arguments are present, return ("", None) instead,
  431. indicating that there is not a valid default value.
  432. """
  433. default, default_factory = None, None
  434. for keyword in field_call.keywords:
  435. if keyword.arg == "default":
  436. default = keyword.value
  437. elif keyword.arg == "default_factory":
  438. default_factory = keyword.value
  439. if default is not None and default_factory is None:
  440. return "default", default
  441. if default is None and default_factory is not None:
  442. new_call = nodes.Call(
  443. lineno=field_call.lineno,
  444. col_offset=field_call.col_offset,
  445. parent=field_call.parent,
  446. )
  447. new_call.postinit(func=default_factory)
  448. return "default_factory", new_call
  449. return None
  450. def _is_class_var(node: nodes.NodeNG) -> bool:
  451. """Return True if node is a ClassVar, with or without subscripting."""
  452. if PY39_PLUS:
  453. try:
  454. inferred = next(node.infer())
  455. except (InferenceError, StopIteration):
  456. return False
  457. return getattr(inferred, "name", "") == "ClassVar"
  458. # Before Python 3.9, inference returns typing._SpecialForm instead of ClassVar.
  459. # Our backup is to inspect the node's structure.
  460. return isinstance(node, nodes.Subscript) and (
  461. isinstance(node.value, nodes.Name)
  462. and node.value.name == "ClassVar"
  463. or isinstance(node.value, nodes.Attribute)
  464. and node.value.attrname == "ClassVar"
  465. )
  466. def _is_keyword_only_sentinel(node: nodes.NodeNG) -> bool:
  467. """Return True if node is the KW_ONLY sentinel."""
  468. if not PY310_PLUS:
  469. return False
  470. inferred = helpers.safe_infer(node)
  471. return (
  472. isinstance(inferred, bases.Instance)
  473. and inferred.qname() == "dataclasses._KW_ONLY_TYPE"
  474. )
  475. def _is_init_var(node: nodes.NodeNG) -> bool:
  476. """Return True if node is an InitVar, with or without subscripting."""
  477. try:
  478. inferred = next(node.infer())
  479. except (InferenceError, StopIteration):
  480. return False
  481. return getattr(inferred, "name", "") == "InitVar"
  482. # Allowed typing classes for which we support inferring instances
  483. _INFERABLE_TYPING_TYPES = frozenset(
  484. (
  485. "Dict",
  486. "FrozenSet",
  487. "List",
  488. "Set",
  489. "Tuple",
  490. )
  491. )
  492. def _infer_instance_from_annotation(
  493. node: nodes.NodeNG, ctx: context.InferenceContext | None = None
  494. ) -> Iterator[UninferableBase | bases.Instance]:
  495. """Infer an instance corresponding to the type annotation represented by node.
  496. Currently has limited support for the typing module.
  497. """
  498. klass = None
  499. try:
  500. klass = next(node.infer(context=ctx))
  501. except (InferenceError, StopIteration):
  502. yield Uninferable
  503. if not isinstance(klass, nodes.ClassDef):
  504. yield Uninferable
  505. elif klass.root().name in {
  506. "typing",
  507. "_collections_abc",
  508. "",
  509. }: # "" because of synthetic nodes in brain_typing.py
  510. if klass.name in _INFERABLE_TYPING_TYPES:
  511. yield klass.instantiate_class()
  512. else:
  513. yield Uninferable
  514. else:
  515. yield klass.instantiate_class()
  516. AstroidManager().register_transform(
  517. nodes.ClassDef, dataclass_transform, is_decorated_with_dataclass
  518. )
  519. AstroidManager().register_transform(
  520. nodes.Call,
  521. inference_tip(infer_dataclass_field_call, raise_on_overwrite=True),
  522. _looks_like_dataclass_field_call,
  523. )
  524. AstroidManager().register_transform(
  525. nodes.Unknown,
  526. inference_tip(infer_dataclass_attribute, raise_on_overwrite=True),
  527. _looks_like_dataclass_attribute,
  528. )