bad_chained_comparison.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/PyCQA/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/PyCQA/pylint/blob/main/CONTRIBUTORS.txt
  4. from __future__ import annotations
  5. from typing import TYPE_CHECKING
  6. from astroid import nodes
  7. from pylint.checkers import BaseChecker
  8. from pylint.interfaces import HIGH
  9. if TYPE_CHECKING:
  10. from pylint.lint import PyLinter
  11. COMPARISON_OP = frozenset(("<", "<=", ">", ">=", "!=", "=="))
  12. IDENTITY_OP = frozenset(("is", "is not"))
  13. MEMBERSHIP_OP = frozenset(("in", "not in"))
  14. class BadChainedComparisonChecker(BaseChecker):
  15. """Checks for unintentional usage of chained comparison."""
  16. name = "bad-chained-comparison"
  17. msgs = {
  18. "W3601": (
  19. "Suspicious %s-part chained comparison using semantically incompatible operators (%s)",
  20. "bad-chained-comparison",
  21. "Used when there is a chained comparison where one expression is part "
  22. "of two comparisons that belong to different semantic groups "
  23. '("<" does not mean the same thing as "is", chaining them in '
  24. '"0 < x is None" is probably a mistake).',
  25. )
  26. }
  27. def _has_diff_semantic_groups(self, operators: list[str]) -> bool:
  28. # Check if comparison operators are in the same semantic group
  29. for semantic_group in (COMPARISON_OP, IDENTITY_OP, MEMBERSHIP_OP):
  30. if operators[0] in semantic_group:
  31. group = semantic_group
  32. return not all(o in group for o in operators)
  33. def visit_compare(self, node: nodes.Compare) -> None:
  34. operators = sorted({op[0] for op in node.ops})
  35. if self._has_diff_semantic_groups(operators):
  36. num_parts = f"{len(node.ops)}"
  37. incompatibles = (
  38. ", ".join(f"'{o}'" for o in operators[:-1]) + f" and '{operators[-1]}'"
  39. )
  40. self.add_message(
  41. "bad-chained-comparison",
  42. node=node,
  43. args=(num_parts, incompatibles),
  44. confidence=HIGH,
  45. )
  46. def register(linter: PyLinter) -> None:
  47. linter.register_checker(BadChainedComparisonChecker(linter))