utils.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
  4. """Generic classes/functions for pyreverse core/extensions."""
  5. from __future__ import annotations
  6. import os
  7. import re
  8. import shutil
  9. import subprocess
  10. import sys
  11. from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
  12. import astroid
  13. from astroid import nodes
  14. from astroid.typing import InferenceResult
  15. if TYPE_CHECKING:
  16. from pylint.pyreverse.diagrams import ClassDiagram, PackageDiagram
  17. _CallbackT = Callable[
  18. [nodes.NodeNG],
  19. Union[Tuple[ClassDiagram], Tuple[PackageDiagram, ClassDiagram], None],
  20. ]
  21. _CallbackTupleT = Tuple[Optional[_CallbackT], Optional[_CallbackT]]
  22. RCFILE = ".pyreverserc"
  23. def get_default_options() -> list[str]:
  24. """Read config file and return list of options."""
  25. options = []
  26. home = os.environ.get("HOME", "")
  27. if home:
  28. rcfile = os.path.join(home, RCFILE)
  29. try:
  30. with open(rcfile, encoding="utf-8") as file_handle:
  31. options = file_handle.read().split()
  32. except OSError:
  33. pass # ignore if no config file found
  34. return options
  35. def insert_default_options() -> None:
  36. """Insert default options to sys.argv."""
  37. options = get_default_options()
  38. options.reverse()
  39. for arg in options:
  40. sys.argv.insert(1, arg)
  41. # astroid utilities ###########################################################
  42. SPECIAL = re.compile(r"^__([^\W_]_*)+__$")
  43. PRIVATE = re.compile(r"^__(_*[^\W_])+_?$")
  44. PROTECTED = re.compile(r"^_\w*$")
  45. def get_visibility(name: str) -> str:
  46. """Return the visibility from a name: public, protected, private or special."""
  47. if SPECIAL.match(name):
  48. visibility = "special"
  49. elif PRIVATE.match(name):
  50. visibility = "private"
  51. elif PROTECTED.match(name):
  52. visibility = "protected"
  53. else:
  54. visibility = "public"
  55. return visibility
  56. def is_interface(node: nodes.ClassDef) -> bool:
  57. # bw compatibility
  58. return node.type == "interface" # type: ignore[no-any-return]
  59. def is_exception(node: nodes.ClassDef) -> bool:
  60. # bw compatibility
  61. return node.type == "exception" # type: ignore[no-any-return]
  62. # Helpers #####################################################################
  63. _SPECIAL = 2
  64. _PROTECTED = 4
  65. _PRIVATE = 8
  66. MODES = {
  67. "ALL": 0,
  68. "PUB_ONLY": _SPECIAL + _PROTECTED + _PRIVATE,
  69. "SPECIAL": _SPECIAL,
  70. "OTHER": _PROTECTED + _PRIVATE,
  71. }
  72. VIS_MOD = {
  73. "special": _SPECIAL,
  74. "protected": _PROTECTED,
  75. "private": _PRIVATE,
  76. "public": 0,
  77. }
  78. class FilterMixIn:
  79. """Filter nodes according to a mode and nodes' visibility."""
  80. def __init__(self, mode: str) -> None:
  81. """Init filter modes."""
  82. __mode = 0
  83. for nummod in mode.split("+"):
  84. try:
  85. __mode += MODES[nummod]
  86. except KeyError as ex:
  87. print(f"Unknown filter mode {ex}", file=sys.stderr)
  88. self.__mode = __mode
  89. def show_attr(self, node: nodes.NodeNG | str) -> bool:
  90. """Return true if the node should be treated."""
  91. visibility = get_visibility(getattr(node, "name", node))
  92. return not self.__mode & VIS_MOD[visibility]
  93. class LocalsVisitor:
  94. """Visit a project by traversing the locals dictionary.
  95. * visit_<class name> on entering a node, where class name is the class of
  96. the node in lower case
  97. * leave_<class name> on leaving a node, where class name is the class of
  98. the node in lower case
  99. """
  100. def __init__(self) -> None:
  101. self._cache: dict[type[nodes.NodeNG], _CallbackTupleT] = {}
  102. self._visited: set[nodes.NodeNG] = set()
  103. def get_callbacks(self, node: nodes.NodeNG) -> _CallbackTupleT:
  104. """Get callbacks from handler for the visited node."""
  105. klass = node.__class__
  106. methods = self._cache.get(klass)
  107. if methods is None:
  108. kid = klass.__name__.lower()
  109. e_method = getattr(
  110. self, f"visit_{kid}", getattr(self, "visit_default", None)
  111. )
  112. l_method = getattr(
  113. self, f"leave_{kid}", getattr(self, "leave_default", None)
  114. )
  115. self._cache[klass] = (e_method, l_method)
  116. else:
  117. e_method, l_method = methods
  118. return e_method, l_method
  119. def visit(self, node: nodes.NodeNG) -> Any:
  120. """Launch the visit starting from the given node."""
  121. if node in self._visited:
  122. return None
  123. self._visited.add(node)
  124. methods = self.get_callbacks(node)
  125. if methods[0] is not None:
  126. methods[0](node)
  127. if hasattr(node, "locals"): # skip Instance and other proxy
  128. for local_node in node.values():
  129. self.visit(local_node)
  130. if methods[1] is not None:
  131. return methods[1](node)
  132. return None
  133. def get_annotation_label(ann: nodes.Name | nodes.NodeNG) -> str:
  134. if isinstance(ann, nodes.Name) and ann.name is not None:
  135. return ann.name # type: ignore[no-any-return]
  136. if isinstance(ann, nodes.NodeNG):
  137. return ann.as_string() # type: ignore[no-any-return]
  138. return ""
  139. def get_annotation(
  140. node: nodes.AssignAttr | nodes.AssignName,
  141. ) -> nodes.Name | nodes.Subscript | None:
  142. """Return the annotation for `node`."""
  143. ann = None
  144. if isinstance(node.parent, nodes.AnnAssign):
  145. ann = node.parent.annotation
  146. elif isinstance(node, nodes.AssignAttr):
  147. init_method = node.parent.parent
  148. try:
  149. annotations = dict(zip(init_method.locals, init_method.args.annotations))
  150. ann = annotations.get(node.parent.value.name)
  151. except AttributeError:
  152. pass
  153. else:
  154. return ann
  155. try:
  156. default, *_ = node.infer()
  157. except astroid.InferenceError:
  158. default = ""
  159. label = get_annotation_label(ann)
  160. if (
  161. ann
  162. and getattr(default, "value", "value") is None
  163. and not label.startswith("Optional")
  164. and (
  165. not isinstance(ann, nodes.BinOp)
  166. or not any(
  167. isinstance(child, nodes.Const) and child.value is None
  168. for child in ann.get_children()
  169. )
  170. )
  171. ):
  172. label = rf"Optional[{label}]"
  173. if label and ann:
  174. ann.name = label
  175. return ann
  176. def infer_node(node: nodes.AssignAttr | nodes.AssignName) -> set[InferenceResult]:
  177. """Return a set containing the node annotation if it exists
  178. otherwise return a set of the inferred types using the NodeNG.infer method.
  179. """
  180. ann = get_annotation(node)
  181. try:
  182. if ann:
  183. if isinstance(ann, nodes.Subscript) or (
  184. isinstance(ann, nodes.BinOp) and ann.op == "|"
  185. ):
  186. return {ann}
  187. return set(ann.infer())
  188. return set(node.infer())
  189. except astroid.InferenceError:
  190. return {ann} if ann else set()
  191. def check_graphviz_availability() -> None:
  192. """Check if the ``dot`` command is available on the machine.
  193. This is needed if image output is desired and ``dot`` is used to convert
  194. from *.dot or *.gv into the final output format.
  195. """
  196. if shutil.which("dot") is None:
  197. print("'Graphviz' needs to be installed for your chosen output format.")
  198. sys.exit(32)
  199. def check_if_graphviz_supports_format(output_format: str) -> None:
  200. """Check if the ``dot`` command supports the requested output format.
  201. This is needed if image output is desired and ``dot`` is used to convert
  202. from *.gv into the final output format.
  203. """
  204. dot_output = subprocess.run(
  205. ["dot", "-T?"], capture_output=True, check=False, encoding="utf-8"
  206. )
  207. match = re.match(
  208. pattern=r".*Use one of: (?P<formats>(\S*\s?)+)",
  209. string=dot_output.stderr.strip(),
  210. )
  211. if not match:
  212. print(
  213. "Unable to determine Graphviz supported output formats. "
  214. "Pyreverse will continue, but subsequent error messages "
  215. "regarding the output format may come from Graphviz directly."
  216. )
  217. return
  218. supported_formats = match.group("formats")
  219. if output_format not in supported_formats.split():
  220. print(
  221. f"Format {output_format} is not supported by Graphviz. It supports: {supported_formats}"
  222. )
  223. sys.exit(32)