transforms.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
  4. from __future__ import annotations
  5. import collections
  6. from typing import TYPE_CHECKING
  7. from astroid.context import _invalidate_cache
  8. if TYPE_CHECKING:
  9. from astroid import NodeNG
  10. class TransformVisitor:
  11. """A visitor for handling transforms.
  12. The standard approach of using it is to call
  13. :meth:`~visit` with an *astroid* module and the class
  14. will take care of the rest, walking the tree and running the
  15. transforms for each encountered node.
  16. Based on its usage in AstroidManager.brain, it should not be reinstantiated.
  17. """
  18. def __init__(self):
  19. self.transforms = collections.defaultdict(list)
  20. def _transform(self, node: NodeNG) -> NodeNG:
  21. """Call matching transforms for the given node if any and return the
  22. transformed node.
  23. """
  24. cls = node.__class__
  25. transforms = self.transforms[cls]
  26. for transform_func, predicate in transforms:
  27. if predicate is None or predicate(node):
  28. ret = transform_func(node)
  29. # if the transformation function returns something, it's
  30. # expected to be a replacement for the node
  31. if ret is not None:
  32. _invalidate_cache()
  33. node = ret
  34. if ret.__class__ != cls:
  35. # Can no longer apply the rest of the transforms.
  36. break
  37. return node
  38. def _visit(self, node):
  39. if hasattr(node, "_astroid_fields"):
  40. for name in node._astroid_fields:
  41. value = getattr(node, name)
  42. visited = self._visit_generic(value)
  43. if visited != value:
  44. setattr(node, name, visited)
  45. return self._transform(node)
  46. def _visit_generic(self, node):
  47. if isinstance(node, list):
  48. return [self._visit_generic(child) for child in node]
  49. if isinstance(node, tuple):
  50. return tuple(self._visit_generic(child) for child in node)
  51. if not node or isinstance(node, str):
  52. return node
  53. return self._visit(node)
  54. def register_transform(self, node_class, transform, predicate=None) -> None:
  55. """Register `transform(node)` function to be applied on the given
  56. astroid's `node_class` if `predicate` is None or returns true
  57. when called with the node as argument.
  58. The transform function may return a value which is then used to
  59. substitute the original node in the tree.
  60. """
  61. self.transforms[node_class].append((transform, predicate))
  62. def unregister_transform(self, node_class, transform, predicate=None) -> None:
  63. """Unregister the given transform."""
  64. self.transforms[node_class].remove((transform, predicate))
  65. def visit(self, module):
  66. """Walk the given astroid *tree* and transform each encountered node.
  67. Only the nodes which have transforms registered will actually
  68. be replaced or changed.
  69. """
  70. return self._visit(module)