about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_hir_analysis/src/astconv/mod.rs2
-rw-r--r--compiler/rustc_hir_analysis/src/autoderef.rs2
-rw-r--r--compiler/rustc_hir_analysis/src/collect.rs2
-rw-r--r--compiler/rustc_hir_typeck/src/coercion.rs2
-rw-r--r--compiler/rustc_hir_typeck/src/expr.rs2
-rw-r--r--compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs2
-rw-r--r--compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs2
-rw-r--r--compiler/rustc_infer/src/infer/at.rs1
-rw-r--r--compiler/rustc_infer/src/infer/mod.rs41
-rw-r--r--compiler/rustc_infer/src/infer/outlives/obligations.rs5
-rw-r--r--compiler/rustc_trait_selection/src/solve/fulfill.rs14
-rw-r--r--compiler/rustc_trait_selection/src/traits/chalk_fulfill.rs38
-rw-r--r--compiler/rustc_trait_selection/src/traits/const_evaluatable.rs4
-rw-r--r--compiler/rustc_trait_selection/src/traits/engine.rs28
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/ambiguity.rs4
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs4
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs2
-rw-r--r--compiler/rustc_trait_selection/src/traits/fulfill.rs37
-rw-r--r--compiler/rustc_trait_selection/src/traits/mod.rs2
-rw-r--r--compiler/rustc_trait_selection/src/traits/query/evaluate_obligation.rs2
-rw-r--r--compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs2
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/mod.rs2
-rw-r--r--compiler/rustc_trait_selection/src/traits/specialize/mod.rs2
23 files changed, 80 insertions, 122 deletions
diff --git a/compiler/rustc_hir_analysis/src/astconv/mod.rs b/compiler/rustc_hir_analysis/src/astconv/mod.rs
index 924f6e723e8..2f5bcf8d647 100644
--- a/compiler/rustc_hir_analysis/src/astconv/mod.rs
+++ b/compiler/rustc_hir_analysis/src/astconv/mod.rs
@@ -1984,7 +1984,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
                 .copied()
                 .filter(|&(impl_, _)| {
                     infcx.probe(|_| {
-                        let ocx = ObligationCtxt::new_in_snapshot(&infcx);
+                        let ocx = ObligationCtxt::new(&infcx);
                         ocx.register_obligations(obligations.clone());
 
                         let impl_substs = infcx.fresh_substs_for_item(span, impl_);
diff --git a/compiler/rustc_hir_analysis/src/autoderef.rs b/compiler/rustc_hir_analysis/src/autoderef.rs
index 6eb18aecd65..8aa9a2c2734 100644
--- a/compiler/rustc_hir_analysis/src/autoderef.rs
+++ b/compiler/rustc_hir_analysis/src/autoderef.rs
@@ -161,7 +161,7 @@ impl<'a, 'tcx> Autoderef<'a, 'tcx> {
         &self,
         ty: Ty<'tcx>,
     ) -> Option<(Ty<'tcx>, Vec<traits::PredicateObligation<'tcx>>)> {
-        let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new_in_snapshot(self.infcx);
+        let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new(self.infcx);
 
         let cause = traits::ObligationCause::misc(self.span, self.body_id);
         let normalized_ty = match self
diff --git a/compiler/rustc_hir_analysis/src/collect.rs b/compiler/rustc_hir_analysis/src/collect.rs
index eba943e0072..d7ac9e7ce73 100644
--- a/compiler/rustc_hir_analysis/src/collect.rs
+++ b/compiler/rustc_hir_analysis/src/collect.rs
@@ -1328,7 +1328,7 @@ fn suggest_impl_trait<'tcx>(
         {
             continue;
         }
-        let ocx = ObligationCtxt::new_in_snapshot(&infcx);
+        let ocx = ObligationCtxt::new(&infcx);
         let item_ty = ocx.normalize(
             &ObligationCause::misc(span, def_id),
             param_env,
diff --git a/compiler/rustc_hir_typeck/src/coercion.rs b/compiler/rustc_hir_typeck/src/coercion.rs
index c0c839b1f18..5f98bacaf2a 100644
--- a/compiler/rustc_hir_typeck/src/coercion.rs
+++ b/compiler/rustc_hir_typeck/src/coercion.rs
@@ -1033,7 +1033,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             let Ok(ok) = coerce.coerce(source, target) else {
                 return false;
             };
-            let ocx = ObligationCtxt::new_in_snapshot(self);
+            let ocx = ObligationCtxt::new(self);
             ocx.register_obligations(ok.obligations);
             ocx.select_where_possible().is_empty()
         })
diff --git a/compiler/rustc_hir_typeck/src/expr.rs b/compiler/rustc_hir_typeck/src/expr.rs
index 8d621c5a42b..f5be030c1a5 100644
--- a/compiler/rustc_hir_typeck/src/expr.rs
+++ b/compiler/rustc_hir_typeck/src/expr.rs
@@ -2962,7 +2962,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         };
 
         self.commit_if_ok(|_| {
-            let ocx = ObligationCtxt::new_in_snapshot(self);
+            let ocx = ObligationCtxt::new(self);
             let impl_substs = self.fresh_substs_for_item(base_expr.span, impl_def_id);
             let impl_trait_ref =
                 self.tcx.impl_trait_ref(impl_def_id).unwrap().subst(self.tcx, impl_substs);
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
index 6921a0bb283..0ee87173a36 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
@@ -746,7 +746,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
         let expect_args = self
             .fudge_inference_if_ok(|| {
-                let ocx = ObligationCtxt::new_in_snapshot(self);
+                let ocx = ObligationCtxt::new(self);
 
                 // Attempt to apply a subtyping relationship between the formal
                 // return type (likely containing type variables if the function
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs
index 00a3f47b306..e035d233bf7 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs
@@ -163,7 +163,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                     return fn_sig;
                 }
                 self.probe(|_| {
-                    let ocx = ObligationCtxt::new_in_snapshot(self);
+                    let ocx = ObligationCtxt::new(self);
                     let normalized_fn_sig =
                         ocx.normalize(&ObligationCause::dummy(), self.param_env, fn_sig);
                     if ocx.select_all_or_error().is_empty() {
diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs
index 1f0bf4f9887..f1a187d2677 100644
--- a/compiler/rustc_infer/src/infer/at.rs
+++ b/compiler/rustc_infer/src/infer/at.rs
@@ -79,7 +79,6 @@ impl<'tcx> InferCtxt<'tcx> {
             reported_closure_mismatch: self.reported_closure_mismatch.clone(),
             tainted_by_errors: self.tainted_by_errors.clone(),
             err_count_on_creation: self.err_count_on_creation,
-            in_snapshot: self.in_snapshot.clone(),
             universe: self.universe.clone(),
             intercrate: self.intercrate,
             next_trait_solver: self.next_trait_solver,
diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs
index 7dbc18908d5..f1f5ac81fb7 100644
--- a/compiler/rustc_infer/src/infer/mod.rs
+++ b/compiler/rustc_infer/src/infer/mod.rs
@@ -6,6 +6,7 @@ pub use self::RegionVariableOrigin::*;
 pub use self::SubregionOrigin::*;
 pub use self::ValuePairs::*;
 pub use combine::ObligationEmittingRelation;
+use rustc_data_structures::undo_log::UndoLogs;
 
 use self::opaque_types::OpaqueTypeStorage;
 pub(crate) use self::undo_log::{InferCtxtUndoLogs, Snapshot, UndoLog};
@@ -297,9 +298,6 @@ pub struct InferCtxt<'tcx> {
     // FIXME(matthewjasper) Merge into `tainted_by_errors`
     err_count_on_creation: usize,
 
-    /// This flag is true while there is an active snapshot.
-    in_snapshot: Cell<bool>,
-
     /// What is the innermost universe we have created? Starts out as
     /// `UniverseIndex::root()` but grows from there as we enter
     /// universal quantifiers.
@@ -643,7 +641,6 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
             reported_closure_mismatch: Default::default(),
             tainted_by_errors: Cell::new(None),
             err_count_on_creation: tcx.sess.err_count(),
-            in_snapshot: Cell::new(false),
             universe: Cell::new(ty::UniverseIndex::ROOT),
             intercrate,
             next_trait_solver,
@@ -679,7 +676,6 @@ pub struct CombinedSnapshot<'tcx> {
     undo_snapshot: Snapshot<'tcx>,
     region_constraints_snapshot: RegionSnapshot,
     universe: ty::UniverseIndex,
-    was_in_snapshot: bool,
 }
 
 impl<'tcx> InferCtxt<'tcx> {
@@ -702,10 +698,6 @@ impl<'tcx> InferCtxt<'tcx> {
         }
     }
 
-    pub fn is_in_snapshot(&self) -> bool {
-        self.in_snapshot.get()
-    }
-
     pub fn freshen<T: TypeFoldable<TyCtxt<'tcx>>>(&self, t: T) -> T {
         t.fold_with(&mut self.freshener())
     }
@@ -766,31 +758,30 @@ impl<'tcx> InferCtxt<'tcx> {
         }
     }
 
+    pub fn in_snapshot(&self) -> bool {
+        UndoLogs::<UndoLog<'tcx>>::in_snapshot(&self.inner.borrow_mut().undo_log)
+    }
+
+    pub fn num_open_snapshots(&self) -> usize {
+        UndoLogs::<UndoLog<'tcx>>::num_open_snapshots(&self.inner.borrow_mut().undo_log)
+    }
+
     fn start_snapshot(&self) -> CombinedSnapshot<'tcx> {
         debug!("start_snapshot()");
 
-        let in_snapshot = self.in_snapshot.replace(true);
-
         let mut inner = self.inner.borrow_mut();
 
         CombinedSnapshot {
             undo_snapshot: inner.undo_log.start_snapshot(),
             region_constraints_snapshot: inner.unwrap_region_constraints().start_snapshot(),
             universe: self.universe(),
-            was_in_snapshot: in_snapshot,
         }
     }
 
     #[instrument(skip(self, snapshot), level = "debug")]
     fn rollback_to(&self, cause: &str, snapshot: CombinedSnapshot<'tcx>) {
-        let CombinedSnapshot {
-            undo_snapshot,
-            region_constraints_snapshot,
-            universe,
-            was_in_snapshot,
-        } = snapshot;
-
-        self.in_snapshot.set(was_in_snapshot);
+        let CombinedSnapshot { undo_snapshot, region_constraints_snapshot, universe } = snapshot;
+
         self.universe.set(universe);
 
         let mut inner = self.inner.borrow_mut();
@@ -800,14 +791,8 @@ impl<'tcx> InferCtxt<'tcx> {
 
     #[instrument(skip(self, snapshot), level = "debug")]
     fn commit_from(&self, snapshot: CombinedSnapshot<'tcx>) {
-        let CombinedSnapshot {
-            undo_snapshot,
-            region_constraints_snapshot: _,
-            universe: _,
-            was_in_snapshot,
-        } = snapshot;
-
-        self.in_snapshot.set(was_in_snapshot);
+        let CombinedSnapshot { undo_snapshot, region_constraints_snapshot: _, universe: _ } =
+            snapshot;
 
         self.inner.borrow_mut().commit(undo_snapshot);
     }
diff --git a/compiler/rustc_infer/src/infer/outlives/obligations.rs b/compiler/rustc_infer/src/infer/outlives/obligations.rs
index 9c20c814b69..73df6d03f86 100644
--- a/compiler/rustc_infer/src/infer/outlives/obligations.rs
+++ b/compiler/rustc_infer/src/infer/outlives/obligations.rs
@@ -125,10 +125,7 @@ impl<'tcx> InferCtxt<'tcx> {
     /// right before lexical region resolution.
     #[instrument(level = "debug", skip(self, outlives_env))]
     pub fn process_registered_region_obligations(&self, outlives_env: &OutlivesEnvironment<'tcx>) {
-        assert!(
-            !self.in_snapshot.get(),
-            "cannot process registered region obligations in a snapshot"
-        );
+        assert!(!self.in_snapshot(), "cannot process registered region obligations in a snapshot");
 
         let my_region_obligations = self.take_registered_region_obligations();
 
diff --git a/compiler/rustc_trait_selection/src/solve/fulfill.rs b/compiler/rustc_trait_selection/src/solve/fulfill.rs
index 5c62ea64f99..fc848fe3080 100644
--- a/compiler/rustc_trait_selection/src/solve/fulfill.rs
+++ b/compiler/rustc_trait_selection/src/solve/fulfill.rs
@@ -26,20 +26,27 @@ use super::{Certainty, InferCtxtEvalExt};
 /// here as this will have to deal with far more root goals than `evaluate_all`.
 pub struct FulfillmentCtxt<'tcx> {
     obligations: Vec<PredicateObligation<'tcx>>,
+
+    /// The snapshot in which this context was created. Using the context
+    /// outside of this snapshot leads to subtle bugs if the snapshot
+    /// gets rolled back. Because of this we explicitly check that we only
+    /// use the context in exactly this snapshot.
+    usable_in_snapshot: usize,
 }
 
 impl<'tcx> FulfillmentCtxt<'tcx> {
-    pub fn new() -> FulfillmentCtxt<'tcx> {
-        FulfillmentCtxt { obligations: Vec::new() }
+    pub fn new(infcx: &InferCtxt<'tcx>) -> FulfillmentCtxt<'tcx> {
+        FulfillmentCtxt { obligations: Vec::new(), usable_in_snapshot: infcx.num_open_snapshots() }
     }
 }
 
 impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
     fn register_predicate_obligation(
         &mut self,
-        _infcx: &InferCtxt<'tcx>,
+        infcx: &InferCtxt<'tcx>,
         obligation: PredicateObligation<'tcx>,
     ) {
+        assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
         self.obligations.push(obligation);
     }
 
@@ -77,6 +84,7 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
     }
 
     fn select_where_possible(&mut self, infcx: &InferCtxt<'tcx>) -> Vec<FulfillmentError<'tcx>> {
+        assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
         let mut errors = Vec::new();
         for i in 0.. {
             if !infcx.tcx.recursion_limit().value_within_limit(i) {
diff --git a/compiler/rustc_trait_selection/src/traits/chalk_fulfill.rs b/compiler/rustc_trait_selection/src/traits/chalk_fulfill.rs
index 28967e1cc55..3ecae429c59 100644
--- a/compiler/rustc_trait_selection/src/traits/chalk_fulfill.rs
+++ b/compiler/rustc_trait_selection/src/traits/chalk_fulfill.rs
@@ -13,16 +13,19 @@ use rustc_middle::ty::TypeVisitableExt;
 pub struct FulfillmentContext<'tcx> {
     obligations: FxIndexSet<PredicateObligation<'tcx>>,
 
-    usable_in_snapshot: bool,
+    /// The snapshot in which this context was created. Using the context
+    /// outside of this snapshot leads to subtle bugs if the snapshot
+    /// gets rolled back. Because of this we explicitly check that we only
+    /// use the context in exactly this snapshot.
+    usable_in_snapshot: usize,
 }
 
-impl FulfillmentContext<'_> {
-    pub(super) fn new() -> Self {
-        FulfillmentContext { obligations: FxIndexSet::default(), usable_in_snapshot: false }
-    }
-
-    pub(crate) fn new_in_snapshot() -> Self {
-        FulfillmentContext { usable_in_snapshot: true, ..Self::new() }
+impl<'tcx> FulfillmentContext<'tcx> {
+    pub(super) fn new(infcx: &InferCtxt<'tcx>) -> Self {
+        FulfillmentContext {
+            obligations: FxIndexSet::default(),
+            usable_in_snapshot: infcx.num_open_snapshots(),
+        }
     }
 }
 
@@ -32,9 +35,7 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> {
         infcx: &InferCtxt<'tcx>,
         obligation: PredicateObligation<'tcx>,
     ) {
-        if !self.usable_in_snapshot {
-            assert!(!infcx.is_in_snapshot());
-        }
+        assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
         let obligation = infcx.resolve_vars_if_possible(obligation);
 
         self.obligations.insert(obligation);
@@ -58,9 +59,7 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> {
     }
 
     fn select_where_possible(&mut self, infcx: &InferCtxt<'tcx>) -> Vec<FulfillmentError<'tcx>> {
-        if !self.usable_in_snapshot {
-            assert!(!infcx.is_in_snapshot());
-        }
+        assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
 
         let mut errors = Vec::new();
         let mut next_round = FxIndexSet::default();
@@ -94,12 +93,11 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> {
                                 &orig_values,
                                 &response,
                             ) {
-                                Ok(infer_ok) => next_round.extend(
-                                    infer_ok.obligations.into_iter().map(|obligation| {
-                                        assert!(!infcx.is_in_snapshot());
-                                        infcx.resolve_vars_if_possible(obligation)
-                                    }),
-                                ),
+                                Ok(infer_ok) => {
+                                    next_round.extend(infer_ok.obligations.into_iter().map(
+                                        |obligation| infcx.resolve_vars_if_possible(obligation),
+                                    ))
+                                }
 
                                 Err(_err) => errors.push(FulfillmentError {
                                     obligation: obligation.clone(),
diff --git a/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs b/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs
index ab4727b8697..8dc13b827e5 100644
--- a/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs
+++ b/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs
@@ -176,7 +176,7 @@ fn satisfied_from_param_env<'tcx>(
         fn visit_const(&mut self, c: ty::Const<'tcx>) -> ControlFlow<Self::BreakTy> {
             debug!("is_const_evaluatable: candidate={:?}", c);
             if self.infcx.probe(|_| {
-                let ocx = ObligationCtxt::new_in_snapshot(self.infcx);
+                let ocx = ObligationCtxt::new(self.infcx);
                 ocx.eq(&ObligationCause::dummy(), self.param_env, c.ty(), self.ct.ty()).is_ok()
                     && ocx.eq(&ObligationCause::dummy(), self.param_env, c, self.ct).is_ok()
                     && ocx.select_all_or_error().is_empty()
@@ -219,7 +219,7 @@ fn satisfied_from_param_env<'tcx>(
     }
 
     if let Some(Ok(c)) = single_match {
-        let ocx = ObligationCtxt::new_in_snapshot(infcx);
+        let ocx = ObligationCtxt::new(infcx);
         assert!(ocx.eq(&ObligationCause::dummy(), param_env, c.ty(), ct.ty()).is_ok());
         assert!(ocx.eq(&ObligationCause::dummy(), param_env, c, ct).is_ok());
         assert!(ocx.select_all_or_error().is_empty());
diff --git a/compiler/rustc_trait_selection/src/traits/engine.rs b/compiler/rustc_trait_selection/src/traits/engine.rs
index 90699c3cadc..faa675054b7 100644
--- a/compiler/rustc_trait_selection/src/traits/engine.rs
+++ b/compiler/rustc_trait_selection/src/traits/engine.rs
@@ -28,36 +28,18 @@ use rustc_span::Span;
 
 pub trait TraitEngineExt<'tcx> {
     fn new(infcx: &InferCtxt<'tcx>) -> Box<Self>;
-    fn new_in_snapshot(infcx: &InferCtxt<'tcx>) -> Box<Self>;
 }
 
 impl<'tcx> TraitEngineExt<'tcx> for dyn TraitEngine<'tcx> {
     fn new(infcx: &InferCtxt<'tcx>) -> Box<Self> {
         match (infcx.tcx.sess.opts.unstable_opts.trait_solver, infcx.next_trait_solver()) {
             (TraitSolver::Classic, false) | (TraitSolver::NextCoherence, false) => {
-                Box::new(FulfillmentContext::new())
+                Box::new(FulfillmentContext::new(infcx))
             }
             (TraitSolver::Next | TraitSolver::NextCoherence, true) => {
-                Box::new(NextFulfillmentCtxt::new())
+                Box::new(NextFulfillmentCtxt::new(infcx))
             }
-            (TraitSolver::Chalk, false) => Box::new(ChalkFulfillmentContext::new()),
-            _ => bug!(
-                "incompatible combination of -Ztrait-solver flag ({:?}) and InferCtxt::next_trait_solver ({:?})",
-                infcx.tcx.sess.opts.unstable_opts.trait_solver,
-                infcx.next_trait_solver()
-            ),
-        }
-    }
-
-    fn new_in_snapshot(infcx: &InferCtxt<'tcx>) -> Box<Self> {
-        match (infcx.tcx.sess.opts.unstable_opts.trait_solver, infcx.next_trait_solver()) {
-            (TraitSolver::Classic, false) | (TraitSolver::NextCoherence, false) => {
-                Box::new(FulfillmentContext::new_in_snapshot())
-            }
-            (TraitSolver::Next | TraitSolver::NextCoherence, true) => {
-                Box::new(NextFulfillmentCtxt::new())
-            }
-            (TraitSolver::Chalk, false) => Box::new(ChalkFulfillmentContext::new_in_snapshot()),
+            (TraitSolver::Chalk, false) => Box::new(ChalkFulfillmentContext::new(infcx)),
             _ => bug!(
                 "incompatible combination of -Ztrait-solver flag ({:?}) and InferCtxt::next_trait_solver ({:?})",
                 infcx.tcx.sess.opts.unstable_opts.trait_solver,
@@ -79,10 +61,6 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
         Self { infcx, engine: RefCell::new(<dyn TraitEngine<'_>>::new(infcx)) }
     }
 
-    pub fn new_in_snapshot(infcx: &'a InferCtxt<'tcx>) -> Self {
-        Self { infcx, engine: RefCell::new(<dyn TraitEngine<'_>>::new_in_snapshot(infcx)) }
-    }
-
     pub fn register_obligation(&self, obligation: PredicateObligation<'tcx>) {
         self.engine.borrow_mut().register_predicate_obligation(self.infcx, obligation);
     }
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/ambiguity.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/ambiguity.rs
index 96c183f9a58..7b5d4f456ff 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/ambiguity.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/ambiguity.rs
@@ -20,7 +20,7 @@ pub fn recompute_applicable_impls<'tcx>(
     let param_env = obligation.param_env;
 
     let impl_may_apply = |impl_def_id| {
-        let ocx = ObligationCtxt::new_in_snapshot(infcx);
+        let ocx = ObligationCtxt::new(infcx);
         let placeholder_obligation =
             infcx.instantiate_binder_with_placeholders(obligation.predicate);
         let obligation_trait_ref =
@@ -45,7 +45,7 @@ pub fn recompute_applicable_impls<'tcx>(
     };
 
     let param_env_candidate_may_apply = |poly_trait_predicate: ty::PolyTraitPredicate<'tcx>| {
-        let ocx = ObligationCtxt::new_in_snapshot(infcx);
+        let ocx = ObligationCtxt::new(infcx);
         let placeholder_obligation =
             infcx.instantiate_binder_with_placeholders(obligation.predicate);
         let obligation_trait_ref =
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs
index c259260c1fa..f7670d51bdc 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs
@@ -377,7 +377,7 @@ impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
                     param_env,
                     ty.rebind(ty::TraitPredicate { trait_ref, constness, polarity }),
                 );
-                let ocx = ObligationCtxt::new_in_snapshot(self);
+                let ocx = ObligationCtxt::new(self);
                 ocx.register_obligation(obligation);
                 if ocx.select_all_or_error().is_empty() {
                     return Ok((
@@ -1599,7 +1599,7 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
         }
 
         self.probe(|_| {
-            let ocx = ObligationCtxt::new_in_snapshot(self);
+            let ocx = ObligationCtxt::new(self);
 
             // try to find the mismatched types to report the error with.
             //
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
index 43b63762ba3..619a099fcb5 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
@@ -3825,7 +3825,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
         body_id: hir::HirId,
         param_env: ty::ParamEnv<'tcx>,
     ) -> Vec<Option<(Span, (DefId, Ty<'tcx>))>> {
-        let ocx = ObligationCtxt::new_in_snapshot(self.infcx);
+        let ocx = ObligationCtxt::new(self.infcx);
         let mut assocs_in_this_method = Vec::with_capacity(type_diffs.len());
         for diff in type_diffs {
             let Sorts(expected_found) = diff else { continue; };
diff --git a/compiler/rustc_trait_selection/src/traits/fulfill.rs b/compiler/rustc_trait_selection/src/traits/fulfill.rs
index f2f99eb60e4..e3472a1c4c1 100644
--- a/compiler/rustc_trait_selection/src/traits/fulfill.rs
+++ b/compiler/rustc_trait_selection/src/traits/fulfill.rs
@@ -50,20 +50,15 @@ impl<'tcx> ForestObligation for PendingPredicateObligation<'tcx> {
 /// method `select_all_or_error` can be used to report any remaining
 /// ambiguous cases as errors.
 pub struct FulfillmentContext<'tcx> {
-    // A list of all obligations that have been registered with this
-    // fulfillment context.
+    /// A list of all obligations that have been registered with this
+    /// fulfillment context.
     predicates: ObligationForest<PendingPredicateObligation<'tcx>>,
 
-    // Is it OK to register obligations into this infcx inside
-    // an infcx snapshot?
-    //
-    // The "primary fulfillment" in many cases in typeck lives
-    // outside of any snapshot, so any use of it inside a snapshot
-    // will lead to trouble and therefore is checked against, but
-    // other fulfillment contexts sometimes do live inside of
-    // a snapshot (they don't *straddle* a snapshot, so there
-    // is no trouble there).
-    usable_in_snapshot: bool,
+    /// The snapshot in which this context was created. Using the context
+    /// outside of this snapshot leads to subtle bugs if the snapshot
+    /// gets rolled back. Because of this we explicitly check that we only
+    /// use the context in exactly this snapshot.
+    usable_in_snapshot: usize,
 }
 
 #[derive(Clone, Debug)]
@@ -80,18 +75,17 @@ pub struct PendingPredicateObligation<'tcx> {
 #[cfg(all(target_arch = "x86_64", target_pointer_width = "64"))]
 static_assert_size!(PendingPredicateObligation<'_>, 72);
 
-impl<'a, 'tcx> FulfillmentContext<'tcx> {
+impl<'tcx> FulfillmentContext<'tcx> {
     /// Creates a new fulfillment context.
-    pub(super) fn new() -> FulfillmentContext<'tcx> {
-        FulfillmentContext { predicates: ObligationForest::new(), usable_in_snapshot: false }
-    }
-
-    pub(super) fn new_in_snapshot() -> FulfillmentContext<'tcx> {
-        FulfillmentContext { predicates: ObligationForest::new(), usable_in_snapshot: true }
+    pub(super) fn new(infcx: &InferCtxt<'tcx>) -> FulfillmentContext<'tcx> {
+        FulfillmentContext {
+            predicates: ObligationForest::new(),
+            usable_in_snapshot: infcx.num_open_snapshots(),
+        }
     }
 
     /// Attempts to select obligations using `selcx`.
-    fn select(&mut self, selcx: SelectionContext<'a, 'tcx>) -> Vec<FulfillmentError<'tcx>> {
+    fn select(&mut self, selcx: SelectionContext<'_, 'tcx>) -> Vec<FulfillmentError<'tcx>> {
         let span = debug_span!("select", obligation_forest_size = ?self.predicates.len());
         let _enter = span.enter();
 
@@ -122,14 +116,13 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> {
         infcx: &InferCtxt<'tcx>,
         obligation: PredicateObligation<'tcx>,
     ) {
+        assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
         // this helps to reduce duplicate errors, as well as making
         // debug output much nicer to read and so on.
         let obligation = infcx.resolve_vars_if_possible(obligation);
 
         debug!(?obligation, "register_predicate_obligation");
 
-        assert!(!infcx.is_in_snapshot() || self.usable_in_snapshot);
-
         self.predicates
             .register_obligation(PendingPredicateObligation { obligation, stalled_on: vec![] });
     }
diff --git a/compiler/rustc_trait_selection/src/traits/mod.rs b/compiler/rustc_trait_selection/src/traits/mod.rs
index ae76651c336..5dc5ddbddbd 100644
--- a/compiler/rustc_trait_selection/src/traits/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/mod.rs
@@ -161,7 +161,7 @@ fn pred_known_to_hold_modulo_regions<'tcx>(
         // the we do no inference in the process of checking this obligation.
         let goal = infcx.resolve_vars_if_possible((obligation.predicate, obligation.param_env));
         infcx.probe(|_| {
-            let ocx = ObligationCtxt::new_in_snapshot(infcx);
+            let ocx = ObligationCtxt::new(infcx);
             ocx.register_obligation(obligation);
 
             let errors = ocx.select_all_or_error();
diff --git a/compiler/rustc_trait_selection/src/traits/query/evaluate_obligation.rs b/compiler/rustc_trait_selection/src/traits/query/evaluate_obligation.rs
index c93c30b7053..e29e1b25919 100644
--- a/compiler/rustc_trait_selection/src/traits/query/evaluate_obligation.rs
+++ b/compiler/rustc_trait_selection/src/traits/query/evaluate_obligation.rs
@@ -80,7 +80,7 @@ impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
 
         if self.next_trait_solver() {
             self.probe(|snapshot| {
-                let mut fulfill_cx = crate::solve::FulfillmentCtxt::new();
+                let mut fulfill_cx = crate::solve::FulfillmentCtxt::new(self);
                 fulfill_cx.register_predicate_obligation(self, obligation.clone());
                 // True errors
                 // FIXME(-Ztrait-solver=next): Overflows are reported as ambig here, is that OK?
diff --git a/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs b/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs
index 8b0973021bc..5420caee329 100644
--- a/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs
+++ b/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs
@@ -82,7 +82,7 @@ where
     );
 
     let value = infcx.commit_if_ok(|_| {
-        let ocx = ObligationCtxt::new_in_snapshot(infcx);
+        let ocx = ObligationCtxt::new(infcx);
         let value = op(&ocx).map_err(|_| {
             infcx.tcx.sess.delay_span_bug(span, format!("error performing operation: {name}"))
         })?;
diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index 7406b47e327..7cf8479b803 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -606,7 +606,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         predicates: impl IntoIterator<Item = PredicateObligation<'tcx>>,
     ) -> Result<EvaluationResult, OverflowError> {
-        let mut fulfill_cx = crate::solve::FulfillmentCtxt::new();
+        let mut fulfill_cx = crate::solve::FulfillmentCtxt::new(self.infcx);
         fulfill_cx.register_predicate_obligations(self.infcx, predicates);
         // True errors
         // FIXME(-Ztrait-solver=next): Overflows are reported as ambig here, is that OK?
diff --git a/compiler/rustc_trait_selection/src/traits/specialize/mod.rs b/compiler/rustc_trait_selection/src/traits/specialize/mod.rs
index fee38aed0e2..96f1287416f 100644
--- a/compiler/rustc_trait_selection/src/traits/specialize/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/specialize/mod.rs
@@ -238,7 +238,7 @@ fn fulfill_implication<'tcx>(
 
     // Needs to be `in_snapshot` because this function is used to rebase
     // substitutions, which may happen inside of a select within a probe.
-    let ocx = ObligationCtxt::new_in_snapshot(infcx);
+    let ocx = ObligationCtxt::new(infcx);
     // attempt to prove all of the predicates for impl2 given those for impl1
     // (which are packed up in penv)
     ocx.register_obligations(obligations.chain(more_obligations));