about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-10-20 00:59:01 +0000
committerbors <bors@rust-lang.org>2024-10-20 00:59:01 +0000
commit54791efd8235805dcfbdad3b8788e08f2142c50b (patch)
treefd89cc5cda3d728b9798b6306caaeaa5b3ff029d /compiler
parentb596184f3b156ef8ad99a463814f18f7bbf5dc27 (diff)
parentd836d35739571dc8a16d9fe12e900c9b45b48336 (diff)
downloadrust-54791efd8235805dcfbdad3b8788e08f2142c50b.tar.gz
rust-54791efd8235805dcfbdad3b8788e08f2142c50b.zip
Auto merge of #131911 - lcnr:probe-no-more-leak-2, r=compiler-errors
refactor fudge_inference, handle effect vars

this makes it easier to use fudging outside of `fudge_inference_if_ok`, which is likely necessary to handle inference variable leaks on rollback.

We now also uses exhaustive matches where possible and improve the code to handle effect vars.

r? `@compiler-errors` `@BoxyUwU`
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_infer/src/infer/mod.rs11
-rw-r--r--compiler/rustc_infer/src/infer/snapshot/fudge.rs267
-rw-r--r--compiler/rustc_infer/src/infer/snapshot/mod.rs21
3 files changed, 175 insertions, 124 deletions
diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs
index f1195d0d4c0..5afdf3c2454 100644
--- a/compiler/rustc_infer/src/infer/mod.rs
+++ b/compiler/rustc_infer/src/infer/mod.rs
@@ -899,6 +899,13 @@ impl<'tcx> InferCtxt<'tcx> {
         ty::Const::new_var(self.tcx, vid)
     }
 
+    fn next_effect_var(&self) -> ty::Const<'tcx> {
+        let effect_vid =
+            self.inner.borrow_mut().effect_unification_table().new_key(EffectVarValue::Unknown).vid;
+
+        ty::Const::new_infer(self.tcx, ty::InferConst::EffectVar(effect_vid))
+    }
+
     pub fn next_int_var(&self) -> Ty<'tcx> {
         let next_int_var_id =
             self.inner.borrow_mut().int_unification_table().new_key(ty::IntVarValue::Unknown);
@@ -1001,15 +1008,13 @@ impl<'tcx> InferCtxt<'tcx> {
     }
 
     pub fn var_for_effect(&self, param: &ty::GenericParamDef) -> GenericArg<'tcx> {
-        let effect_vid =
-            self.inner.borrow_mut().effect_unification_table().new_key(EffectVarValue::Unknown).vid;
         let ty = self
             .tcx
             .type_of(param.def_id)
             .no_bound_vars()
             .expect("const parameter types cannot be generic");
         debug_assert_eq!(self.tcx.types.bool, ty);
-        ty::Const::new_infer(self.tcx, ty::InferConst::EffectVar(effect_vid)).into()
+        self.next_effect_var().into()
     }
 
     /// Given a set of generics defined on a type or impl, returns the generic parameters mapping
diff --git a/compiler/rustc_infer/src/infer/snapshot/fudge.rs b/compiler/rustc_infer/src/infer/snapshot/fudge.rs
index 8e330a084c6..613cebc266d 100644
--- a/compiler/rustc_infer/src/infer/snapshot/fudge.rs
+++ b/compiler/rustc_infer/src/infer/snapshot/fudge.rs
@@ -4,9 +4,12 @@ use rustc_data_structures::{snapshot_vec as sv, unify as ut};
 use rustc_middle::infer::unify_key::{ConstVariableValue, ConstVidKey};
 use rustc_middle::ty::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
 use rustc_middle::ty::{self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid};
+use rustc_type_ir::EffectVid;
+use rustc_type_ir::visit::TypeVisitableExt;
 use tracing::instrument;
 use ut::UnifyKey;
 
+use super::VariableLengths;
 use crate::infer::type_variable::TypeVariableOrigin;
 use crate::infer::{ConstVariableOrigin, InferCtxt, RegionVariableOrigin, UnificationTable};
 
@@ -40,26 +43,7 @@ fn const_vars_since_snapshot<'tcx>(
     )
 }
 
-struct VariableLengths {
-    type_var_len: usize,
-    const_var_len: usize,
-    int_var_len: usize,
-    float_var_len: usize,
-    region_constraints_len: usize,
-}
-
 impl<'tcx> InferCtxt<'tcx> {
-    fn variable_lengths(&self) -> VariableLengths {
-        let mut inner = self.inner.borrow_mut();
-        VariableLengths {
-            type_var_len: inner.type_variables().num_vars(),
-            const_var_len: inner.const_unification_table().len(),
-            int_var_len: inner.int_unification_table().len(),
-            float_var_len: inner.float_unification_table().len(),
-            region_constraints_len: inner.unwrap_region_constraints().num_region_vars(),
-        }
-    }
-
     /// This rather funky routine is used while processing expected
     /// types. What happens here is that we want to propagate a
     /// coercion through the return type of a fn to its
@@ -106,78 +90,94 @@ impl<'tcx> InferCtxt<'tcx> {
         T: TypeFoldable<TyCtxt<'tcx>>,
     {
         let variable_lengths = self.variable_lengths();
-        let (mut fudger, value) = self.probe(|_| {
-            match f() {
-                Ok(value) => {
-                    let value = self.resolve_vars_if_possible(value);
-
-                    // At this point, `value` could in principle refer
-                    // to inference variables that have been created during
-                    // the snapshot. Once we exit `probe()`, those are
-                    // going to be popped, so we will have to
-                    // eliminate any references to them.
-
-                    let mut inner = self.inner.borrow_mut();
-                    let type_vars =
-                        inner.type_variables().vars_since_snapshot(variable_lengths.type_var_len);
-                    let int_vars = vars_since_snapshot(
-                        &inner.int_unification_table(),
-                        variable_lengths.int_var_len,
-                    );
-                    let float_vars = vars_since_snapshot(
-                        &inner.float_unification_table(),
-                        variable_lengths.float_var_len,
-                    );
-                    let region_vars = inner
-                        .unwrap_region_constraints()
-                        .vars_since_snapshot(variable_lengths.region_constraints_len);
-                    let const_vars = const_vars_since_snapshot(
-                        &mut inner.const_unification_table(),
-                        variable_lengths.const_var_len,
-                    );
-
-                    let fudger = InferenceFudger {
-                        infcx: self,
-                        type_vars,
-                        int_vars,
-                        float_vars,
-                        region_vars,
-                        const_vars,
-                    };
-
-                    Ok((fudger, value))
-                }
-                Err(e) => Err(e),
-            }
+        let (snapshot_vars, value) = self.probe(|_| {
+            let value = f()?;
+            // At this point, `value` could in principle refer
+            // to inference variables that have been created during
+            // the snapshot. Once we exit `probe()`, those are
+            // going to be popped, so we will have to
+            // eliminate any references to them.
+            let snapshot_vars = SnapshotVarData::new(self, variable_lengths);
+            Ok((snapshot_vars, self.resolve_vars_if_possible(value)))
         })?;
 
         // At this point, we need to replace any of the now-popped
         // type/region variables that appear in `value` with a fresh
         // variable of the appropriate kind. We can't do this during
         // the probe because they would just get popped then too. =)
+        Ok(self.fudge_inference(snapshot_vars, value))
+    }
 
+    fn fudge_inference<T: TypeFoldable<TyCtxt<'tcx>>>(
+        &self,
+        snapshot_vars: SnapshotVarData,
+        value: T,
+    ) -> T {
         // Micro-optimization: if no variables have been created, then
         // `value` can't refer to any of them. =) So we can just return it.
-        if fudger.type_vars.0.is_empty()
-            && fudger.int_vars.is_empty()
-            && fudger.float_vars.is_empty()
-            && fudger.region_vars.0.is_empty()
-            && fudger.const_vars.0.is_empty()
-        {
-            Ok(value)
+        if snapshot_vars.is_empty() {
+            value
         } else {
-            Ok(value.fold_with(&mut fudger))
+            value.fold_with(&mut InferenceFudger { infcx: self, snapshot_vars })
         }
     }
 }
 
-struct InferenceFudger<'a, 'tcx> {
-    infcx: &'a InferCtxt<'tcx>,
+struct SnapshotVarData {
+    region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
     type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
     int_vars: Range<IntVid>,
     float_vars: Range<FloatVid>,
-    region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
     const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
+    effect_vars: Range<EffectVid>,
+}
+
+impl SnapshotVarData {
+    fn new(infcx: &InferCtxt<'_>, vars_pre_snapshot: VariableLengths) -> SnapshotVarData {
+        let mut inner = infcx.inner.borrow_mut();
+        let region_vars = inner
+            .unwrap_region_constraints()
+            .vars_since_snapshot(vars_pre_snapshot.region_constraints_len);
+        let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
+        let int_vars =
+            vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
+        let float_vars =
+            vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);
+
+        let const_vars = const_vars_since_snapshot(
+            &mut inner.const_unification_table(),
+            vars_pre_snapshot.const_var_len,
+        );
+        let effect_vars = vars_since_snapshot(
+            &inner.effect_unification_table(),
+            vars_pre_snapshot.effect_var_len,
+        );
+        let effect_vars = effect_vars.start.vid..effect_vars.end.vid;
+
+        SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars, effect_vars }
+    }
+
+    fn is_empty(&self) -> bool {
+        let SnapshotVarData {
+            region_vars,
+            type_vars,
+            int_vars,
+            float_vars,
+            const_vars,
+            effect_vars,
+        } = self;
+        region_vars.0.is_empty()
+            && type_vars.0.is_empty()
+            && int_vars.is_empty()
+            && float_vars.is_empty()
+            && const_vars.0.is_empty()
+            && effect_vars.is_empty()
+    }
+}
+
+struct InferenceFudger<'a, 'tcx> {
+    infcx: &'a InferCtxt<'tcx>,
+    snapshot_vars: SnapshotVarData,
 }
 
 impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
@@ -186,68 +186,93 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
     }
 
     fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
-        match *ty.kind() {
-            ty::Infer(ty::InferTy::TyVar(vid)) => {
-                if self.type_vars.0.contains(&vid) {
-                    // This variable was created during the fudging.
-                    // Recreate it with a fresh variable here.
-                    let idx = vid.as_usize() - self.type_vars.0.start.as_usize();
-                    let origin = self.type_vars.1[idx];
-                    self.infcx.next_ty_var_with_origin(origin)
-                } else {
-                    // This variable was created before the
-                    // "fudging". Since we refresh all type
-                    // variables to their binding anyhow, we know
-                    // that it is unbound, so we can just return
-                    // it.
-                    debug_assert!(
-                        self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
-                    );
-                    ty
+        if let &ty::Infer(infer_ty) = ty.kind() {
+            match infer_ty {
+                ty::TyVar(vid) => {
+                    if self.snapshot_vars.type_vars.0.contains(&vid) {
+                        // This variable was created during the fudging.
+                        // Recreate it with a fresh variable here.
+                        let idx = vid.as_usize() - self.snapshot_vars.type_vars.0.start.as_usize();
+                        let origin = self.snapshot_vars.type_vars.1[idx];
+                        self.infcx.next_ty_var_with_origin(origin)
+                    } else {
+                        // This variable was created before the
+                        // "fudging". Since we refresh all type
+                        // variables to their binding anyhow, we know
+                        // that it is unbound, so we can just return
+                        // it.
+                        debug_assert!(
+                            self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
+                        );
+                        ty
+                    }
                 }
-            }
-            ty::Infer(ty::InferTy::IntVar(vid)) => {
-                if self.int_vars.contains(&vid) {
-                    self.infcx.next_int_var()
-                } else {
-                    ty
+                ty::IntVar(vid) => {
+                    if self.snapshot_vars.int_vars.contains(&vid) {
+                        self.infcx.next_int_var()
+                    } else {
+                        ty
+                    }
                 }
-            }
-            ty::Infer(ty::InferTy::FloatVar(vid)) => {
-                if self.float_vars.contains(&vid) {
-                    self.infcx.next_float_var()
-                } else {
-                    ty
+                ty::FloatVar(vid) => {
+                    if self.snapshot_vars.float_vars.contains(&vid) {
+                        self.infcx.next_float_var()
+                    } else {
+                        ty
+                    }
+                }
+                ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => {
+                    unreachable!("unexpected fresh infcx var")
                 }
             }
-            _ => ty.super_fold_with(self),
+        } else if ty.has_infer() {
+            ty.super_fold_with(self)
+        } else {
+            ty
         }
     }
 
     fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
-        if let ty::ReVar(vid) = *r
-            && self.region_vars.0.contains(&vid)
-        {
-            let idx = vid.index() - self.region_vars.0.start.index();
-            let origin = self.region_vars.1[idx];
-            return self.infcx.next_region_var(origin);
+        if let ty::ReVar(vid) = r.kind() {
+            if self.snapshot_vars.region_vars.0.contains(&vid) {
+                let idx = vid.index() - self.snapshot_vars.region_vars.0.start.index();
+                let origin = self.snapshot_vars.region_vars.1[idx];
+                self.infcx.next_region_var(origin)
+            } else {
+                r
+            }
+        } else {
+            r
         }
-        r
     }
 
     fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
-        if let ty::ConstKind::Infer(ty::InferConst::Var(vid)) = ct.kind() {
-            if self.const_vars.0.contains(&vid) {
-                // This variable was created during the fudging.
-                // Recreate it with a fresh variable here.
-                let idx = vid.index() - self.const_vars.0.start.index();
-                let origin = self.const_vars.1[idx];
-                self.infcx.next_const_var_with_origin(origin)
-            } else {
-                ct
+        if let ty::ConstKind::Infer(infer_ct) = ct.kind() {
+            match infer_ct {
+                ty::InferConst::Var(vid) => {
+                    if self.snapshot_vars.const_vars.0.contains(&vid) {
+                        let idx = vid.index() - self.snapshot_vars.const_vars.0.start.index();
+                        let origin = self.snapshot_vars.const_vars.1[idx];
+                        self.infcx.next_const_var_with_origin(origin)
+                    } else {
+                        ct
+                    }
+                }
+                ty::InferConst::EffectVar(vid) => {
+                    if self.snapshot_vars.effect_vars.contains(&vid) {
+                        self.infcx.next_effect_var()
+                    } else {
+                        ct
+                    }
+                }
+                ty::InferConst::Fresh(_) => {
+                    unreachable!("unexpected fresh infcx var")
+                }
             }
-        } else {
+        } else if ct.has_infer() {
             ct.super_fold_with(self)
+        } else {
+            ct
         }
     }
 }
diff --git a/compiler/rustc_infer/src/infer/snapshot/mod.rs b/compiler/rustc_infer/src/infer/snapshot/mod.rs
index 07a482c2f9a..fa813500c54 100644
--- a/compiler/rustc_infer/src/infer/snapshot/mod.rs
+++ b/compiler/rustc_infer/src/infer/snapshot/mod.rs
@@ -17,7 +17,28 @@ pub struct CombinedSnapshot<'tcx> {
     universe: ty::UniverseIndex,
 }
 
+struct VariableLengths {
+    region_constraints_len: usize,
+    type_var_len: usize,
+    int_var_len: usize,
+    float_var_len: usize,
+    const_var_len: usize,
+    effect_var_len: usize,
+}
+
 impl<'tcx> InferCtxt<'tcx> {
+    fn variable_lengths(&self) -> VariableLengths {
+        let mut inner = self.inner.borrow_mut();
+        VariableLengths {
+            region_constraints_len: inner.unwrap_region_constraints().num_region_vars(),
+            type_var_len: inner.type_variables().num_vars(),
+            int_var_len: inner.int_unification_table().len(),
+            float_var_len: inner.float_unification_table().len(),
+            const_var_len: inner.const_unification_table().len(),
+            effect_var_len: inner.effect_unification_table().len(),
+        }
+    }
+
     pub fn in_snapshot(&self) -> bool {
         UndoLogs::<UndoLog<'tcx>>::in_snapshot(&self.inner.borrow_mut().undo_log)
     }