solve.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. """Type inference constraint solving"""
  2. from __future__ import annotations
  3. from typing import Iterable
  4. from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op
  5. from mypy.expandtype import expand_type
  6. from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
  7. from mypy.join import join_types
  8. from mypy.meet import meet_types
  9. from mypy.subtypes import is_subtype
  10. from mypy.typeops import get_type_vars
  11. from mypy.types import (
  12. AnyType,
  13. ProperType,
  14. Type,
  15. TypeOfAny,
  16. TypeVarId,
  17. TypeVarType,
  18. UninhabitedType,
  19. UnionType,
  20. get_proper_type,
  21. remove_dups,
  22. )
  23. from mypy.typestate import type_state
  24. def solve_constraints(
  25. vars: list[TypeVarId],
  26. constraints: list[Constraint],
  27. strict: bool = True,
  28. allow_polymorphic: bool = False,
  29. ) -> list[Type | None]:
  30. """Solve type constraints.
  31. Return the best type(s) for type variables; each type can be None if the value of the variable
  32. could not be solved.
  33. If a variable has no constraints, if strict=True then arbitrarily
  34. pick NoneType as the value of the type variable. If strict=False,
  35. pick AnyType.
  36. """
  37. if not vars:
  38. return []
  39. if allow_polymorphic:
  40. # Constraints like T :> S and S <: T are semantically the same, but they are
  41. # represented differently. Normalize the constraint list w.r.t this equivalence.
  42. constraints = normalize_constraints(constraints, vars)
  43. # Collect a list of constraints for each type variable.
  44. cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars}
  45. for con in constraints:
  46. if con.type_var in vars:
  47. cmap[con.type_var].append(con)
  48. if allow_polymorphic:
  49. solutions = solve_non_linear(vars, constraints, cmap)
  50. else:
  51. solutions = {}
  52. for tv, cs in cmap.items():
  53. if not cs:
  54. continue
  55. lowers = [c.target for c in cs if c.op == SUPERTYPE_OF]
  56. uppers = [c.target for c in cs if c.op == SUBTYPE_OF]
  57. solutions[tv] = solve_one(lowers, uppers, [])
  58. res: list[Type | None] = []
  59. for v in vars:
  60. if v in solutions:
  61. res.append(solutions[v])
  62. else:
  63. # No constraints for type variable -- 'UninhabitedType' is the most specific type.
  64. candidate: Type
  65. if strict:
  66. candidate = UninhabitedType()
  67. candidate.ambiguous = True
  68. else:
  69. candidate = AnyType(TypeOfAny.special_form)
  70. res.append(candidate)
  71. return res
  72. def solve_non_linear(
  73. vars: list[TypeVarId], constraints: list[Constraint], cmap: dict[TypeVarId, list[Constraint]]
  74. ) -> dict[TypeVarId, Type | None]:
  75. """Solve set of constraints that may include non-linear ones, like T <: List[S].
  76. The whole algorithm consists of five steps:
  77. * Propagate via linear constraints to get all possible constraints for each variable
  78. * Find dependencies between type variables, group them in SCCs, and sort topologically
  79. * Check all SCC are intrinsically linear, we can't solve (express) T <: List[T]
  80. * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC)
  81. * Solve constraints iteratively starting from leafs, updating targets after each step.
  82. """
  83. extra_constraints = []
  84. for tvar in vars:
  85. extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap))
  86. extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap))
  87. constraints += remove_dups(extra_constraints)
  88. # Recompute constraint map after propagating.
  89. cmap = {tv: [] for tv in vars}
  90. for con in constraints:
  91. if con.type_var in vars:
  92. cmap[con.type_var].append(con)
  93. dmap = compute_dependencies(cmap)
  94. sccs = list(strongly_connected_components(set(vars), dmap))
  95. if all(check_linear(scc, cmap) for scc in sccs):
  96. raw_batches = list(topsort(prepare_sccs(sccs, dmap)))
  97. leafs = raw_batches[0]
  98. free_vars = []
  99. for scc in leafs:
  100. # If all constrain targets in this SCC are type variables within the
  101. # same SCC then the only meaningful solution we can express, is that
  102. # each variable is equal to a new free variable. For example if we
  103. # have T <: S, S <: U, we deduce: T = S = U = <free>.
  104. if all(
  105. isinstance(c.target, TypeVarType) and c.target.id in vars
  106. for tv in scc
  107. for c in cmap[tv]
  108. ):
  109. # For convenience with current type application machinery, we randomly
  110. # choose one of the existing type variables in SCC and designate it as free
  111. # instead of defining a new type variable as a common solution.
  112. # TODO: be careful about upper bounds (or values) when introducing free vars.
  113. free_vars.append(sorted(scc, key=lambda x: x.raw_id)[0])
  114. # Flatten the SCCs that are independent, we can solve them together,
  115. # since we don't need to update any targets in between.
  116. batches = []
  117. for batch in raw_batches:
  118. next_bc = []
  119. for scc in batch:
  120. next_bc.extend(list(scc))
  121. batches.append(next_bc)
  122. solutions: dict[TypeVarId, Type | None] = {}
  123. for flat_batch in batches:
  124. solutions.update(solve_iteratively(flat_batch, cmap, free_vars))
  125. # We remove the solutions like T = T for free variables. This will indicate
  126. # to the apply function, that they should not be touched.
  127. # TODO: return list of free type variables explicitly, this logic is fragile
  128. # (but if we do, we need to be careful everything works in incremental modes).
  129. for tv in free_vars:
  130. if tv in solutions:
  131. del solutions[tv]
  132. return solutions
  133. return {}
  134. def solve_iteratively(
  135. batch: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId]
  136. ) -> dict[TypeVarId, Type | None]:
  137. """Solve constraints sequentially, updating constraint targets after each step.
  138. We solve for type variables that appear in `batch`. If a constraint target is not constant
  139. (i.e. constraint looks like T :> F[S, ...]), we substitute solutions found so far in
  140. the target F[S, ...]. This way we can gradually solve for all variables in the batch taking
  141. one solvable variable at a time (i.e. such a variable that has at least one constant bound).
  142. Importantly, variables in free_vars are considered constants, so for example if we have just
  143. one initial constraint T <: List[S], we will have two SCCs {T} and {S}, then we first
  144. designate S as free, and therefore T = List[S] is a valid solution for T.
  145. """
  146. solutions = {}
  147. relevant_constraints = []
  148. for tv in batch:
  149. relevant_constraints.extend(cmap.get(tv, []))
  150. lowers, uppers = transitive_closure(batch, relevant_constraints)
  151. s_batch = set(batch)
  152. not_allowed_vars = [v for v in batch if v not in free_vars]
  153. while s_batch:
  154. for tv in s_batch:
  155. if any(not get_vars(l, not_allowed_vars) for l in lowers[tv]) or any(
  156. not get_vars(u, not_allowed_vars) for u in uppers[tv]
  157. ):
  158. solvable_tv = tv
  159. break
  160. else:
  161. break
  162. # Solve each solvable type variable separately.
  163. s_batch.remove(solvable_tv)
  164. result = solve_one(lowers[solvable_tv], uppers[solvable_tv], not_allowed_vars)
  165. solutions[solvable_tv] = result
  166. if result is None:
  167. # TODO: support backtracking lower/upper bound choices
  168. # (will require switching this function from iterative to recursive).
  169. continue
  170. # Update the (transitive) constraints if there is a solution.
  171. subs = {solvable_tv: result}
  172. lowers = {tv: {expand_type(l, subs) for l in lowers[tv]} for tv in lowers}
  173. uppers = {tv: {expand_type(u, subs) for u in uppers[tv]} for tv in uppers}
  174. for v in cmap:
  175. for c in cmap[v]:
  176. c.target = expand_type(c.target, subs)
  177. return solutions
  178. def solve_one(
  179. lowers: Iterable[Type], uppers: Iterable[Type], not_allowed_vars: list[TypeVarId]
  180. ) -> Type | None:
  181. """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds."""
  182. bottom: Type | None = None
  183. top: Type | None = None
  184. candidate: Type | None = None
  185. # Process each bound separately, and calculate the lower and upper
  186. # bounds based on constraints. Note that we assume that the constraint
  187. # targets do not have constraint references.
  188. for target in lowers:
  189. # There may be multiple steps needed to solve all vars within a
  190. # (linear) SCC. We ignore targets pointing to not yet solved vars.
  191. if get_vars(target, not_allowed_vars):
  192. continue
  193. if bottom is None:
  194. bottom = target
  195. else:
  196. if type_state.infer_unions:
  197. # This deviates from the general mypy semantics because
  198. # recursive types are union-heavy in 95% of cases.
  199. bottom = UnionType.make_union([bottom, target])
  200. else:
  201. bottom = join_types(bottom, target)
  202. for target in uppers:
  203. # Same as above.
  204. if get_vars(target, not_allowed_vars):
  205. continue
  206. if top is None:
  207. top = target
  208. else:
  209. top = meet_types(top, target)
  210. p_top = get_proper_type(top)
  211. p_bottom = get_proper_type(bottom)
  212. if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType):
  213. source_any = top if isinstance(p_top, AnyType) else bottom
  214. assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType)
  215. return AnyType(TypeOfAny.from_another_any, source_any=source_any)
  216. elif bottom is None:
  217. if top:
  218. candidate = top
  219. else:
  220. # No constraints for type variable
  221. return None
  222. elif top is None:
  223. candidate = bottom
  224. elif is_subtype(bottom, top):
  225. candidate = bottom
  226. else:
  227. candidate = None
  228. return candidate
  229. def normalize_constraints(
  230. constraints: list[Constraint], vars: list[TypeVarId]
  231. ) -> list[Constraint]:
  232. """Normalize list of constraints (to simplify life for the non-linear solver).
  233. This includes two things currently:
  234. * Complement T :> S by S <: T
  235. * Remove strict duplicates
  236. """
  237. res = constraints.copy()
  238. for c in constraints:
  239. if isinstance(c.target, TypeVarType):
  240. res.append(Constraint(c.target, neg_op(c.op), c.origin_type_var))
  241. return [c for c in remove_dups(constraints) if c.type_var in vars]
  242. def propagate_constraints_for(
  243. var: TypeVarId, direction: int, cmap: dict[TypeVarId, list[Constraint]]
  244. ) -> list[Constraint]:
  245. """Propagate via linear constraints to get additional constraints for `var`.
  246. For example if we have constraints:
  247. [T <: int, S <: T, S :> str]
  248. we can add two more
  249. [S <: int, T :> str]
  250. """
  251. extra_constraints = []
  252. seen = set()
  253. front = [var]
  254. if cmap[var]:
  255. var_def = cmap[var][0].origin_type_var
  256. else:
  257. return []
  258. while front:
  259. tv = front.pop(0)
  260. for c in cmap[tv]:
  261. if (
  262. isinstance(c.target, TypeVarType)
  263. and c.target.id not in seen
  264. and c.target.id in cmap
  265. and c.op == direction
  266. ):
  267. front.append(c.target.id)
  268. seen.add(c.target.id)
  269. elif c.op == direction:
  270. new_c = Constraint(var_def, direction, c.target)
  271. if new_c not in cmap[var]:
  272. extra_constraints.append(new_c)
  273. return extra_constraints
  274. def transitive_closure(
  275. tvars: list[TypeVarId], constraints: list[Constraint]
  276. ) -> tuple[dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]]]:
  277. """Find transitive closure for given constraints on type variables.
  278. Transitive closure gives maximal set of lower/upper bounds for each type variable,
  279. such that we cannot deduce any further bounds by chaining other existing bounds.
  280. For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive
  281. closure is given by:
  282. * {} <: T <: {S, U, int}
  283. * {T} <: S <: {U, int}
  284. * {T, S} <: U <: {int}
  285. """
  286. # TODO: merge propagate_constraints_for() into this function.
  287. # TODO: add secondary constraints here to make the algorithm complete.
  288. uppers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars}
  289. lowers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars}
  290. graph: set[tuple[TypeVarId, TypeVarId]] = set()
  291. # Prime the closure with the initial trivial values.
  292. for c in constraints:
  293. if isinstance(c.target, TypeVarType) and c.target.id in tvars:
  294. if c.op == SUBTYPE_OF:
  295. graph.add((c.type_var, c.target.id))
  296. else:
  297. graph.add((c.target.id, c.type_var))
  298. if c.op == SUBTYPE_OF:
  299. uppers[c.type_var].add(c.target)
  300. else:
  301. lowers[c.type_var].add(c.target)
  302. # At this stage we know that constant bounds have been propagated already, so we
  303. # only need to propagate linear constraints.
  304. for c in constraints:
  305. if isinstance(c.target, TypeVarType) and c.target.id in tvars:
  306. if c.op == SUBTYPE_OF:
  307. lower, upper = c.type_var, c.target.id
  308. else:
  309. lower, upper = c.target.id, c.type_var
  310. extras = {
  311. (l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph
  312. }
  313. graph |= extras
  314. for u in tvars:
  315. if (upper, u) in graph:
  316. lowers[u] |= lowers[lower]
  317. for l in tvars:
  318. if (l, lower) in graph:
  319. uppers[l] |= uppers[upper]
  320. return lowers, uppers
  321. def compute_dependencies(
  322. cmap: dict[TypeVarId, list[Constraint]]
  323. ) -> dict[TypeVarId, list[TypeVarId]]:
  324. """Compute dependencies between type variables induced by constraints.
  325. If we have a constraint like T <: List[S], we say that T depends on S, since
  326. we will need to solve for S first before we can solve for T.
  327. """
  328. res = {}
  329. vars = list(cmap.keys())
  330. for tv in cmap:
  331. deps = set()
  332. for c in cmap[tv]:
  333. deps |= get_vars(c.target, vars)
  334. res[tv] = list(deps)
  335. return res
  336. def check_linear(scc: set[TypeVarId], cmap: dict[TypeVarId, list[Constraint]]) -> bool:
  337. """Check there are only linear constraints between type variables in SCC.
  338. Linear are constraints like T <: S (while T <: F[S] are non-linear).
  339. """
  340. for tv in scc:
  341. if any(
  342. get_vars(c.target, list(scc)) and not isinstance(c.target, TypeVarType)
  343. for c in cmap[tv]
  344. ):
  345. return False
  346. return True
  347. def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]:
  348. """Find type variables for which we are solving in a target type."""
  349. return {tv.id for tv in get_type_vars(target)} & set(vars)