mro.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from __future__ import annotations
  2. from typing import Callable
  3. from mypy.nodes import TypeInfo
  4. from mypy.types import Instance
  5. from mypy.typestate import type_state
  6. def calculate_mro(info: TypeInfo, obj_type: Callable[[], Instance] | None = None) -> None:
  7. """Calculate and set mro (method resolution order).
  8. Raise MroError if cannot determine mro.
  9. """
  10. mro = linearize_hierarchy(info, obj_type)
  11. assert mro, f"Could not produce a MRO at all for {info}"
  12. info.mro = mro
  13. # The property of falling back to Any is inherited.
  14. info.fallback_to_any = any(baseinfo.fallback_to_any for baseinfo in info.mro)
  15. type_state.reset_all_subtype_caches_for(info)
  16. class MroError(Exception):
  17. """Raised if a consistent mro cannot be determined for a class."""
  18. def linearize_hierarchy(
  19. info: TypeInfo, obj_type: Callable[[], Instance] | None = None
  20. ) -> list[TypeInfo]:
  21. # TODO describe
  22. if info.mro:
  23. return info.mro
  24. bases = info.direct_base_classes()
  25. if not bases and info.fullname != "builtins.object" and obj_type is not None:
  26. # Probably an error, add a dummy `object` base class,
  27. # otherwise MRO calculation may spuriously fail.
  28. bases = [obj_type().type]
  29. lin_bases = []
  30. for base in bases:
  31. assert base is not None, f"Cannot linearize bases for {info.fullname} {bases}"
  32. lin_bases.append(linearize_hierarchy(base, obj_type))
  33. lin_bases.append(bases)
  34. return [info] + merge(lin_bases)
  35. def merge(seqs: list[list[TypeInfo]]) -> list[TypeInfo]:
  36. seqs = [s.copy() for s in seqs]
  37. result: list[TypeInfo] = []
  38. while True:
  39. seqs = [s for s in seqs if s]
  40. if not seqs:
  41. return result
  42. for seq in seqs:
  43. head = seq[0]
  44. if not [s for s in seqs if head in s[1:]]:
  45. break
  46. else:
  47. raise MroError()
  48. result.append(head)
  49. for s in seqs:
  50. if s[0] is head:
  51. del s[0]