inspector.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
  4. """Visitor doing some post-processing on the astroid tree.
  5. Try to resolve definitions (namespace) dictionary, relationship...
  6. """
  7. from __future__ import annotations
  8. import collections
  9. import os
  10. import traceback
  11. import warnings
  12. from abc import ABC, abstractmethod
  13. from collections.abc import Generator
  14. from typing import Any, Callable, Optional
  15. import astroid
  16. from astroid import nodes, util
  17. from pylint import constants
  18. from pylint.pyreverse import utils
  19. _WrapperFuncT = Callable[[Callable[[str], nodes.Module], str], Optional[nodes.Module]]
  20. def _astroid_wrapper(
  21. func: Callable[[str], nodes.Module], modname: str
  22. ) -> nodes.Module | None:
  23. print(f"parsing {modname}...")
  24. try:
  25. return func(modname)
  26. except astroid.exceptions.AstroidBuildingException as exc:
  27. print(exc)
  28. except Exception: # pylint: disable=broad-except
  29. traceback.print_exc()
  30. return None
  31. def interfaces(node: nodes.ClassDef) -> Generator[Any, None, None]:
  32. """Return an iterator on interfaces implemented by the given class node."""
  33. try:
  34. implements = astroid.bases.Instance(node).getattr("__implements__")[0]
  35. except astroid.exceptions.NotFoundError:
  36. return
  37. if implements.frame(future=True) is not node:
  38. return
  39. found = set()
  40. missing = False
  41. for iface in nodes.unpack_infer(implements):
  42. if isinstance(iface, util.UninferableBase):
  43. missing = True
  44. continue
  45. if iface not in found:
  46. found.add(iface)
  47. yield iface
  48. if missing:
  49. raise astroid.exceptions.InferenceError()
  50. class IdGeneratorMixIn:
  51. """Mixin adding the ability to generate integer uid."""
  52. def __init__(self, start_value: int = 0) -> None:
  53. self.id_count = start_value
  54. def init_counter(self, start_value: int = 0) -> None:
  55. """Init the id counter."""
  56. self.id_count = start_value
  57. def generate_id(self) -> int:
  58. """Generate a new identifier."""
  59. self.id_count += 1
  60. return self.id_count
  61. class Project:
  62. """A project handle a set of modules / packages."""
  63. def __init__(self, name: str = ""):
  64. self.name = name
  65. self.uid: int | None = None
  66. self.path: str = ""
  67. self.modules: list[nodes.Module] = []
  68. self.locals: dict[str, nodes.Module] = {}
  69. self.__getitem__ = self.locals.__getitem__
  70. self.__iter__ = self.locals.__iter__
  71. self.values = self.locals.values
  72. self.keys = self.locals.keys
  73. self.items = self.locals.items
  74. def add_module(self, node: nodes.Module) -> None:
  75. self.locals[node.name] = node
  76. self.modules.append(node)
  77. def get_module(self, name: str) -> nodes.Module:
  78. return self.locals[name]
  79. def get_children(self) -> list[nodes.Module]:
  80. return self.modules
  81. def __repr__(self) -> str:
  82. return f"<Project {self.name!r} at {id(self)} ({len(self.modules)} modules)>"
  83. class Linker(IdGeneratorMixIn, utils.LocalsVisitor):
  84. """Walk on the project tree and resolve relationships.
  85. According to options the following attributes may be
  86. added to visited nodes:
  87. * uid,
  88. a unique identifier for the node (on astroid.Project, astroid.Module,
  89. astroid.Class and astroid.locals_type). Only if the linker
  90. has been instantiated with tag=True parameter (False by default).
  91. * Function
  92. a mapping from locals names to their bounded value, which may be a
  93. constant like a string or an integer, or an astroid node
  94. (on astroid.Module, astroid.Class and astroid.Function).
  95. * instance_attrs_type
  96. as locals_type but for klass member attributes (only on astroid.Class)
  97. * associations_type
  98. as instance_attrs_type but for association relationships
  99. * aggregations_type
  100. as instance_attrs_type but for aggregations relationships
  101. * implements,
  102. list of implemented interface _objects_ (only on astroid.Class nodes)
  103. """
  104. def __init__(self, project: Project, tag: bool = False) -> None:
  105. IdGeneratorMixIn.__init__(self)
  106. utils.LocalsVisitor.__init__(self)
  107. # tag nodes or not
  108. self.tag = tag
  109. # visited project
  110. self.project = project
  111. self.associations_handler = AggregationsHandler()
  112. self.associations_handler.set_next(OtherAssociationsHandler())
  113. def visit_project(self, node: Project) -> None:
  114. """Visit a pyreverse.utils.Project node.
  115. * optionally tag the node with a unique id
  116. """
  117. if self.tag:
  118. node.uid = self.generate_id()
  119. for module in node.modules:
  120. self.visit(module)
  121. def visit_module(self, node: nodes.Module) -> None:
  122. """Visit an astroid.Module node.
  123. * set the locals_type mapping
  124. * set the depends mapping
  125. * optionally tag the node with a unique id
  126. """
  127. if hasattr(node, "locals_type"):
  128. return
  129. node.locals_type = collections.defaultdict(list)
  130. node.depends = []
  131. if self.tag:
  132. node.uid = self.generate_id()
  133. def visit_classdef(self, node: nodes.ClassDef) -> None:
  134. """Visit an astroid.Class node.
  135. * set the locals_type and instance_attrs_type mappings
  136. * set the implements list and build it
  137. * optionally tag the node with a unique id
  138. """
  139. if hasattr(node, "locals_type"):
  140. return
  141. node.locals_type = collections.defaultdict(list)
  142. if self.tag:
  143. node.uid = self.generate_id()
  144. # resolve ancestors
  145. for baseobj in node.ancestors(recurs=False):
  146. specializations = getattr(baseobj, "specializations", [])
  147. specializations.append(node)
  148. baseobj.specializations = specializations
  149. # resolve instance attributes
  150. node.instance_attrs_type = collections.defaultdict(list)
  151. node.aggregations_type = collections.defaultdict(list)
  152. node.associations_type = collections.defaultdict(list)
  153. for assignattrs in tuple(node.instance_attrs.values()):
  154. for assignattr in assignattrs:
  155. if not isinstance(assignattr, nodes.Unknown):
  156. self.associations_handler.handle(assignattr, node)
  157. self.handle_assignattr_type(assignattr, node)
  158. # resolve implemented interface
  159. try:
  160. ifaces = interfaces(node)
  161. if ifaces is not None:
  162. node.implements = list(ifaces)
  163. if node.implements:
  164. # TODO: 3.0: Remove support for __implements__
  165. warnings.warn(
  166. "pyreverse will drop support for resolving and displaying "
  167. "implemented interfaces in pylint 3.0. The implementation "
  168. "relies on the '__implements__' attribute proposed in PEP 245"
  169. ", which was rejected in 2006.",
  170. DeprecationWarning,
  171. )
  172. else:
  173. node.implements = []
  174. except astroid.InferenceError:
  175. node.implements = []
  176. def visit_functiondef(self, node: nodes.FunctionDef) -> None:
  177. """Visit an astroid.Function node.
  178. * set the locals_type mapping
  179. * optionally tag the node with a unique id
  180. """
  181. if hasattr(node, "locals_type"):
  182. return
  183. node.locals_type = collections.defaultdict(list)
  184. if self.tag:
  185. node.uid = self.generate_id()
  186. def visit_assignname(self, node: nodes.AssignName) -> None:
  187. """Visit an astroid.AssignName node.
  188. handle locals_type
  189. """
  190. # avoid double parsing done by different Linkers.visit
  191. # running over the same project:
  192. if hasattr(node, "_handled"):
  193. return
  194. node._handled = True
  195. if node.name in node.frame(future=True):
  196. frame = node.frame(future=True)
  197. else:
  198. # the name has been defined as 'global' in the frame and belongs
  199. # there.
  200. frame = node.root()
  201. if not hasattr(frame, "locals_type"):
  202. # If the frame doesn't have a locals_type yet,
  203. # it means it wasn't yet visited. Visit it now
  204. # to add what's missing from it.
  205. if isinstance(frame, nodes.ClassDef):
  206. self.visit_classdef(frame)
  207. elif isinstance(frame, nodes.FunctionDef):
  208. self.visit_functiondef(frame)
  209. else:
  210. self.visit_module(frame)
  211. current = frame.locals_type[node.name]
  212. frame.locals_type[node.name] = list(set(current) | utils.infer_node(node))
  213. @staticmethod
  214. def handle_assignattr_type(node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
  215. """Handle an astroid.assignattr node.
  216. handle instance_attrs_type
  217. """
  218. current = set(parent.instance_attrs_type[node.attrname])
  219. parent.instance_attrs_type[node.attrname] = list(
  220. current | utils.infer_node(node)
  221. )
  222. def visit_import(self, node: nodes.Import) -> None:
  223. """Visit an astroid.Import node.
  224. resolve module dependencies
  225. """
  226. context_file = node.root().file
  227. for name in node.names:
  228. relative = astroid.modutils.is_relative(name[0], context_file)
  229. self._imported_module(node, name[0], relative)
  230. def visit_importfrom(self, node: nodes.ImportFrom) -> None:
  231. """Visit an astroid.ImportFrom node.
  232. resolve module dependencies
  233. """
  234. basename = node.modname
  235. context_file = node.root().file
  236. if context_file is not None:
  237. relative = astroid.modutils.is_relative(basename, context_file)
  238. else:
  239. relative = False
  240. for name in node.names:
  241. if name[0] == "*":
  242. continue
  243. # analyze dependencies
  244. fullname = f"{basename}.{name[0]}"
  245. if fullname.find(".") > -1:
  246. try:
  247. fullname = astroid.modutils.get_module_part(fullname, context_file)
  248. except ImportError:
  249. continue
  250. if fullname != basename:
  251. self._imported_module(node, fullname, relative)
  252. def compute_module(self, context_name: str, mod_path: str) -> bool:
  253. """Should the module be added to dependencies ?"""
  254. package_dir = os.path.dirname(self.project.path)
  255. if context_name == mod_path:
  256. return False
  257. # astroid does return a boolean but is not typed correctly yet
  258. return astroid.modutils.module_in_path(mod_path, (package_dir,)) # type: ignore[no-any-return]
  259. def _imported_module(
  260. self, node: nodes.Import | nodes.ImportFrom, mod_path: str, relative: bool
  261. ) -> None:
  262. """Notify an imported module, used to analyze dependencies."""
  263. module = node.root()
  264. context_name = module.name
  265. if relative:
  266. mod_path = f"{'.'.join(context_name.split('.')[:-1])}.{mod_path}"
  267. if self.compute_module(context_name, mod_path):
  268. # handle dependencies
  269. if not hasattr(module, "depends"):
  270. module.depends = []
  271. mod_paths = module.depends
  272. if mod_path not in mod_paths:
  273. mod_paths.append(mod_path)
  274. class AssociationHandlerInterface(ABC):
  275. @abstractmethod
  276. def set_next(
  277. self, handler: AssociationHandlerInterface
  278. ) -> AssociationHandlerInterface:
  279. pass
  280. @abstractmethod
  281. def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
  282. pass
  283. class AbstractAssociationHandler(AssociationHandlerInterface):
  284. """
  285. Chain of Responsibility for handling types of association, useful
  286. to expand in the future if we want to add more distinct associations.
  287. Every link of the chain checks if it's a certain type of association.
  288. If no association is found it's set as a generic association in `associations_type`.
  289. The default chaining behavior is implemented inside the base handler
  290. class.
  291. """
  292. _next_handler: AssociationHandlerInterface
  293. def set_next(
  294. self, handler: AssociationHandlerInterface
  295. ) -> AssociationHandlerInterface:
  296. self._next_handler = handler
  297. return handler
  298. @abstractmethod
  299. def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
  300. if self._next_handler:
  301. self._next_handler.handle(node, parent)
  302. class AggregationsHandler(AbstractAssociationHandler):
  303. def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
  304. if isinstance(node.parent, (nodes.AnnAssign, nodes.Assign)) and isinstance(
  305. node.parent.value, astroid.node_classes.Name
  306. ):
  307. current = set(parent.aggregations_type[node.attrname])
  308. parent.aggregations_type[node.attrname] = list(
  309. current | utils.infer_node(node)
  310. )
  311. else:
  312. super().handle(node, parent)
  313. class OtherAssociationsHandler(AbstractAssociationHandler):
  314. def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
  315. current = set(parent.associations_type[node.attrname])
  316. parent.associations_type[node.attrname] = list(current | utils.infer_node(node))
  317. def project_from_files(
  318. files: list[str],
  319. func_wrapper: _WrapperFuncT = _astroid_wrapper,
  320. project_name: str = "no name",
  321. black_list: tuple[str, ...] = constants.DEFAULT_IGNORE_LIST,
  322. ) -> Project:
  323. """Return a Project from a list of files or modules."""
  324. # build the project representation
  325. astroid_manager = astroid.MANAGER
  326. project = Project(project_name)
  327. for something in files:
  328. if not os.path.exists(something):
  329. fpath = astroid.modutils.file_from_modpath(something.split("."))
  330. elif os.path.isdir(something):
  331. fpath = os.path.join(something, "__init__.py")
  332. else:
  333. fpath = something
  334. ast = func_wrapper(astroid_manager.ast_from_file, fpath)
  335. if ast is None:
  336. continue
  337. project.path = project.path or ast.file
  338. project.add_module(ast)
  339. base_name = ast.name
  340. # recurse in package except if __init__ was explicitly given
  341. if ast.package and something.find("__init__") == -1:
  342. # recurse on others packages / modules if this is a package
  343. for fpath in astroid.modutils.get_module_files(
  344. os.path.dirname(ast.file), black_list
  345. ):
  346. ast = func_wrapper(astroid_manager.ast_from_file, fpath)
  347. if ast is None or ast.name == base_name:
  348. continue
  349. project.add_module(ast)
  350. return project