inference_tip.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
  4. """Transform utilities (filters and decorator)."""
  5. from __future__ import annotations
  6. import typing
  7. from collections.abc import Iterator
  8. import wrapt
  9. from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault
  10. from astroid.nodes import NodeNG
  11. from astroid.typing import InferenceResult, InferFn
  12. _cache: dict[tuple[InferFn, NodeNG], list[InferenceResult] | None] = {}
  13. def clear_inference_tip_cache() -> None:
  14. """Clear the inference tips cache."""
  15. _cache.clear()
  16. @wrapt.decorator
  17. def _inference_tip_cached(
  18. func: InferFn, instance: None, args: typing.Any, kwargs: typing.Any
  19. ) -> Iterator[InferenceResult]:
  20. """Cache decorator used for inference tips."""
  21. node = args[0]
  22. try:
  23. result = _cache[func, node]
  24. # If through recursion we end up trying to infer the same
  25. # func + node we raise here.
  26. if result is None:
  27. raise UseInferenceDefault()
  28. except KeyError:
  29. _cache[func, node] = None
  30. result = _cache[func, node] = list(func(*args, **kwargs))
  31. assert result
  32. return iter(result)
  33. def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> InferFn:
  34. """Given an instance specific inference function, return a function to be
  35. given to AstroidManager().register_transform to set this inference function.
  36. :param bool raise_on_overwrite: Raise an `InferenceOverwriteError`
  37. if the inference tip will overwrite another. Used for debugging
  38. Typical usage
  39. .. sourcecode:: python
  40. AstroidManager().register_transform(Call, inference_tip(infer_named_tuple),
  41. predicate)
  42. .. Note::
  43. Using an inference tip will override
  44. any previously set inference tip for the given
  45. node. Use a predicate in the transform to prevent
  46. excess overwrites.
  47. """
  48. def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
  49. if (
  50. raise_on_overwrite
  51. and node._explicit_inference is not None
  52. and node._explicit_inference is not infer_function
  53. ):
  54. raise InferenceOverwriteError(
  55. "Inference already set to {existing_inference}. "
  56. "Trying to overwrite with {new_inference} for {node}".format(
  57. existing_inference=infer_function,
  58. new_inference=node._explicit_inference,
  59. node=node,
  60. )
  61. )
  62. # pylint: disable=no-value-for-parameter
  63. node._explicit_inference = _inference_tip_cached(infer_function)
  64. return node
  65. return transform