indirection.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from __future__ import annotations
  2. from typing import Iterable, Set
  3. import mypy.types as types
  4. from mypy.types import TypeVisitor
  5. from mypy.util import split_module_names
  6. def extract_module_names(type_name: str | None) -> list[str]:
  7. """Returns the module names of a fully qualified type name."""
  8. if type_name is not None:
  9. # Discard the first one, which is just the qualified name of the type
  10. possible_module_names = split_module_names(type_name)
  11. return possible_module_names[1:]
  12. else:
  13. return []
  14. class TypeIndirectionVisitor(TypeVisitor[Set[str]]):
  15. """Returns all module references within a particular type."""
  16. def __init__(self) -> None:
  17. self.cache: dict[types.Type, set[str]] = {}
  18. self.seen_aliases: set[types.TypeAliasType] = set()
  19. def find_modules(self, typs: Iterable[types.Type]) -> set[str]:
  20. self.seen_aliases.clear()
  21. return self._visit(typs)
  22. def _visit(self, typ_or_typs: types.Type | Iterable[types.Type]) -> set[str]:
  23. typs = [typ_or_typs] if isinstance(typ_or_typs, types.Type) else typ_or_typs
  24. output: set[str] = set()
  25. for typ in typs:
  26. if isinstance(typ, types.TypeAliasType):
  27. # Avoid infinite recursion for recursive type aliases.
  28. if typ in self.seen_aliases:
  29. continue
  30. self.seen_aliases.add(typ)
  31. if typ in self.cache:
  32. modules = self.cache[typ]
  33. else:
  34. modules = typ.accept(self)
  35. self.cache[typ] = set(modules)
  36. output.update(modules)
  37. return output
  38. def visit_unbound_type(self, t: types.UnboundType) -> set[str]:
  39. return self._visit(t.args)
  40. def visit_any(self, t: types.AnyType) -> set[str]:
  41. return set()
  42. def visit_none_type(self, t: types.NoneType) -> set[str]:
  43. return set()
  44. def visit_uninhabited_type(self, t: types.UninhabitedType) -> set[str]:
  45. return set()
  46. def visit_erased_type(self, t: types.ErasedType) -> set[str]:
  47. return set()
  48. def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
  49. return set()
  50. def visit_type_var(self, t: types.TypeVarType) -> set[str]:
  51. return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)
  52. def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
  53. return self._visit(t.upper_bound) | self._visit(t.default)
  54. def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
  55. return self._visit(t.upper_bound) | self._visit(t.default)
  56. def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
  57. return t.type.accept(self)
  58. def visit_parameters(self, t: types.Parameters) -> set[str]:
  59. return self._visit(t.arg_types)
  60. def visit_instance(self, t: types.Instance) -> set[str]:
  61. out = self._visit(t.args)
  62. if t.type:
  63. # Uses of a class depend on everything in the MRO,
  64. # as changes to classes in the MRO can add types to methods,
  65. # change property types, change the MRO itself, etc.
  66. for s in t.type.mro:
  67. out.update(split_module_names(s.module_name))
  68. if t.type.metaclass_type is not None:
  69. out.update(split_module_names(t.type.metaclass_type.type.module_name))
  70. return out
  71. def visit_callable_type(self, t: types.CallableType) -> set[str]:
  72. out = self._visit(t.arg_types) | self._visit(t.ret_type)
  73. if t.definition is not None:
  74. out.update(extract_module_names(t.definition.fullname))
  75. return out
  76. def visit_overloaded(self, t: types.Overloaded) -> set[str]:
  77. return self._visit(t.items) | self._visit(t.fallback)
  78. def visit_tuple_type(self, t: types.TupleType) -> set[str]:
  79. return self._visit(t.items) | self._visit(t.partial_fallback)
  80. def visit_typeddict_type(self, t: types.TypedDictType) -> set[str]:
  81. return self._visit(t.items.values()) | self._visit(t.fallback)
  82. def visit_literal_type(self, t: types.LiteralType) -> set[str]:
  83. return self._visit(t.fallback)
  84. def visit_union_type(self, t: types.UnionType) -> set[str]:
  85. return self._visit(t.items)
  86. def visit_partial_type(self, t: types.PartialType) -> set[str]:
  87. return set()
  88. def visit_type_type(self, t: types.TypeType) -> set[str]:
  89. return self._visit(t.item)
  90. def visit_type_alias_type(self, t: types.TypeAliasType) -> set[str]:
  91. return self._visit(types.get_proper_type(t))