ast_walker.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
  4. from __future__ import annotations
  5. import sys
  6. import traceback
  7. from collections import defaultdict
  8. from collections.abc import Sequence
  9. from typing import TYPE_CHECKING, Callable
  10. from astroid import nodes
  11. if TYPE_CHECKING:
  12. from pylint.checkers.base_checker import BaseChecker
  13. from pylint.lint import PyLinter
  14. # Callable parameter type NodeNG not completely correct.
  15. # Due to contravariance of Callable parameter types,
  16. # it should be a Union of all NodeNG subclasses.
  17. # However, since the methods are only retrieved with
  18. # getattr(checker, member) and thus are inferred as Any,
  19. # NodeNG will work too.
  20. AstCallback = Callable[[nodes.NodeNG], None]
  21. class ASTWalker:
  22. def __init__(self, linter: PyLinter) -> None:
  23. # callbacks per node types
  24. self.nbstatements = 0
  25. self.visit_events: defaultdict[str, list[AstCallback]] = defaultdict(list)
  26. self.leave_events: defaultdict[str, list[AstCallback]] = defaultdict(list)
  27. self.linter = linter
  28. self.exception_msg = False
  29. def _is_method_enabled(self, method: AstCallback) -> bool:
  30. if not hasattr(method, "checks_msgs"):
  31. return True
  32. return any(self.linter.is_message_enabled(m) for m in method.checks_msgs)
  33. def add_checker(self, checker: BaseChecker) -> None:
  34. """Walk to the checker's dir and collect visit and leave methods."""
  35. vcids: set[str] = set()
  36. lcids: set[str] = set()
  37. visits = self.visit_events
  38. leaves = self.leave_events
  39. for member in dir(checker):
  40. cid = member[6:]
  41. if cid == "default":
  42. continue
  43. if member.startswith("visit_"):
  44. v_meth = getattr(checker, member)
  45. # don't use visit_methods with no activated message:
  46. if self._is_method_enabled(v_meth):
  47. visits[cid].append(v_meth)
  48. vcids.add(cid)
  49. elif member.startswith("leave_"):
  50. l_meth = getattr(checker, member)
  51. # don't use leave_methods with no activated message:
  52. if self._is_method_enabled(l_meth):
  53. leaves[cid].append(l_meth)
  54. lcids.add(cid)
  55. visit_default = getattr(checker, "visit_default", None)
  56. if visit_default:
  57. for cls in nodes.ALL_NODE_CLASSES:
  58. cid = cls.__name__.lower()
  59. if cid not in vcids:
  60. visits[cid].append(visit_default)
  61. # For now, we have no "leave_default" method in Pylint
  62. def walk(self, astroid: nodes.NodeNG) -> None:
  63. """Call visit events of astroid checkers for the given node, recurse on
  64. its children, then leave events.
  65. """
  66. cid = astroid.__class__.__name__.lower()
  67. # Detect if the node is a new name for a deprecated alias.
  68. # In this case, favour the methods for the deprecated
  69. # alias if any, in order to maintain backwards
  70. # compatibility.
  71. visit_events: Sequence[AstCallback] = self.visit_events.get(cid, ())
  72. leave_events: Sequence[AstCallback] = self.leave_events.get(cid, ())
  73. # pylint: disable = too-many-try-statements
  74. try:
  75. if astroid.is_statement:
  76. self.nbstatements += 1
  77. # generate events for this node on each checker
  78. for callback in visit_events:
  79. callback(astroid)
  80. # recurse on children
  81. for child in astroid.get_children():
  82. self.walk(child)
  83. for callback in leave_events:
  84. callback(astroid)
  85. except Exception:
  86. if self.exception_msg is False:
  87. file = getattr(astroid.root(), "file", None)
  88. print(
  89. f"Exception on node {repr(astroid)} in file '{file}'",
  90. file=sys.stderr,
  91. )
  92. traceback.print_exc()
  93. self.exception_msg = True
  94. raise