astdiff.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. """Utilities for comparing two versions of a module symbol table.
  2. The goal is to find which AST nodes have externally visible changes, so
  3. that we can fire triggers and re-process other parts of the program
  4. that are stale because of the changes.
  5. Only look at detail at definitions at the current module -- don't
  6. recurse into other modules.
  7. A summary of the module contents:
  8. * snapshot_symbol_table(...) creates an opaque snapshot description of a
  9. module/class symbol table (recursing into nested class symbol tables).
  10. * compare_symbol_table_snapshots(...) compares two snapshots for the same
  11. module id and returns fully qualified names of differences (which act as
  12. triggers).
  13. To compare two versions of a module symbol table, take snapshots of both
  14. versions and compare the snapshots. The use of snapshots makes it easy to
  15. compare two versions of the *same* symbol table that is being mutated.
  16. Summary of how this works for certain kinds of differences:
  17. * If a symbol table node is deleted or added (only present in old/new version
  18. of the symbol table), it is considered different, of course.
  19. * If a symbol table node refers to a different sort of thing in the new version,
  20. it is considered different (for example, if a class is replaced with a
  21. function).
  22. * If the signature of a function has changed, it is considered different.
  23. * If the type of a variable changes, it is considered different.
  24. * If the MRO of a class changes, or a non-generic class is turned into a
  25. generic class, the class is considered different (there are other such "big"
  26. differences that cause a class to be considered changed). However, just changes
  27. to attributes or methods don't generally constitute a difference at the
  28. class level -- these are handled at attribute level (say, 'mod.Cls.method'
  29. is different rather than 'mod.Cls' being different).
  30. * If an imported name targets a different name (say, 'from x import y' is
  31. replaced with 'from z import y'), the name in the module is considered
  32. different. If the target of an import continues to have the same name,
  33. but it's specifics change, this doesn't mean that the imported name is
  34. treated as changed. Say, there is 'from x import y' in 'm', and the
  35. type of 'x.y' has changed. This doesn't mean that that 'm.y' is considered
  36. changed. Instead, processing the difference in 'm' will be handled through
  37. fine-grained dependencies.
  38. """
  39. from __future__ import annotations
  40. from typing import Sequence, Tuple, Union
  41. from typing_extensions import TypeAlias as _TypeAlias
  42. from mypy.expandtype import expand_type
  43. from mypy.nodes import (
  44. UNBOUND_IMPORTED,
  45. Decorator,
  46. FuncBase,
  47. FuncDef,
  48. FuncItem,
  49. MypyFile,
  50. OverloadedFuncDef,
  51. ParamSpecExpr,
  52. SymbolNode,
  53. SymbolTable,
  54. TypeAlias,
  55. TypeInfo,
  56. TypeVarExpr,
  57. TypeVarTupleExpr,
  58. Var,
  59. )
  60. from mypy.semanal_shared import find_dataclass_transform_spec
  61. from mypy.types import (
  62. AnyType,
  63. CallableType,
  64. DeletedType,
  65. ErasedType,
  66. Instance,
  67. LiteralType,
  68. NoneType,
  69. Overloaded,
  70. Parameters,
  71. ParamSpecType,
  72. PartialType,
  73. TupleType,
  74. Type,
  75. TypeAliasType,
  76. TypedDictType,
  77. TypeType,
  78. TypeVarId,
  79. TypeVarLikeType,
  80. TypeVarTupleType,
  81. TypeVarType,
  82. TypeVisitor,
  83. UnboundType,
  84. UninhabitedType,
  85. UnionType,
  86. UnpackType,
  87. )
  88. from mypy.util import get_prefix
  89. # Snapshot representation of a symbol table node or type. The representation is
  90. # opaque -- the only supported operations are comparing for equality and
  91. # hashing (latter for type snapshots only). Snapshots can contain primitive
  92. # objects, nested tuples, lists and dictionaries and primitive objects (type
  93. # snapshots are immutable).
  94. #
  95. # For example, the snapshot of the 'int' type is ('Instance', 'builtins.int', ()).
  96. # Type snapshots are strict, they must be hashable and ordered (e.g. for Unions).
  97. Primitive: _TypeAlias = Union[str, float, int, bool] # float is for Literal[3.14] support.
  98. SnapshotItem: _TypeAlias = Tuple[Union[Primitive, "SnapshotItem"], ...]
  99. # Symbol snapshots can be more lenient.
  100. SymbolSnapshot: _TypeAlias = Tuple[object, ...]
  101. def compare_symbol_table_snapshots(
  102. name_prefix: str, snapshot1: dict[str, SymbolSnapshot], snapshot2: dict[str, SymbolSnapshot]
  103. ) -> set[str]:
  104. """Return names that are different in two snapshots of a symbol table.
  105. Only shallow (intra-module) differences are considered. References to things defined
  106. outside the module are compared based on the name of the target only.
  107. Recurse into class symbol tables (if the class is defined in the target module).
  108. Return a set of fully-qualified names (e.g., 'mod.func' or 'mod.Class.method').
  109. """
  110. # Find names only defined only in one version.
  111. names1 = {f"{name_prefix}.{name}" for name in snapshot1}
  112. names2 = {f"{name_prefix}.{name}" for name in snapshot2}
  113. triggers = names1 ^ names2
  114. # Look for names defined in both versions that are different.
  115. for name in set(snapshot1.keys()) & set(snapshot2.keys()):
  116. item1 = snapshot1[name]
  117. item2 = snapshot2[name]
  118. kind1 = item1[0]
  119. kind2 = item2[0]
  120. item_name = f"{name_prefix}.{name}"
  121. if kind1 != kind2:
  122. # Different kind of node in two snapshots -> trivially different.
  123. triggers.add(item_name)
  124. elif kind1 == "TypeInfo":
  125. if item1[:-1] != item2[:-1]:
  126. # Record major difference (outside class symbol tables).
  127. triggers.add(item_name)
  128. # Look for differences in nested class symbol table entries.
  129. assert isinstance(item1[-1], dict)
  130. assert isinstance(item2[-1], dict)
  131. triggers |= compare_symbol_table_snapshots(item_name, item1[-1], item2[-1])
  132. else:
  133. # Shallow node (no interesting internal structure). Just use equality.
  134. if snapshot1[name] != snapshot2[name]:
  135. triggers.add(item_name)
  136. return triggers
  137. def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, SymbolSnapshot]:
  138. """Create a snapshot description that represents the state of a symbol table.
  139. The snapshot has a representation based on nested tuples and dicts
  140. that makes it easy and fast to find differences.
  141. Only "shallow" state is included in the snapshot -- references to
  142. things defined in other modules are represented just by the names of
  143. the targets.
  144. """
  145. result: dict[str, SymbolSnapshot] = {}
  146. for name, symbol in table.items():
  147. node = symbol.node
  148. # TODO: cross_ref?
  149. fullname = node.fullname if node else None
  150. common = (fullname, symbol.kind, symbol.module_public)
  151. if isinstance(node, MypyFile):
  152. # This is a cross-reference to another module.
  153. # If the reference is busted because the other module is missing,
  154. # the node will be a "stale_info" TypeInfo produced by fixup,
  155. # but that doesn't really matter to us here.
  156. result[name] = ("Moduleref", common)
  157. elif isinstance(node, TypeVarExpr):
  158. result[name] = (
  159. "TypeVar",
  160. node.variance,
  161. [snapshot_type(value) for value in node.values],
  162. snapshot_type(node.upper_bound),
  163. snapshot_type(node.default),
  164. )
  165. elif isinstance(node, TypeAlias):
  166. result[name] = (
  167. "TypeAlias",
  168. snapshot_types(node.alias_tvars),
  169. node.normalized,
  170. node.no_args,
  171. snapshot_optional_type(node.target),
  172. )
  173. elif isinstance(node, ParamSpecExpr):
  174. result[name] = (
  175. "ParamSpec",
  176. node.variance,
  177. snapshot_type(node.upper_bound),
  178. snapshot_type(node.default),
  179. )
  180. elif isinstance(node, TypeVarTupleExpr):
  181. result[name] = (
  182. "TypeVarTuple",
  183. node.variance,
  184. snapshot_type(node.upper_bound),
  185. snapshot_type(node.default),
  186. )
  187. else:
  188. assert symbol.kind != UNBOUND_IMPORTED
  189. if node and get_prefix(node.fullname) != name_prefix:
  190. # This is a cross-reference to a node defined in another module.
  191. result[name] = ("CrossRef", common)
  192. else:
  193. result[name] = snapshot_definition(node, common)
  194. return result
  195. def snapshot_definition(node: SymbolNode | None, common: SymbolSnapshot) -> SymbolSnapshot:
  196. """Create a snapshot description of a symbol table node.
  197. The representation is nested tuples and dicts. Only externally
  198. visible attributes are included.
  199. """
  200. if isinstance(node, FuncBase):
  201. # TODO: info
  202. if node.type:
  203. signature = snapshot_type(node.type)
  204. else:
  205. signature = snapshot_untyped_signature(node)
  206. impl: FuncDef | None = None
  207. if isinstance(node, FuncDef):
  208. impl = node
  209. elif isinstance(node, OverloadedFuncDef) and node.impl:
  210. impl = node.impl.func if isinstance(node.impl, Decorator) else node.impl
  211. is_trivial_body = impl.is_trivial_body if impl else False
  212. dataclass_transform_spec = find_dataclass_transform_spec(node)
  213. return (
  214. "Func",
  215. common,
  216. node.is_property,
  217. node.is_final,
  218. node.is_class,
  219. node.is_static,
  220. signature,
  221. is_trivial_body,
  222. dataclass_transform_spec.serialize() if dataclass_transform_spec is not None else None,
  223. )
  224. elif isinstance(node, Var):
  225. return ("Var", common, snapshot_optional_type(node.type), node.is_final)
  226. elif isinstance(node, Decorator):
  227. # Note that decorated methods are represented by Decorator instances in
  228. # a symbol table since we need to preserve information about the
  229. # decorated function (whether it's a class function, for
  230. # example). Top-level decorated functions, however, are represented by
  231. # the corresponding Var node, since that happens to provide enough
  232. # context.
  233. return (
  234. "Decorator",
  235. node.is_overload,
  236. snapshot_optional_type(node.var.type),
  237. snapshot_definition(node.func, common),
  238. )
  239. elif isinstance(node, TypeInfo):
  240. dataclass_transform_spec = node.dataclass_transform_spec
  241. if dataclass_transform_spec is None:
  242. dataclass_transform_spec = find_dataclass_transform_spec(node)
  243. attrs = (
  244. node.is_abstract,
  245. node.is_enum,
  246. node.is_protocol,
  247. node.fallback_to_any,
  248. node.meta_fallback_to_any,
  249. node.is_named_tuple,
  250. node.is_newtype,
  251. # We need this to e.g. trigger metaclass calculation in subclasses.
  252. snapshot_optional_type(node.metaclass_type),
  253. snapshot_optional_type(node.tuple_type),
  254. snapshot_optional_type(node.typeddict_type),
  255. [base.fullname for base in node.mro],
  256. # Note that the structure of type variables is a part of the external interface,
  257. # since creating instances might fail, for example:
  258. # T = TypeVar('T', bound=int)
  259. # class C(Generic[T]):
  260. # ...
  261. # x: C[str] <- this is invalid, and needs to be re-checked if `T` changes.
  262. # An alternative would be to create both deps: <...> -> C, and <...> -> <C>,
  263. # but this currently seems a bit ad hoc.
  264. tuple(snapshot_type(tdef) for tdef in node.defn.type_vars),
  265. [snapshot_type(base) for base in node.bases],
  266. [snapshot_type(p) for p in node._promote],
  267. dataclass_transform_spec.serialize() if dataclass_transform_spec is not None else None,
  268. )
  269. prefix = node.fullname
  270. symbol_table = snapshot_symbol_table(prefix, node.names)
  271. # Special dependency for abstract attribute handling.
  272. symbol_table["(abstract)"] = ("Abstract", tuple(sorted(node.abstract_attributes)))
  273. return ("TypeInfo", common, attrs, symbol_table)
  274. else:
  275. # Other node types are handled elsewhere.
  276. assert False, type(node)
  277. def snapshot_type(typ: Type) -> SnapshotItem:
  278. """Create a snapshot representation of a type using nested tuples."""
  279. return typ.accept(SnapshotTypeVisitor())
  280. def snapshot_optional_type(typ: Type | None) -> SnapshotItem:
  281. if typ:
  282. return snapshot_type(typ)
  283. else:
  284. return ("<not set>",)
  285. def snapshot_types(types: Sequence[Type]) -> SnapshotItem:
  286. return tuple(snapshot_type(item) for item in types)
  287. def snapshot_simple_type(typ: Type) -> SnapshotItem:
  288. return (type(typ).__name__,)
  289. def encode_optional_str(s: str | None) -> str:
  290. if s is None:
  291. return "<None>"
  292. else:
  293. return s
  294. class SnapshotTypeVisitor(TypeVisitor[SnapshotItem]):
  295. """Creates a read-only, self-contained snapshot of a type object.
  296. Properties of a snapshot:
  297. - Contains (nested) tuples and other immutable primitive objects only.
  298. - References to AST nodes are replaced with full names of targets.
  299. - Has no references to mutable or non-primitive objects.
  300. - Two snapshots represent the same object if and only if they are
  301. equal.
  302. - Results must be sortable. It's important that tuples have
  303. consistent types and can't arbitrarily mix str and None values,
  304. for example, since they can't be compared.
  305. """
  306. def visit_unbound_type(self, typ: UnboundType) -> SnapshotItem:
  307. return (
  308. "UnboundType",
  309. typ.name,
  310. typ.optional,
  311. typ.empty_tuple_index,
  312. snapshot_types(typ.args),
  313. )
  314. def visit_any(self, typ: AnyType) -> SnapshotItem:
  315. return snapshot_simple_type(typ)
  316. def visit_none_type(self, typ: NoneType) -> SnapshotItem:
  317. return snapshot_simple_type(typ)
  318. def visit_uninhabited_type(self, typ: UninhabitedType) -> SnapshotItem:
  319. return snapshot_simple_type(typ)
  320. def visit_erased_type(self, typ: ErasedType) -> SnapshotItem:
  321. return snapshot_simple_type(typ)
  322. def visit_deleted_type(self, typ: DeletedType) -> SnapshotItem:
  323. return snapshot_simple_type(typ)
  324. def visit_instance(self, typ: Instance) -> SnapshotItem:
  325. return (
  326. "Instance",
  327. encode_optional_str(typ.type.fullname),
  328. snapshot_types(typ.args),
  329. ("None",) if typ.last_known_value is None else snapshot_type(typ.last_known_value),
  330. )
  331. def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:
  332. return (
  333. "TypeVar",
  334. typ.name,
  335. typ.fullname,
  336. typ.id.raw_id,
  337. typ.id.meta_level,
  338. snapshot_types(typ.values),
  339. snapshot_type(typ.upper_bound),
  340. snapshot_type(typ.default),
  341. typ.variance,
  342. )
  343. def visit_param_spec(self, typ: ParamSpecType) -> SnapshotItem:
  344. return (
  345. "ParamSpec",
  346. typ.id.raw_id,
  347. typ.id.meta_level,
  348. typ.flavor,
  349. snapshot_type(typ.upper_bound),
  350. snapshot_type(typ.default),
  351. )
  352. def visit_type_var_tuple(self, typ: TypeVarTupleType) -> SnapshotItem:
  353. return (
  354. "TypeVarTupleType",
  355. typ.id.raw_id,
  356. typ.id.meta_level,
  357. snapshot_type(typ.upper_bound),
  358. snapshot_type(typ.default),
  359. )
  360. def visit_unpack_type(self, typ: UnpackType) -> SnapshotItem:
  361. return ("UnpackType", snapshot_type(typ.type))
  362. def visit_parameters(self, typ: Parameters) -> SnapshotItem:
  363. return (
  364. "Parameters",
  365. snapshot_types(typ.arg_types),
  366. tuple(encode_optional_str(name) for name in typ.arg_names),
  367. tuple(k.value for k in typ.arg_kinds),
  368. )
  369. def visit_callable_type(self, typ: CallableType) -> SnapshotItem:
  370. if typ.is_generic():
  371. typ = self.normalize_callable_variables(typ)
  372. return (
  373. "CallableType",
  374. snapshot_types(typ.arg_types),
  375. snapshot_type(typ.ret_type),
  376. tuple(encode_optional_str(name) for name in typ.arg_names),
  377. tuple(k.value for k in typ.arg_kinds),
  378. typ.is_type_obj(),
  379. typ.is_ellipsis_args,
  380. snapshot_types(typ.variables),
  381. )
  382. def normalize_callable_variables(self, typ: CallableType) -> CallableType:
  383. """Normalize all type variable ids to run from -1 to -len(variables)."""
  384. tvs = []
  385. tvmap: dict[TypeVarId, Type] = {}
  386. for i, v in enumerate(typ.variables):
  387. tid = TypeVarId(-1 - i)
  388. if isinstance(v, TypeVarType):
  389. tv: TypeVarLikeType = v.copy_modified(id=tid)
  390. elif isinstance(v, TypeVarTupleType):
  391. tv = v.copy_modified(id=tid)
  392. else:
  393. assert isinstance(v, ParamSpecType)
  394. tv = v.copy_modified(id=tid)
  395. tvs.append(tv)
  396. tvmap[v.id] = tv
  397. return expand_type(typ, tvmap).copy_modified(variables=tvs)
  398. def visit_tuple_type(self, typ: TupleType) -> SnapshotItem:
  399. return ("TupleType", snapshot_types(typ.items))
  400. def visit_typeddict_type(self, typ: TypedDictType) -> SnapshotItem:
  401. items = tuple((key, snapshot_type(item_type)) for key, item_type in typ.items.items())
  402. required = tuple(sorted(typ.required_keys))
  403. return ("TypedDictType", items, required)
  404. def visit_literal_type(self, typ: LiteralType) -> SnapshotItem:
  405. return ("LiteralType", snapshot_type(typ.fallback), typ.value)
  406. def visit_union_type(self, typ: UnionType) -> SnapshotItem:
  407. # Sort and remove duplicates so that we can use equality to test for
  408. # equivalent union type snapshots.
  409. items = {snapshot_type(item) for item in typ.items}
  410. normalized = tuple(sorted(items))
  411. return ("UnionType", normalized)
  412. def visit_overloaded(self, typ: Overloaded) -> SnapshotItem:
  413. return ("Overloaded", snapshot_types(typ.items))
  414. def visit_partial_type(self, typ: PartialType) -> SnapshotItem:
  415. # A partial type is not fully defined, so the result is indeterminate. We shouldn't
  416. # get here.
  417. raise RuntimeError
  418. def visit_type_type(self, typ: TypeType) -> SnapshotItem:
  419. return ("TypeType", snapshot_type(typ.item))
  420. def visit_type_alias_type(self, typ: TypeAliasType) -> SnapshotItem:
  421. assert typ.alias is not None
  422. return ("TypeAliasType", typ.alias.fullname, snapshot_types(typ.args))
  423. def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> SymbolSnapshot:
  424. """Create a snapshot of the signature of a function that has no explicit signature.
  425. If the arguments to a function without signature change, it must be
  426. considered as different. We have this special casing since we don't store
  427. the implicit signature anywhere, and we'd rather not construct new
  428. Callable objects in this module (the idea is to only read properties of
  429. the AST here).
  430. """
  431. if isinstance(func, FuncItem):
  432. return (tuple(func.arg_names), tuple(func.arg_kinds))
  433. else:
  434. result: list[SymbolSnapshot] = []
  435. for item in func.items:
  436. if isinstance(item, Decorator):
  437. if item.var.type:
  438. result.append(snapshot_type(item.var.type))
  439. else:
  440. result.append(("DecoratorWithoutType",))
  441. else:
  442. result.append(snapshot_untyped_signature(item))
  443. return tuple(result)