| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395 |
- """Type inference constraint solving"""
- from __future__ import annotations
- from typing import Iterable
- from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op
- from mypy.expandtype import expand_type
- from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
- from mypy.join import join_types
- from mypy.meet import meet_types
- from mypy.subtypes import is_subtype
- from mypy.typeops import get_type_vars
- from mypy.types import (
- AnyType,
- ProperType,
- Type,
- TypeOfAny,
- TypeVarId,
- TypeVarType,
- UninhabitedType,
- UnionType,
- get_proper_type,
- remove_dups,
- )
- from mypy.typestate import type_state
- def solve_constraints(
- vars: list[TypeVarId],
- constraints: list[Constraint],
- strict: bool = True,
- allow_polymorphic: bool = False,
- ) -> list[Type | None]:
- """Solve type constraints.
- Return the best type(s) for type variables; each type can be None if the value of the variable
- could not be solved.
- If a variable has no constraints, if strict=True then arbitrarily
- pick NoneType as the value of the type variable. If strict=False,
- pick AnyType.
- """
- if not vars:
- return []
- if allow_polymorphic:
- # Constraints like T :> S and S <: T are semantically the same, but they are
- # represented differently. Normalize the constraint list w.r.t this equivalence.
- constraints = normalize_constraints(constraints, vars)
- # Collect a list of constraints for each type variable.
- cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars}
- for con in constraints:
- if con.type_var in vars:
- cmap[con.type_var].append(con)
- if allow_polymorphic:
- solutions = solve_non_linear(vars, constraints, cmap)
- else:
- solutions = {}
- for tv, cs in cmap.items():
- if not cs:
- continue
- lowers = [c.target for c in cs if c.op == SUPERTYPE_OF]
- uppers = [c.target for c in cs if c.op == SUBTYPE_OF]
- solutions[tv] = solve_one(lowers, uppers, [])
- res: list[Type | None] = []
- for v in vars:
- if v in solutions:
- res.append(solutions[v])
- else:
- # No constraints for type variable -- 'UninhabitedType' is the most specific type.
- candidate: Type
- if strict:
- candidate = UninhabitedType()
- candidate.ambiguous = True
- else:
- candidate = AnyType(TypeOfAny.special_form)
- res.append(candidate)
- return res
- def solve_non_linear(
- vars: list[TypeVarId], constraints: list[Constraint], cmap: dict[TypeVarId, list[Constraint]]
- ) -> dict[TypeVarId, Type | None]:
- """Solve set of constraints that may include non-linear ones, like T <: List[S].
- The whole algorithm consists of five steps:
- * Propagate via linear constraints to get all possible constraints for each variable
- * Find dependencies between type variables, group them in SCCs, and sort topologically
- * Check all SCC are intrinsically linear, we can't solve (express) T <: List[T]
- * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC)
- * Solve constraints iteratively starting from leafs, updating targets after each step.
- """
- extra_constraints = []
- for tvar in vars:
- extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap))
- extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap))
- constraints += remove_dups(extra_constraints)
- # Recompute constraint map after propagating.
- cmap = {tv: [] for tv in vars}
- for con in constraints:
- if con.type_var in vars:
- cmap[con.type_var].append(con)
- dmap = compute_dependencies(cmap)
- sccs = list(strongly_connected_components(set(vars), dmap))
- if all(check_linear(scc, cmap) for scc in sccs):
- raw_batches = list(topsort(prepare_sccs(sccs, dmap)))
- leafs = raw_batches[0]
- free_vars = []
- for scc in leafs:
- # If all constrain targets in this SCC are type variables within the
- # same SCC then the only meaningful solution we can express, is that
- # each variable is equal to a new free variable. For example if we
- # have T <: S, S <: U, we deduce: T = S = U = <free>.
- if all(
- isinstance(c.target, TypeVarType) and c.target.id in vars
- for tv in scc
- for c in cmap[tv]
- ):
- # For convenience with current type application machinery, we randomly
- # choose one of the existing type variables in SCC and designate it as free
- # instead of defining a new type variable as a common solution.
- # TODO: be careful about upper bounds (or values) when introducing free vars.
- free_vars.append(sorted(scc, key=lambda x: x.raw_id)[0])
- # Flatten the SCCs that are independent, we can solve them together,
- # since we don't need to update any targets in between.
- batches = []
- for batch in raw_batches:
- next_bc = []
- for scc in batch:
- next_bc.extend(list(scc))
- batches.append(next_bc)
- solutions: dict[TypeVarId, Type | None] = {}
- for flat_batch in batches:
- solutions.update(solve_iteratively(flat_batch, cmap, free_vars))
- # We remove the solutions like T = T for free variables. This will indicate
- # to the apply function, that they should not be touched.
- # TODO: return list of free type variables explicitly, this logic is fragile
- # (but if we do, we need to be careful everything works in incremental modes).
- for tv in free_vars:
- if tv in solutions:
- del solutions[tv]
- return solutions
- return {}
- def solve_iteratively(
- batch: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId]
- ) -> dict[TypeVarId, Type | None]:
- """Solve constraints sequentially, updating constraint targets after each step.
- We solve for type variables that appear in `batch`. If a constraint target is not constant
- (i.e. constraint looks like T :> F[S, ...]), we substitute solutions found so far in
- the target F[S, ...]. This way we can gradually solve for all variables in the batch taking
- one solvable variable at a time (i.e. such a variable that has at least one constant bound).
- Importantly, variables in free_vars are considered constants, so for example if we have just
- one initial constraint T <: List[S], we will have two SCCs {T} and {S}, then we first
- designate S as free, and therefore T = List[S] is a valid solution for T.
- """
- solutions = {}
- relevant_constraints = []
- for tv in batch:
- relevant_constraints.extend(cmap.get(tv, []))
- lowers, uppers = transitive_closure(batch, relevant_constraints)
- s_batch = set(batch)
- not_allowed_vars = [v for v in batch if v not in free_vars]
- while s_batch:
- for tv in s_batch:
- if any(not get_vars(l, not_allowed_vars) for l in lowers[tv]) or any(
- not get_vars(u, not_allowed_vars) for u in uppers[tv]
- ):
- solvable_tv = tv
- break
- else:
- break
- # Solve each solvable type variable separately.
- s_batch.remove(solvable_tv)
- result = solve_one(lowers[solvable_tv], uppers[solvable_tv], not_allowed_vars)
- solutions[solvable_tv] = result
- if result is None:
- # TODO: support backtracking lower/upper bound choices
- # (will require switching this function from iterative to recursive).
- continue
- # Update the (transitive) constraints if there is a solution.
- subs = {solvable_tv: result}
- lowers = {tv: {expand_type(l, subs) for l in lowers[tv]} for tv in lowers}
- uppers = {tv: {expand_type(u, subs) for u in uppers[tv]} for tv in uppers}
- for v in cmap:
- for c in cmap[v]:
- c.target = expand_type(c.target, subs)
- return solutions
- def solve_one(
- lowers: Iterable[Type], uppers: Iterable[Type], not_allowed_vars: list[TypeVarId]
- ) -> Type | None:
- """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds."""
- bottom: Type | None = None
- top: Type | None = None
- candidate: Type | None = None
- # Process each bound separately, and calculate the lower and upper
- # bounds based on constraints. Note that we assume that the constraint
- # targets do not have constraint references.
- for target in lowers:
- # There may be multiple steps needed to solve all vars within a
- # (linear) SCC. We ignore targets pointing to not yet solved vars.
- if get_vars(target, not_allowed_vars):
- continue
- if bottom is None:
- bottom = target
- else:
- if type_state.infer_unions:
- # This deviates from the general mypy semantics because
- # recursive types are union-heavy in 95% of cases.
- bottom = UnionType.make_union([bottom, target])
- else:
- bottom = join_types(bottom, target)
- for target in uppers:
- # Same as above.
- if get_vars(target, not_allowed_vars):
- continue
- if top is None:
- top = target
- else:
- top = meet_types(top, target)
- p_top = get_proper_type(top)
- p_bottom = get_proper_type(bottom)
- if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType):
- source_any = top if isinstance(p_top, AnyType) else bottom
- assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType)
- return AnyType(TypeOfAny.from_another_any, source_any=source_any)
- elif bottom is None:
- if top:
- candidate = top
- else:
- # No constraints for type variable
- return None
- elif top is None:
- candidate = bottom
- elif is_subtype(bottom, top):
- candidate = bottom
- else:
- candidate = None
- return candidate
- def normalize_constraints(
- constraints: list[Constraint], vars: list[TypeVarId]
- ) -> list[Constraint]:
- """Normalize list of constraints (to simplify life for the non-linear solver).
- This includes two things currently:
- * Complement T :> S by S <: T
- * Remove strict duplicates
- """
- res = constraints.copy()
- for c in constraints:
- if isinstance(c.target, TypeVarType):
- res.append(Constraint(c.target, neg_op(c.op), c.origin_type_var))
- return [c for c in remove_dups(constraints) if c.type_var in vars]
- def propagate_constraints_for(
- var: TypeVarId, direction: int, cmap: dict[TypeVarId, list[Constraint]]
- ) -> list[Constraint]:
- """Propagate via linear constraints to get additional constraints for `var`.
- For example if we have constraints:
- [T <: int, S <: T, S :> str]
- we can add two more
- [S <: int, T :> str]
- """
- extra_constraints = []
- seen = set()
- front = [var]
- if cmap[var]:
- var_def = cmap[var][0].origin_type_var
- else:
- return []
- while front:
- tv = front.pop(0)
- for c in cmap[tv]:
- if (
- isinstance(c.target, TypeVarType)
- and c.target.id not in seen
- and c.target.id in cmap
- and c.op == direction
- ):
- front.append(c.target.id)
- seen.add(c.target.id)
- elif c.op == direction:
- new_c = Constraint(var_def, direction, c.target)
- if new_c not in cmap[var]:
- extra_constraints.append(new_c)
- return extra_constraints
- def transitive_closure(
- tvars: list[TypeVarId], constraints: list[Constraint]
- ) -> tuple[dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]]]:
- """Find transitive closure for given constraints on type variables.
- Transitive closure gives maximal set of lower/upper bounds for each type variable,
- such that we cannot deduce any further bounds by chaining other existing bounds.
- For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive
- closure is given by:
- * {} <: T <: {S, U, int}
- * {T} <: S <: {U, int}
- * {T, S} <: U <: {int}
- """
- # TODO: merge propagate_constraints_for() into this function.
- # TODO: add secondary constraints here to make the algorithm complete.
- uppers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars}
- lowers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars}
- graph: set[tuple[TypeVarId, TypeVarId]] = set()
- # Prime the closure with the initial trivial values.
- for c in constraints:
- if isinstance(c.target, TypeVarType) and c.target.id in tvars:
- if c.op == SUBTYPE_OF:
- graph.add((c.type_var, c.target.id))
- else:
- graph.add((c.target.id, c.type_var))
- if c.op == SUBTYPE_OF:
- uppers[c.type_var].add(c.target)
- else:
- lowers[c.type_var].add(c.target)
- # At this stage we know that constant bounds have been propagated already, so we
- # only need to propagate linear constraints.
- for c in constraints:
- if isinstance(c.target, TypeVarType) and c.target.id in tvars:
- if c.op == SUBTYPE_OF:
- lower, upper = c.type_var, c.target.id
- else:
- lower, upper = c.target.id, c.type_var
- extras = {
- (l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph
- }
- graph |= extras
- for u in tvars:
- if (upper, u) in graph:
- lowers[u] |= lowers[lower]
- for l in tvars:
- if (l, lower) in graph:
- uppers[l] |= uppers[upper]
- return lowers, uppers
- def compute_dependencies(
- cmap: dict[TypeVarId, list[Constraint]]
- ) -> dict[TypeVarId, list[TypeVarId]]:
- """Compute dependencies between type variables induced by constraints.
- If we have a constraint like T <: List[S], we say that T depends on S, since
- we will need to solve for S first before we can solve for T.
- """
- res = {}
- vars = list(cmap.keys())
- for tv in cmap:
- deps = set()
- for c in cmap[tv]:
- deps |= get_vars(c.target, vars)
- res[tv] = list(deps)
- return res
- def check_linear(scc: set[TypeVarId], cmap: dict[TypeVarId, list[Constraint]]) -> bool:
- """Check there are only linear constraints between type variables in SCC.
- Linear are constraints like T <: S (while T <: F[S] are non-linear).
- """
- for tv in scc:
- if any(
- get_vars(c.target, list(scc)) and not isinstance(c.target, TypeVarType)
- for c in cmap[tv]
- ):
- return False
- return True
- def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]:
- """Find type variables for which we are solving in a target type."""
- return {tv.id for tv in get_type_vars(target)} & set(vars)
|