constraint.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
  4. """Classes representing different types of constraints on inference values."""
  5. from __future__ import annotations
  6. import sys
  7. from abc import ABC, abstractmethod
  8. from collections.abc import Iterator
  9. from typing import Union
  10. from astroid import bases, nodes, util
  11. from astroid.typing import InferenceResult
  12. if sys.version_info >= (3, 11):
  13. from typing import Self
  14. else:
  15. from typing_extensions import Self
  16. _NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name]
  17. class Constraint(ABC):
  18. """Represents a single constraint on a variable."""
  19. def __init__(self, node: nodes.NodeNG, negate: bool) -> None:
  20. self.node = node
  21. """The node that this constraint applies to."""
  22. self.negate = negate
  23. """True if this constraint is negated. E.g., "is not" instead of "is"."""
  24. @classmethod
  25. @abstractmethod
  26. def match(
  27. cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
  28. ) -> Self | None:
  29. """Return a new constraint for node matched from expr, if expr matches
  30. the constraint pattern.
  31. If negate is True, negate the constraint.
  32. """
  33. @abstractmethod
  34. def satisfied_by(self, inferred: InferenceResult) -> bool:
  35. """Return True if this constraint is satisfied by the given inferred value."""
  36. class NoneConstraint(Constraint):
  37. """Represents an "is None" or "is not None" constraint."""
  38. CONST_NONE: nodes.Const = nodes.Const(None)
  39. @classmethod
  40. def match(
  41. cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
  42. ) -> Self | None:
  43. """Return a new constraint for node matched from expr, if expr matches
  44. the constraint pattern.
  45. Negate the constraint based on the value of negate.
  46. """
  47. if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
  48. left = expr.left
  49. op, right = expr.ops[0]
  50. if op in {"is", "is not"} and (
  51. _matches(left, node) and _matches(right, cls.CONST_NONE)
  52. ):
  53. negate = (op == "is" and negate) or (op == "is not" and not negate)
  54. return cls(node=node, negate=negate)
  55. return None
  56. def satisfied_by(self, inferred: InferenceResult) -> bool:
  57. """Return True if this constraint is satisfied by the given inferred value."""
  58. # Assume true if uninferable
  59. if isinstance(inferred, util.UninferableBase):
  60. return True
  61. # Return the XOR of self.negate and matches(inferred, self.CONST_NONE)
  62. return self.negate ^ _matches(inferred, self.CONST_NONE)
  63. def get_constraints(
  64. expr: _NameNodes, frame: nodes.LocalsDictNodeNG
  65. ) -> dict[nodes.If, set[Constraint]]:
  66. """Returns the constraints for the given expression.
  67. The returned dictionary maps the node where the constraint was generated to the
  68. corresponding constraint(s).
  69. Constraints are computed statically by analysing the code surrounding expr.
  70. Currently this only supports constraints generated from if conditions.
  71. """
  72. current_node: nodes.NodeNG | None = expr
  73. constraints_mapping: dict[nodes.If, set[Constraint]] = {}
  74. while current_node is not None and current_node is not frame:
  75. parent = current_node.parent
  76. if isinstance(parent, nodes.If):
  77. branch, _ = parent.locate_child(current_node)
  78. constraints: set[Constraint] | None = None
  79. if branch == "body":
  80. constraints = set(_match_constraint(expr, parent.test))
  81. elif branch == "orelse":
  82. constraints = set(_match_constraint(expr, parent.test, invert=True))
  83. if constraints:
  84. constraints_mapping[parent] = constraints
  85. current_node = parent
  86. return constraints_mapping
  87. ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,))
  88. """All supported constraint types."""
  89. def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool:
  90. """Returns True if the two nodes match."""
  91. if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name):
  92. return node1.name == node2.name
  93. if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute):
  94. return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr)
  95. if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const):
  96. return node1.value == node2.value
  97. return False
  98. def _match_constraint(
  99. node: _NameNodes, expr: nodes.NodeNG, invert: bool = False
  100. ) -> Iterator[Constraint]:
  101. """Yields all constraint patterns for node that match."""
  102. for constraint_cls in ALL_CONSTRAINT_CLASSES:
  103. constraint = constraint_cls.match(node, expr, invert)
  104. if constraint:
  105. yield constraint