writer.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/PyCQA/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/PyCQA/pylint/blob/main/CONTRIBUTORS.txt
  4. """Utilities for creating VCG and Dot diagrams."""
  5. from __future__ import annotations
  6. import argparse
  7. import itertools
  8. import os
  9. from collections.abc import Iterable
  10. from astroid import modutils, nodes
  11. from pylint.pyreverse.diagrams import (
  12. ClassDiagram,
  13. ClassEntity,
  14. DiagramEntity,
  15. PackageDiagram,
  16. PackageEntity,
  17. )
  18. from pylint.pyreverse.printer import EdgeType, NodeProperties, NodeType, Printer
  19. from pylint.pyreverse.printer_factory import get_printer_for_filetype
  20. from pylint.pyreverse.utils import is_exception
  21. class DiagramWriter:
  22. """Base class for writing project diagrams."""
  23. def __init__(self, config: argparse.Namespace) -> None:
  24. self.config = config
  25. self.printer_class = get_printer_for_filetype(self.config.output_format)
  26. self.printer: Printer # defined in set_printer
  27. self.file_name = "" # defined in set_printer
  28. self.depth = self.config.max_color_depth
  29. # default colors are an adaption of the seaborn colorblind palette
  30. self.available_colors = itertools.cycle(self.config.color_palette)
  31. self.used_colors: dict[str, str] = {}
  32. def write(self, diadefs: Iterable[ClassDiagram | PackageDiagram]) -> None:
  33. """Write files for <project> according to <diadefs>."""
  34. for diagram in diadefs:
  35. basename = diagram.title.strip().replace("/", "_").replace(" ", "_")
  36. file_name = f"{basename}.{self.config.output_format}"
  37. if os.path.exists(self.config.output_directory):
  38. file_name = os.path.join(self.config.output_directory, file_name)
  39. self.set_printer(file_name, basename)
  40. if isinstance(diagram, PackageDiagram):
  41. self.write_packages(diagram)
  42. else:
  43. self.write_classes(diagram)
  44. self.save()
  45. def write_packages(self, diagram: PackageDiagram) -> None:
  46. """Write a package diagram."""
  47. # sorted to get predictable (hence testable) results
  48. for module in sorted(diagram.modules(), key=lambda x: x.title):
  49. module.fig_id = module.node.qname()
  50. self.printer.emit_node(
  51. module.fig_id,
  52. type_=NodeType.PACKAGE,
  53. properties=self.get_package_properties(module),
  54. )
  55. # package dependencies
  56. for rel in diagram.get_relationships("depends"):
  57. self.printer.emit_edge(
  58. rel.from_object.fig_id,
  59. rel.to_object.fig_id,
  60. type_=EdgeType.USES,
  61. )
  62. def write_classes(self, diagram: ClassDiagram) -> None:
  63. """Write a class diagram."""
  64. # sorted to get predictable (hence testable) results
  65. for obj in sorted(diagram.objects, key=lambda x: x.title): # type: ignore[no-any-return]
  66. obj.fig_id = obj.node.qname()
  67. type_ = NodeType.INTERFACE if obj.shape == "interface" else NodeType.CLASS
  68. self.printer.emit_node(
  69. obj.fig_id, type_=type_, properties=self.get_class_properties(obj)
  70. )
  71. # inheritance links
  72. for rel in diagram.get_relationships("specialization"):
  73. self.printer.emit_edge(
  74. rel.from_object.fig_id,
  75. rel.to_object.fig_id,
  76. type_=EdgeType.INHERITS,
  77. )
  78. # implementation links
  79. for rel in diagram.get_relationships("implements"):
  80. self.printer.emit_edge(
  81. rel.from_object.fig_id,
  82. rel.to_object.fig_id,
  83. type_=EdgeType.IMPLEMENTS,
  84. )
  85. # generate associations
  86. for rel in diagram.get_relationships("association"):
  87. self.printer.emit_edge(
  88. rel.from_object.fig_id,
  89. rel.to_object.fig_id,
  90. label=rel.name,
  91. type_=EdgeType.ASSOCIATION,
  92. )
  93. # generate aggregations
  94. for rel in diagram.get_relationships("aggregation"):
  95. self.printer.emit_edge(
  96. rel.from_object.fig_id,
  97. rel.to_object.fig_id,
  98. label=rel.name,
  99. type_=EdgeType.AGGREGATION,
  100. )
  101. def set_printer(self, file_name: str, basename: str) -> None:
  102. """Set printer."""
  103. self.printer = self.printer_class(basename)
  104. self.file_name = file_name
  105. def get_package_properties(self, obj: PackageEntity) -> NodeProperties:
  106. """Get label and shape for packages."""
  107. return NodeProperties(
  108. label=obj.title,
  109. color=self.get_shape_color(obj) if self.config.colorized else "black",
  110. )
  111. def get_class_properties(self, obj: ClassEntity) -> NodeProperties:
  112. """Get label and shape for classes."""
  113. properties = NodeProperties(
  114. label=obj.title,
  115. attrs=obj.attrs if not self.config.only_classnames else None,
  116. methods=obj.methods if not self.config.only_classnames else None,
  117. fontcolor="red" if is_exception(obj.node) else "black",
  118. color=self.get_shape_color(obj) if self.config.colorized else "black",
  119. )
  120. return properties
  121. def get_shape_color(self, obj: DiagramEntity) -> str:
  122. """Get shape color."""
  123. qualified_name = obj.node.qname()
  124. if modutils.is_stdlib_module(qualified_name.split(".", maxsplit=1)[0]):
  125. return "grey"
  126. if isinstance(obj.node, nodes.ClassDef):
  127. package = qualified_name.rsplit(".", maxsplit=2)[0]
  128. elif obj.node.package:
  129. package = qualified_name
  130. else:
  131. package = qualified_name.rsplit(".", maxsplit=1)[0]
  132. base_name = ".".join(package.split(".", self.depth)[: self.depth])
  133. if base_name not in self.used_colors:
  134. self.used_colors[base_name] = next(self.available_colors)
  135. return self.used_colors[base_name]
  136. def save(self) -> None:
  137. """Write to disk."""
  138. self.printer.generate(self.file_name)