diff options
| author | bors <bors@rust-lang.org> | 2025-09-08 19:39:36 +0000 | 
|---|---|---|
| committer | bors <bors@rust-lang.org> | 2025-09-08 19:39:36 +0000 | 
| commit | 9c27f27ea3bab79a2fec827ef3ae0009959d60f4 (patch) | |
| tree | e1f00e3dec434fe16f9bcadeafd1a7cdf6c8a6cf /compiler/rustc_infer/src | |
| parent | a78f9aa87fa828ad4a5c11f1e3b93e94d9352ad6 (diff) | |
| parent | b51a3a565a056235f3864e2cefdb9449f6b0dcb1 (diff) | |
| download | rust-9c27f27ea3bab79a2fec827ef3ae0009959d60f4.tar.gz rust-9c27f27ea3bab79a2fec827ef3ae0009959d60f4.zip | |
Auto merge of #140375 - lcnr:subrelations-infcx, r=BoxyUwU
eagerly compute `sub_unification_table` again Previously called `sub_relations`. We still only using them for diagnostics right now. This mostly reverts rust-lang/rust#119989. Necessary for type inference guidance due to not-yet defined opaque types, cc https://github.com/rust-lang/trait-system-refactor-initiative/issues/182. We could use them for cycle detection in generalization and it seems desirable to do so in the future. However, this is unsound with the old trait solver as its cache does not track these `sub_unification_table` in any way. We now properly track the `sub_unification_table` when canonicalizing so using them in the new solver is totally sound and the performance impact is far more manageable than I thought back in rust-lang/rust#119989. r? `@compiler-errors`
Diffstat (limited to 'compiler/rustc_infer/src')
| -rw-r--r-- | compiler/rustc_infer/src/infer/canonical/canonicalizer.rs | 34 | ||||
| -rw-r--r-- | compiler/rustc_infer/src/infer/canonical/mod.rs | 56 | ||||
| -rw-r--r-- | compiler/rustc_infer/src/infer/canonical/query_response.rs | 77 | ||||
| -rw-r--r-- | compiler/rustc_infer/src/infer/context.rs | 8 | ||||
| -rw-r--r-- | compiler/rustc_infer/src/infer/mod.rs | 9 | ||||
| -rw-r--r-- | compiler/rustc_infer/src/infer/relate/generalize.rs | 4 | ||||
| -rw-r--r-- | compiler/rustc_infer/src/infer/snapshot/undo_log.rs | 4 | ||||
| -rw-r--r-- | compiler/rustc_infer/src/infer/type_variable.rs | 116 | 
8 files changed, 222 insertions, 86 deletions
| diff --git a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs index 3ad14dc79d5..3c5e4a91c98 100644 --- a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs +++ b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs @@ -6,6 +6,7 @@ //! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::sso::SsoHashMap; use rustc_index::Idx; use rustc_middle::bug; use rustc_middle::ty::{ @@ -17,7 +18,7 @@ use tracing::debug; use crate::infer::InferCtxt; use crate::infer::canonical::{ - Canonical, CanonicalQueryInput, CanonicalTyVarKind, CanonicalVarKind, OriginalQueryValues, + Canonical, CanonicalQueryInput, CanonicalVarKind, OriginalQueryValues, }; impl<'tcx> InferCtxt<'tcx> { @@ -293,6 +294,13 @@ struct Canonicalizer<'cx, 'tcx> { // Note that indices is only used once `var_values` is big enough to be // heap-allocated. indices: FxHashMap<GenericArg<'tcx>, BoundVar>, + /// Maps each `sub_unification_table_root_var` to the index of the first + /// variable which used it. + /// + /// This means in case two type variables have the same sub relations root, + /// we set the `sub_root` of the second variable to the position of the first. + /// Otherwise the `sub_root` of each type variable is just its own position. + sub_root_lookup_table: SsoHashMap<ty::TyVid, usize>, canonicalize_mode: &'cx dyn CanonicalizeMode, needs_canonical_flags: TypeFlags, @@ -361,10 +369,8 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> { // FIXME: perf problem described in #55921. ui = ty::UniverseIndex::ROOT; } - self.canonicalize_ty_var( - CanonicalVarKind::Ty(CanonicalTyVarKind::General(ui)), - t, - ) + let sub_root = self.get_or_insert_sub_root(vid); + self.canonicalize_ty_var(CanonicalVarKind::Ty { ui, sub_root }, t) } } } @@ -374,7 +380,7 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> { if nt != t { return self.fold_ty(nt); } else { - self.canonicalize_ty_var(CanonicalVarKind::Ty(CanonicalTyVarKind::Int), t) + self.canonicalize_ty_var(CanonicalVarKind::Int, t) } } ty::Infer(ty::FloatVar(vid)) => { @@ -382,7 +388,7 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> { if nt != t { return self.fold_ty(nt); } else { - self.canonicalize_ty_var(CanonicalVarKind::Ty(CanonicalTyVarKind::Float), t) + self.canonicalize_ty_var(CanonicalVarKind::Float, t) } } @@ -562,6 +568,7 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> { variables: SmallVec::from_slice(base.variables), query_state, indices: FxHashMap::default(), + sub_root_lookup_table: Default::default(), binder_index: ty::INNERMOST, }; if canonicalizer.query_state.var_values.spilled() { @@ -660,6 +667,13 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> { } } + fn get_or_insert_sub_root(&mut self, vid: ty::TyVid) -> ty::BoundVar { + let root_vid = self.infcx.unwrap().sub_unification_table_root_var(vid); + let idx = + *self.sub_root_lookup_table.entry(root_vid).or_insert_with(|| self.variables.len()); + ty::BoundVar::from(idx) + } + /// Replaces the universe indexes used in `var_values` with their index in /// `query_state.universe_map`. This minimizes the maximum universe used in /// the canonicalized value. @@ -679,11 +693,11 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> { self.variables .iter() .map(|&kind| match kind { - CanonicalVarKind::Ty(CanonicalTyVarKind::Int | CanonicalTyVarKind::Float) => { + CanonicalVarKind::Int | CanonicalVarKind::Float => { return kind; } - CanonicalVarKind::Ty(CanonicalTyVarKind::General(u)) => { - CanonicalVarKind::Ty(CanonicalTyVarKind::General(reverse_universe_map[&u])) + CanonicalVarKind::Ty { ui, sub_root } => { + CanonicalVarKind::Ty { ui: reverse_universe_map[&ui], sub_root } } CanonicalVarKind::Region(u) => CanonicalVarKind::Region(reverse_universe_map[&u]), CanonicalVarKind::Const(u) => CanonicalVarKind::Const(reverse_universe_map[&u]), diff --git a/compiler/rustc_infer/src/infer/canonical/mod.rs b/compiler/rustc_infer/src/infer/canonical/mod.rs index 79a2aa54ef8..f99f228e19d 100644 --- a/compiler/rustc_infer/src/infer/canonical/mod.rs +++ b/compiler/rustc_infer/src/infer/canonical/mod.rs @@ -24,7 +24,7 @@ pub use instantiate::CanonicalExt; use rustc_index::IndexVec; pub use rustc_middle::infer::canonical::*; -use rustc_middle::ty::{self, GenericArg, List, Ty, TyCtxt, TypeFoldable}; +use rustc_middle::ty::{self, GenericArg, Ty, TyCtxt, TypeFoldable}; use rustc_span::Span; use crate::infer::{InferCtxt, RegionVariableOrigin}; @@ -67,30 +67,12 @@ impl<'tcx> InferCtxt<'tcx> { .chain((1..=canonical.max_universe.as_u32()).map(|_| self.create_next_universe())) .collect(); - let canonical_inference_vars = - self.instantiate_canonical_vars(span, canonical.variables, |ui| universes[ui]); - let result = canonical.instantiate(self.tcx, &canonical_inference_vars); - (result, canonical_inference_vars) - } - - /// Given the "infos" about the canonical variables from some - /// canonical, creates fresh variables with the same - /// characteristics (see `instantiate_canonical_var` for - /// details). You can then use `instantiate` to instantiate the - /// canonical variable with these inference variables. - fn instantiate_canonical_vars( - &self, - span: Span, - variables: &List<CanonicalVarKind<'tcx>>, - universe_map: impl Fn(ty::UniverseIndex) -> ty::UniverseIndex, - ) -> CanonicalVarValues<'tcx> { - CanonicalVarValues { - var_values: self.tcx.mk_args_from_iter( - variables - .iter() - .map(|kind| self.instantiate_canonical_var(span, kind, &universe_map)), - ), - } + let var_values = + CanonicalVarValues::instantiate(self.tcx, &canonical.variables, |var_values, info| { + self.instantiate_canonical_var(span, info, &var_values, |ui| universes[ui]) + }); + let result = canonical.instantiate(self.tcx, &var_values); + (result, var_values) } /// Given the "info" about a canonical variable, creates a fresh @@ -105,21 +87,27 @@ impl<'tcx> InferCtxt<'tcx> { &self, span: Span, kind: CanonicalVarKind<'tcx>, + previous_var_values: &[GenericArg<'tcx>], universe_map: impl Fn(ty::UniverseIndex) -> ty::UniverseIndex, ) -> GenericArg<'tcx> { match kind { - CanonicalVarKind::Ty(ty_kind) => { - let ty = match ty_kind { - CanonicalTyVarKind::General(ui) => { - self.next_ty_var_in_universe(span, universe_map(ui)) + CanonicalVarKind::Ty { ui, sub_root } => { + let vid = self.next_ty_vid_in_universe(span, universe_map(ui)); + // If this inference variable is related to an earlier variable + // via subtyping, we need to add that info to the inference context. + if let Some(prev) = previous_var_values.get(sub_root.as_usize()) { + if let &ty::Infer(ty::TyVar(sub_root)) = prev.expect_ty().kind() { + self.sub_unify_ty_vids_raw(vid, sub_root); + } else { + unreachable!() } + } + Ty::new_var(self.tcx, vid).into() + } - CanonicalTyVarKind::Int => self.next_int_var(), + CanonicalVarKind::Int => self.next_int_var().into(), - CanonicalTyVarKind::Float => self.next_float_var(), - }; - ty.into() - } + CanonicalVarKind::Float => self.next_float_var().into(), CanonicalVarKind::PlaceholderTy(ty::PlaceholderType { universe, bound }) => { let universe_mapped = universe_map(universe); diff --git a/compiler/rustc_infer/src/infer/canonical/query_response.rs b/compiler/rustc_infer/src/infer/canonical/query_response.rs index 09578598114..5d1b4be9e57 100644 --- a/compiler/rustc_infer/src/infer/canonical/query_response.rs +++ b/compiler/rustc_infer/src/infer/canonical/query_response.rs @@ -13,6 +13,7 @@ use std::iter; use rustc_index::{Idx, IndexVec}; use rustc_middle::arena::ArenaAllocatable; use rustc_middle::bug; +use rustc_middle::infer::canonical::CanonicalVarKind; use rustc_middle::ty::{self, BoundVar, GenericArg, GenericArgKind, Ty, TyCtxt, TypeFoldable}; use tracing::{debug, instrument}; @@ -413,26 +414,27 @@ impl<'tcx> InferCtxt<'tcx> { let mut opt_values: IndexVec<BoundVar, Option<GenericArg<'tcx>>> = IndexVec::from_elem_n(None, query_response.variables.len()); - // In terms of our example above, we are iterating over pairs like: - // [(?A, Vec<?0>), ('static, '?1), (?B, ?0)] for (original_value, result_value) in iter::zip(&original_values.var_values, result_values) { match result_value.kind() { GenericArgKind::Type(result_value) => { - // e.g., here `result_value` might be `?0` in the example above... - if let ty::Bound(debruijn, b) = *result_value.kind() { - // ...in which case we would set `canonical_vars[0]` to `Some(?U)`. - + // We disable the instantiation guess for inference variables + // and only use it for placeholders. We need to handle the + // `sub_root` of type inference variables which would make this + // more involved. They are also a lot rarer than region variables. + if let ty::Bound(debruijn, b) = *result_value.kind() + && !matches!( + query_response.variables[b.var.as_usize()], + CanonicalVarKind::Ty { .. } + ) + { // We only allow a `ty::INNERMOST` index in generic parameters. assert_eq!(debruijn, ty::INNERMOST); opt_values[b.var] = Some(*original_value); } } GenericArgKind::Lifetime(result_value) => { - // e.g., here `result_value` might be `'?1` in the example above... if let ty::ReBound(debruijn, b) = result_value.kind() { - // ... in which case we would set `canonical_vars[0]` to `Some('static)`. - // We only allow a `ty::INNERMOST` index in generic parameters. assert_eq!(debruijn, ty::INNERMOST); opt_values[b.var] = Some(*original_value); @@ -440,8 +442,6 @@ impl<'tcx> InferCtxt<'tcx> { } GenericArgKind::Const(result_value) => { if let ty::ConstKind::Bound(debruijn, b) = result_value.kind() { - // ...in which case we would set `canonical_vars[0]` to `Some(const X)`. - // We only allow a `ty::INNERMOST` index in generic parameters. assert_eq!(debruijn, ty::INNERMOST); opt_values[b.var] = Some(*original_value); @@ -453,39 +453,36 @@ impl<'tcx> InferCtxt<'tcx> { // Create result arguments: if we found a value for a // given variable in the loop above, use that. Otherwise, use // a fresh inference variable. - let result_args = CanonicalVarValues { - var_values: self.tcx.mk_args_from_iter( - query_response.variables.iter().enumerate().map(|(index, var_kind)| { - if var_kind.universe() != ty::UniverseIndex::ROOT { - // A variable from inside a binder of the query. While ideally these shouldn't - // exist at all, we have to deal with them for now. - self.instantiate_canonical_var(cause.span, var_kind, |u| { - universe_map[u.as_usize()] - }) - } else if var_kind.is_existential() { - match opt_values[BoundVar::new(index)] { - Some(k) => k, - None => self.instantiate_canonical_var(cause.span, var_kind, |u| { - universe_map[u.as_usize()] - }), - } - } else { - // For placeholders which were already part of the input, we simply map this - // universal bound variable back the placeholder of the input. - opt_values[BoundVar::new(index)].expect( - "expected placeholder to be unified with itself during response", - ) - } - }), - ), - }; + let tcx = self.tcx; + let variables = query_response.variables; + let var_values = CanonicalVarValues::instantiate(tcx, variables, |var_values, kind| { + if kind.universe() != ty::UniverseIndex::ROOT { + // A variable from inside a binder of the query. While ideally these shouldn't + // exist at all, we have to deal with them for now. + self.instantiate_canonical_var(cause.span, kind, &var_values, |u| { + universe_map[u.as_usize()] + }) + } else if kind.is_existential() { + match opt_values[BoundVar::new(var_values.len())] { + Some(k) => k, + None => self.instantiate_canonical_var(cause.span, kind, &var_values, |u| { + universe_map[u.as_usize()] + }), + } + } else { + // For placeholders which were already part of the input, we simply map this + // universal bound variable back the placeholder of the input. + opt_values[BoundVar::new(var_values.len())] + .expect("expected placeholder to be unified with itself during response") + } + }); let mut obligations = PredicateObligations::new(); // Carry all newly resolved opaque types to the caller's scope for &(a, b) in &query_response.value.opaque_types { - let a = instantiate_value(self.tcx, &result_args, a); - let b = instantiate_value(self.tcx, &result_args, b); + let a = instantiate_value(self.tcx, &var_values, a); + let b = instantiate_value(self.tcx, &var_values, b); debug!(?a, ?b, "constrain opaque type"); // We use equate here instead of, for example, just registering the // opaque type's hidden value directly, because the hidden type may have been an inference @@ -502,7 +499,7 @@ impl<'tcx> InferCtxt<'tcx> { ); } - Ok(InferOk { value: result_args, obligations }) + Ok(InferOk { value: var_values, obligations }) } /// Given a "guess" at the values for the canonical variables in diff --git a/compiler/rustc_infer/src/infer/context.rs b/compiler/rustc_infer/src/infer/context.rs index 8265fccabc9..14cc590720a 100644 --- a/compiler/rustc_infer/src/infer/context.rs +++ b/compiler/rustc_infer/src/infer/context.rs @@ -59,6 +59,10 @@ impl<'tcx> rustc_type_ir::InferCtxtLike for InferCtxt<'tcx> { self.root_var(var) } + fn sub_unification_table_root_var(&self, var: ty::TyVid) -> ty::TyVid { + self.sub_unification_table_root_var(var) + } + fn root_const_var(&self, var: ty::ConstVid) -> ty::ConstVid { self.root_const_var(var) } @@ -179,6 +183,10 @@ impl<'tcx> rustc_type_ir::InferCtxtLike for InferCtxt<'tcx> { self.inner.borrow_mut().type_variables().equate(a, b); } + fn sub_unify_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid) { + self.sub_unify_ty_vids_raw(a, b); + } + fn equate_int_vids_raw(&self, a: ty::IntVid, b: ty::IntVid) { self.inner.borrow_mut().int_unification_table().union(a, b); } diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs index d105d24bed7..9d3886aff1c 100644 --- a/compiler/rustc_infer/src/infer/mod.rs +++ b/compiler/rustc_infer/src/infer/mod.rs @@ -764,6 +764,7 @@ impl<'tcx> InferCtxt<'tcx> { let r_b = self.shallow_resolve(predicate.skip_binder().b); match (r_a.kind(), r_b.kind()) { (&ty::Infer(ty::TyVar(a_vid)), &ty::Infer(ty::TyVar(b_vid))) => { + self.sub_unify_ty_vids_raw(a_vid, b_vid); return Err((a_vid, b_vid)); } _ => {} @@ -1128,6 +1129,14 @@ impl<'tcx> InferCtxt<'tcx> { self.inner.borrow_mut().type_variables().root_var(var) } + pub fn sub_unify_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid) { + self.inner.borrow_mut().type_variables().sub_unify(a, b); + } + + pub fn sub_unification_table_root_var(&self, var: ty::TyVid) -> ty::TyVid { + self.inner.borrow_mut().type_variables().sub_unification_table_root_var(var) + } + pub fn root_const_var(&self, var: ty::ConstVid) -> ty::ConstVid { self.inner.borrow_mut().const_unification_table().find(var).vid } diff --git a/compiler/rustc_infer/src/infer/relate/generalize.rs b/compiler/rustc_infer/src/infer/relate/generalize.rs index a75fd8dfa18..cc41957c110 100644 --- a/compiler/rustc_infer/src/infer/relate/generalize.rs +++ b/compiler/rustc_infer/src/infer/relate/generalize.rs @@ -558,6 +558,10 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for Generalizer<'_, 'tcx> { let origin = inner.type_variables().var_origin(vid); let new_var_id = inner.type_variables().new_var(self.for_universe, origin); + // Record that `vid` and `new_var_id` have to be subtypes + // of each other. This is currently only used for diagnostics. + // To see why, see the docs in the `type_variables` module. + inner.type_variables().sub_unify(vid, new_var_id); // If we're in the new solver and create a new inference // variable inside of an alias we eagerly constrain that // inference variable to prevent unexpected ambiguity errors. diff --git a/compiler/rustc_infer/src/infer/snapshot/undo_log.rs b/compiler/rustc_infer/src/infer/snapshot/undo_log.rs index fcc0ab3af41..22c815fb87c 100644 --- a/compiler/rustc_infer/src/infer/snapshot/undo_log.rs +++ b/compiler/rustc_infer/src/infer/snapshot/undo_log.rs @@ -20,7 +20,7 @@ pub struct Snapshot<'tcx> { pub(crate) enum UndoLog<'tcx> { DuplicateOpaqueType, OpaqueTypes(OpaqueTypeKey<'tcx>, Option<OpaqueHiddenType<'tcx>>), - TypeVariables(sv::UndoLog<ut::Delegate<type_variable::TyVidEqKey<'tcx>>>), + TypeVariables(type_variable::UndoLog<'tcx>), ConstUnificationTable(sv::UndoLog<ut::Delegate<ConstVidKey<'tcx>>>), IntUnificationTable(sv::UndoLog<ut::Delegate<ty::IntVid>>), FloatUnificationTable(sv::UndoLog<ut::Delegate<ty::FloatVid>>), @@ -49,6 +49,8 @@ impl_from! { RegionConstraintCollector(region_constraints::UndoLog<'tcx>), TypeVariables(sv::UndoLog<ut::Delegate<type_variable::TyVidEqKey<'tcx>>>), + TypeVariables(sv::UndoLog<ut::Delegate<type_variable::TyVidSubKey>>), + TypeVariables(type_variable::UndoLog<'tcx>), IntUnificationTable(sv::UndoLog<ut::Delegate<ty::IntVid>>), FloatUnificationTable(sv::UndoLog<ut::Delegate<ty::FloatVid>>), diff --git a/compiler/rustc_infer/src/infer/type_variable.rs b/compiler/rustc_infer/src/infer/type_variable.rs index 6f6791804d3..65f77fe8e25 100644 --- a/compiler/rustc_infer/src/infer/type_variable.rs +++ b/compiler/rustc_infer/src/infer/type_variable.rs @@ -13,12 +13,48 @@ use tracing::debug; use crate::infer::InferCtxtUndoLogs; +/// Represents a single undo-able action that affects a type inference variable. +#[derive(Clone)] +pub(crate) enum UndoLog<'tcx> { + EqRelation(sv::UndoLog<ut::Delegate<TyVidEqKey<'tcx>>>), + SubRelation(sv::UndoLog<ut::Delegate<TyVidSubKey>>), +} + +/// Convert from a specific kind of undo to the more general UndoLog +impl<'tcx> From<sv::UndoLog<ut::Delegate<TyVidEqKey<'tcx>>>> for UndoLog<'tcx> { + fn from(l: sv::UndoLog<ut::Delegate<TyVidEqKey<'tcx>>>) -> Self { + UndoLog::EqRelation(l) + } +} + +/// Convert from a specific kind of undo to the more general UndoLog +impl<'tcx> From<sv::UndoLog<ut::Delegate<TyVidSubKey>>> for UndoLog<'tcx> { + fn from(l: sv::UndoLog<ut::Delegate<TyVidSubKey>>) -> Self { + UndoLog::SubRelation(l) + } +} + impl<'tcx> Rollback<sv::UndoLog<ut::Delegate<TyVidEqKey<'tcx>>>> for TypeVariableStorage<'tcx> { fn reverse(&mut self, undo: sv::UndoLog<ut::Delegate<TyVidEqKey<'tcx>>>) { self.eq_relations.reverse(undo) } } +impl<'tcx> Rollback<sv::UndoLog<ut::Delegate<TyVidSubKey>>> for TypeVariableStorage<'tcx> { + fn reverse(&mut self, undo: sv::UndoLog<ut::Delegate<TyVidSubKey>>) { + self.sub_unification_table.reverse(undo) + } +} + +impl<'tcx> Rollback<UndoLog<'tcx>> for TypeVariableStorage<'tcx> { + fn reverse(&mut self, undo: UndoLog<'tcx>) { + match undo { + UndoLog::EqRelation(undo) => self.eq_relations.reverse(undo), + UndoLog::SubRelation(undo) => self.sub_unification_table.reverse(undo), + } + } +} + #[derive(Clone, Default)] pub(crate) struct TypeVariableStorage<'tcx> { /// The origins of each type variable. @@ -27,6 +63,25 @@ pub(crate) struct TypeVariableStorage<'tcx> { /// constraint `?X == ?Y`. This table also stores, for each key, /// the known value. eq_relations: ut::UnificationTableStorage<TyVidEqKey<'tcx>>, + /// Only used by `-Znext-solver` and for diagnostics. Tracks whether + /// type variables are related via subtyping at all, ignoring which of + /// the two is the subtype. + /// + /// When reporting ambiguity errors, we sometimes want to + /// treat all inference vars which are subtypes of each + /// others as if they are equal. For this case we compute + /// the transitive closure of our subtype obligations here. + /// + /// E.g. when encountering ambiguity errors, we want to suggest + /// specifying some method argument or to add a type annotation + /// to a local variable. Because subtyping cannot change the + /// shape of a type, it's fine if the cause of the ambiguity error + /// is only related to the suggested variable via subtyping. + /// + /// Even for something like `let x = returns_arg(); x.method();` the + /// type of `x` is only a supertype of the argument of `returns_arg`. We + /// still want to suggest specifying the type of the argument. + sub_unification_table: ut::UnificationTableStorage<TyVidSubKey>, } pub(crate) struct TypeVariableTable<'a, 'tcx> { @@ -102,13 +157,24 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { self.storage.values[vid].origin } - /// Records that `a == b`, depending on `dir`. + /// Records that `a == b`. /// /// Precondition: neither `a` nor `b` are known. pub(crate) fn equate(&mut self, a: ty::TyVid, b: ty::TyVid) { debug_assert!(self.probe(a).is_unknown()); debug_assert!(self.probe(b).is_unknown()); self.eq_relations().union(a, b); + self.sub_unification_table().union(a, b); + } + + /// Records that `a` and `b` are related via subtyping. We don't track + /// which of the two is the subtype. + /// + /// Precondition: neither `a` nor `b` are known. + pub(crate) fn sub_unify(&mut self, a: ty::TyVid, b: ty::TyVid) { + debug_assert!(self.probe(a).is_unknown()); + debug_assert!(self.probe(b).is_unknown()); + self.sub_unification_table().union(a, b); } /// Instantiates `vid` with the type `ty`. @@ -142,6 +208,10 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { origin: TypeVariableOrigin, ) -> ty::TyVid { let eq_key = self.eq_relations().new_key(TypeVariableValue::Unknown { universe }); + + let sub_key = self.sub_unification_table().new_key(()); + debug_assert_eq!(eq_key.vid, sub_key.vid); + let index = self.storage.values.push(TypeVariableData { origin }); debug_assert_eq!(eq_key.vid, index); @@ -164,6 +234,18 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { self.eq_relations().find(vid).vid } + /// Returns the "root" variable of `vid` in the `sub_unification_table` + /// equivalence table. All type variables that have been are related via + /// equality or subtyping will yield the same root variable (per the + /// union-find algorithm), so `sub_unification_table_root_var(a) + /// == sub_unification_table_root_var(b)` implies that: + /// ```text + /// exists X. (a <: X || X <: a) && (b <: X || X <: b) + /// ``` + pub(crate) fn sub_unification_table_root_var(&mut self, vid: ty::TyVid) -> ty::TyVid { + self.sub_unification_table().find(vid).vid + } + /// Retrieves the type to which `vid` has been instantiated, if /// any. pub(crate) fn probe(&mut self, vid: ty::TyVid) -> TypeVariableValue<'tcx> { @@ -181,6 +263,11 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { self.storage.eq_relations.with_log(self.undo_log) } + #[inline] + fn sub_unification_table(&mut self) -> super::UnificationTable<'_, 'tcx, TyVidSubKey> { + self.storage.sub_unification_table.with_log(self.undo_log) + } + /// Returns a range of the type variables created during the snapshot. pub(crate) fn vars_since_snapshot( &mut self, @@ -243,6 +330,33 @@ impl<'tcx> ut::UnifyKey for TyVidEqKey<'tcx> { } } +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(crate) struct TyVidSubKey { + vid: ty::TyVid, +} + +impl From<ty::TyVid> for TyVidSubKey { + #[inline] // make this function eligible for inlining - it is quite hot. + fn from(vid: ty::TyVid) -> Self { + TyVidSubKey { vid } + } +} + +impl ut::UnifyKey for TyVidSubKey { + type Value = (); + #[inline] + fn index(&self) -> u32 { + self.vid.as_u32() + } + #[inline] + fn from_index(i: u32) -> TyVidSubKey { + TyVidSubKey { vid: ty::TyVid::from_u32(i) } + } + fn tag() -> &'static str { + "TyVidSubKey" + } +} + impl<'tcx> ut::UnifyValue for TypeVariableValue<'tcx> { type Error = ut::NoError; | 
