about summary refs log tree commit diff
path: root/compiler/rustc_trait_selection/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_trait_selection/src')
-rw-r--r--compiler/rustc_trait_selection/src/traits/auto_trait.rs62
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/confirmation.rs120
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/mod.rs20
3 files changed, 84 insertions, 118 deletions
diff --git a/compiler/rustc_trait_selection/src/traits/auto_trait.rs b/compiler/rustc_trait_selection/src/traits/auto_trait.rs
index c909a0b49e2..73e94da165f 100644
--- a/compiler/rustc_trait_selection/src/traits/auto_trait.rs
+++ b/compiler/rustc_trait_selection/src/traits/auto_trait.rs
@@ -6,13 +6,13 @@ use super::*;
 use crate::errors::UnableToConstructConstantValue;
 use crate::infer::region_constraints::{Constraint, RegionConstraintData};
 use crate::traits::project::ProjectAndUnifyResult;
+
+use rustc_data_structures::fx::{FxIndexMap, FxIndexSet, IndexEntry};
+use rustc_data_structures::unord::UnordSet;
 use rustc_infer::infer::DefineOpaqueTypes;
 use rustc_middle::mir::interpret::ErrorHandled;
 use rustc_middle::ty::{Region, RegionVid};
 
-use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
-
-use std::collections::hash_map::Entry;
 use std::collections::VecDeque;
 use std::iter;
 
@@ -25,8 +25,8 @@ pub enum RegionTarget<'tcx> {
 
 #[derive(Default, Debug, Clone)]
 pub struct RegionDeps<'tcx> {
-    larger: FxIndexSet<RegionTarget<'tcx>>,
-    smaller: FxIndexSet<RegionTarget<'tcx>>,
+    pub larger: FxIndexSet<RegionTarget<'tcx>>,
+    pub smaller: FxIndexSet<RegionTarget<'tcx>>,
 }
 
 pub enum AutoTraitResult<A> {
@@ -35,17 +35,10 @@ pub enum AutoTraitResult<A> {
     NegativeImpl,
 }
 
-#[allow(dead_code)]
-impl<A> AutoTraitResult<A> {
-    fn is_auto(&self) -> bool {
-        matches!(self, AutoTraitResult::PositiveImpl(_) | AutoTraitResult::NegativeImpl)
-    }
-}
-
 pub struct AutoTraitInfo<'cx> {
     pub full_user_env: ty::ParamEnv<'cx>,
     pub region_data: RegionConstraintData<'cx>,
-    pub vid_to_region: FxHashMap<ty::RegionVid, ty::Region<'cx>>,
+    pub vid_to_region: FxIndexMap<ty::RegionVid, ty::Region<'cx>>,
 }
 
 pub struct AutoTraitFinder<'tcx> {
@@ -88,19 +81,12 @@ impl<'tcx> AutoTraitFinder<'tcx> {
 
         let infcx = tcx.infer_ctxt().build();
         let mut selcx = SelectionContext::new(&infcx);
-        for polarity in [true, false] {
+        for polarity in [ty::PredicatePolarity::Positive, ty::PredicatePolarity::Negative] {
             let result = selcx.select(&Obligation::new(
                 tcx,
                 ObligationCause::dummy(),
                 orig_env,
-                ty::TraitPredicate {
-                    trait_ref,
-                    polarity: if polarity {
-                        ty::PredicatePolarity::Positive
-                    } else {
-                        ty::PredicatePolarity::Negative
-                    },
-                },
+                ty::TraitPredicate { trait_ref, polarity },
             ));
             if let Ok(Some(ImplSource::UserDefined(_))) = result {
                 debug!(
@@ -114,7 +100,7 @@ impl<'tcx> AutoTraitFinder<'tcx> {
         }
 
         let infcx = tcx.infer_ctxt().build();
-        let mut fresh_preds = FxHashSet::default();
+        let mut fresh_preds = FxIndexSet::default();
 
         // Due to the way projections are handled by SelectionContext, we need to run
         // evaluate_predicates twice: once on the original param env, and once on the result of
@@ -239,7 +225,7 @@ impl<'tcx> AutoTraitFinder<'tcx> {
         ty: Ty<'tcx>,
         param_env: ty::ParamEnv<'tcx>,
         user_env: ty::ParamEnv<'tcx>,
-        fresh_preds: &mut FxHashSet<ty::Predicate<'tcx>>,
+        fresh_preds: &mut FxIndexSet<ty::Predicate<'tcx>>,
     ) -> Option<(ty::ParamEnv<'tcx>, ty::ParamEnv<'tcx>)> {
         let tcx = infcx.tcx;
 
@@ -252,7 +238,7 @@ impl<'tcx> AutoTraitFinder<'tcx> {
 
         let mut select = SelectionContext::new(infcx);
 
-        let mut already_visited = FxHashSet::default();
+        let mut already_visited = UnordSet::new();
         let mut predicates = VecDeque::new();
         predicates.push_back(ty::Binder::dummy(ty::TraitPredicate {
             trait_ref: ty::TraitRef::new(infcx.tcx, trait_did, [ty]),
@@ -473,9 +459,9 @@ impl<'tcx> AutoTraitFinder<'tcx> {
     fn map_vid_to_region<'cx>(
         &self,
         regions: &RegionConstraintData<'cx>,
-    ) -> FxHashMap<ty::RegionVid, ty::Region<'cx>> {
-        let mut vid_map: FxHashMap<RegionTarget<'cx>, RegionDeps<'cx>> = FxHashMap::default();
-        let mut finished_map = FxHashMap::default();
+    ) -> FxIndexMap<ty::RegionVid, ty::Region<'cx>> {
+        let mut vid_map = FxIndexMap::<RegionTarget<'cx>, RegionDeps<'cx>>::default();
+        let mut finished_map = FxIndexMap::default();
 
         for (constraint, _) in &regions.constraints {
             match constraint {
@@ -513,25 +499,22 @@ impl<'tcx> AutoTraitFinder<'tcx> {
         }
 
         while !vid_map.is_empty() {
-            #[allow(rustc::potential_query_instability)]
-            let target = *vid_map.keys().next().expect("Keys somehow empty");
-            let deps = vid_map.remove(&target).expect("Entry somehow missing");
+            let target = *vid_map.keys().next().unwrap();
+            let deps = vid_map.swap_remove(&target).unwrap();
 
             for smaller in deps.smaller.iter() {
                 for larger in deps.larger.iter() {
                     match (smaller, larger) {
                         (&RegionTarget::Region(_), &RegionTarget::Region(_)) => {
-                            if let Entry::Occupied(v) = vid_map.entry(*smaller) {
+                            if let IndexEntry::Occupied(v) = vid_map.entry(*smaller) {
                                 let smaller_deps = v.into_mut();
                                 smaller_deps.larger.insert(*larger);
-                                // FIXME(#120456) - is `swap_remove` correct?
                                 smaller_deps.larger.swap_remove(&target);
                             }
 
-                            if let Entry::Occupied(v) = vid_map.entry(*larger) {
+                            if let IndexEntry::Occupied(v) = vid_map.entry(*larger) {
                                 let larger_deps = v.into_mut();
                                 larger_deps.smaller.insert(*smaller);
-                                // FIXME(#120456) - is `swap_remove` correct?
                                 larger_deps.smaller.swap_remove(&target);
                             }
                         }
@@ -542,17 +525,15 @@ impl<'tcx> AutoTraitFinder<'tcx> {
                             // Do nothing; we don't care about regions that are smaller than vids.
                         }
                         (&RegionTarget::RegionVid(_), &RegionTarget::RegionVid(_)) => {
-                            if let Entry::Occupied(v) = vid_map.entry(*smaller) {
+                            if let IndexEntry::Occupied(v) = vid_map.entry(*smaller) {
                                 let smaller_deps = v.into_mut();
                                 smaller_deps.larger.insert(*larger);
-                                // FIXME(#120456) - is `swap_remove` correct?
                                 smaller_deps.larger.swap_remove(&target);
                             }
 
-                            if let Entry::Occupied(v) = vid_map.entry(*larger) {
+                            if let IndexEntry::Occupied(v) = vid_map.entry(*larger) {
                                 let larger_deps = v.into_mut();
                                 larger_deps.smaller.insert(*smaller);
-                                // FIXME(#120456) - is `swap_remove` correct?
                                 larger_deps.smaller.swap_remove(&target);
                             }
                         }
@@ -560,6 +541,7 @@ impl<'tcx> AutoTraitFinder<'tcx> {
                 }
             }
         }
+
         finished_map
     }
 
@@ -588,7 +570,7 @@ impl<'tcx> AutoTraitFinder<'tcx> {
         ty: Ty<'_>,
         nested: impl Iterator<Item = PredicateObligation<'tcx>>,
         computed_preds: &mut FxIndexSet<ty::Predicate<'tcx>>,
-        fresh_preds: &mut FxHashSet<ty::Predicate<'tcx>>,
+        fresh_preds: &mut FxIndexSet<ty::Predicate<'tcx>>,
         predicates: &mut VecDeque<ty::PolyTraitPredicate<'tcx>>,
         selcx: &mut SelectionContext<'_, 'tcx>,
     ) -> bool {
diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
index 6f512a1173f..0459246553b 100644
--- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
@@ -28,7 +28,7 @@ use crate::traits::{
     BuiltinDerivedObligation, ImplDerivedObligation, ImplDerivedObligationCause, ImplSource,
     ImplSourceUserDefinedData, Normalized, Obligation, ObligationCause, PolyTraitObligation,
     PredicateObligation, Selection, SelectionError, SignatureMismatch, TraitNotObjectSafe,
-    Unimplemented,
+    TraitObligation, Unimplemented,
 };
 
 use super::BuiltinImplConditions;
@@ -678,17 +678,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         fn_host_effect: ty::Const<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
         debug!(?obligation, "confirm_fn_pointer_candidate");
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
 
         let tcx = self.tcx();
-
-        let Some(self_ty) = self.infcx.shallow_resolve(obligation.self_ty().no_bound_vars()) else {
-            // FIXME: Ideally we'd support `for<'a> fn(&'a ()): Fn(&'a ())`,
-            // but we do not currently. Luckily, such a bound is not
-            // particularly useful, so we don't expect users to write
-            // them often.
-            return Err(SelectionError::Unimplemented);
-        };
-
         let sig = self_ty.fn_sig(tcx);
         let trait_ref = closure_trait_ref_and_return_type(
             tcx,
@@ -700,7 +693,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         )
         .map_bound(|(trait_ref, _)| trait_ref);
 
-        let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
+        let mut nested =
+            self.equate_trait_refs(obligation.with(tcx, placeholder_predicate), trait_ref)?;
         let cause = obligation.derived_cause(BuiltinDerivedObligation);
 
         // Confirm the `type Output: Sized;` bound that is present on `FnOnce`
@@ -748,10 +742,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        // Okay to skip binder because the args on coroutine types never
-        // touch bound regions, they just capture the in-scope
-        // type/region parameters.
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
         let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
             bug!("closure candidate for non-closure {:?}", obligation);
         };
@@ -760,15 +752,6 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
 
         let coroutine_sig = args.as_coroutine().sig();
 
-        // NOTE: The self-type is a coroutine type and hence is
-        // in fact unparameterized (or at least does not reference any
-        // regions bound in the obligation).
-        let self_ty = obligation
-            .predicate
-            .self_ty()
-            .no_bound_vars()
-            .expect("unboxed closure type should not capture bound vars from the predicate");
-
         let (trait_ref, _, _) = super::util::coroutine_trait_ref_and_outputs(
             self.tcx(),
             obligation.predicate.def_id(),
@@ -776,7 +759,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
             coroutine_sig,
         );
 
-        let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
+        let nested = self.equate_trait_refs(
+            obligation.with(self.tcx(), placeholder_predicate),
+            ty::Binder::dummy(trait_ref),
+        )?;
         debug!(?trait_ref, ?nested, "coroutine candidate obligations");
 
         Ok(nested)
@@ -786,10 +772,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        // Okay to skip binder because the args on coroutine types never
-        // touch bound regions, they just capture the in-scope
-        // type/region parameters.
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
         let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
             bug!("closure candidate for non-closure {:?}", obligation);
         };
@@ -801,11 +785,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         let (trait_ref, _) = super::util::future_trait_ref_and_outputs(
             self.tcx(),
             obligation.predicate.def_id(),
-            obligation.predicate.no_bound_vars().expect("future has no bound vars").self_ty(),
+            self_ty,
             coroutine_sig,
         );
 
-        let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
+        let nested = self.equate_trait_refs(
+            obligation.with(self.tcx(), placeholder_predicate),
+            ty::Binder::dummy(trait_ref),
+        )?;
         debug!(?trait_ref, ?nested, "future candidate obligations");
 
         Ok(nested)
@@ -815,10 +802,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        // Okay to skip binder because the args on coroutine types never
-        // touch bound regions, they just capture the in-scope
-        // type/region parameters.
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
         let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
             bug!("closure candidate for non-closure {:?}", obligation);
         };
@@ -830,11 +815,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         let (trait_ref, _) = super::util::iterator_trait_ref_and_outputs(
             self.tcx(),
             obligation.predicate.def_id(),
-            obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
+            self_ty,
             gen_sig,
         );
 
-        let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
+        let nested = self.equate_trait_refs(
+            obligation.with(self.tcx(), placeholder_predicate),
+            ty::Binder::dummy(trait_ref),
+        )?;
         debug!(?trait_ref, ?nested, "iterator candidate obligations");
 
         Ok(nested)
@@ -844,10 +832,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        // Okay to skip binder because the args on coroutine types never
-        // touch bound regions, they just capture the in-scope
-        // type/region parameters.
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
         let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
             bug!("closure candidate for non-closure {:?}", obligation);
         };
@@ -859,11 +845,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         let (trait_ref, _) = super::util::async_iterator_trait_ref_and_outputs(
             self.tcx(),
             obligation.predicate.def_id(),
-            obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
+            self_ty,
             gen_sig,
         );
 
-        let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
+        let nested = self.equate_trait_refs(
+            obligation.with(self.tcx(), placeholder_predicate),
+            ty::Binder::dummy(trait_ref),
+        )?;
         debug!(?trait_ref, ?nested, "iterator candidate obligations");
 
         Ok(nested)
@@ -874,14 +863,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        // Okay to skip binder because the args on closure types never
-        // touch bound regions, they just capture the in-scope
-        // type/region parameters.
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty: Ty<'_> = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
+
         let trait_ref = match *self_ty.kind() {
-            ty::Closure(_, args) => {
-                self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
-            }
+            ty::Closure(..) => self.closure_trait_ref_unnormalized(
+                self_ty,
+                obligation.predicate.def_id(),
+                self.tcx().consts.true_,
+            ),
             ty::CoroutineClosure(_, args) => {
                 args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
                     ty::TraitRef::new(
@@ -896,7 +886,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
             }
         };
 
-        self.confirm_poly_trait_refs(obligation, trait_ref)
+        self.equate_trait_refs(obligation.with(self.tcx(), placeholder_predicate), trait_ref)
     }
 
     #[instrument(skip(self), level = "debug")]
@@ -904,8 +894,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
+
         let tcx = self.tcx();
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
 
         let mut nested = vec![];
         let (trait_ref, kind_ty) = match *self_ty.kind() {
@@ -972,7 +964,9 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
             _ => bug!("expected callable type for AsyncFn candidate"),
         };
 
-        nested.extend(self.confirm_poly_trait_refs(obligation, trait_ref)?);
+        nested.extend(
+            self.equate_trait_refs(obligation.with(tcx, placeholder_predicate), trait_ref)?,
+        );
 
         let goal_kind =
             self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap();
@@ -1025,34 +1019,32 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
     /// selection of the impl. Therefore, if there is a mismatch, we
     /// report an error to the user.
     #[instrument(skip(self), level = "trace")]
-    fn confirm_poly_trait_refs(
+    fn equate_trait_refs(
         &mut self,
-        obligation: &PolyTraitObligation<'tcx>,
-        self_ty_trait_ref: ty::PolyTraitRef<'tcx>,
+        obligation: TraitObligation<'tcx>,
+        found_trait_ref: ty::PolyTraitRef<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        let obligation_trait_ref =
-            self.infcx.enter_forall_and_leak_universe(obligation.predicate.to_poly_trait_ref());
-        let self_ty_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
+        let found_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
             obligation.cause.span,
             HigherRankedType,
-            self_ty_trait_ref,
+            found_trait_ref,
         );
         // Normalize the obligation and expected trait refs together, because why not
-        let Normalized { obligations: nested, value: (obligation_trait_ref, expected_trait_ref) } =
+        let Normalized { obligations: nested, value: (obligation_trait_ref, found_trait_ref) } =
             ensure_sufficient_stack(|| {
                 normalize_with_depth(
                     self,
                     obligation.param_env,
                     obligation.cause.clone(),
                     obligation.recursion_depth + 1,
-                    (obligation_trait_ref, self_ty_trait_ref),
+                    (obligation.predicate.trait_ref, found_trait_ref),
                 )
             });
 
         // needed to define opaque types for tests/ui/type-alias-impl-trait/assoc-projection-ice.rs
         self.infcx
             .at(&obligation.cause, obligation.param_env)
-            .eq(DefineOpaqueTypes::Yes, obligation_trait_ref, expected_trait_ref)
+            .eq(DefineOpaqueTypes::Yes, obligation_trait_ref, found_trait_ref)
             .map(|InferOk { mut obligations, .. }| {
                 obligations.extend(nested);
                 obligations
@@ -1060,7 +1052,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
             .map_err(|terr| {
                 SignatureMismatch(Box::new(SignatureMismatchData {
                     expected_trait_ref: ty::Binder::dummy(obligation_trait_ref),
-                    found_trait_ref: ty::Binder::dummy(expected_trait_ref),
+                    found_trait_ref: ty::Binder::dummy(found_trait_ref),
                     terr,
                 }))
             })
diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index 495ca25f696..3fbe8e39d1e 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -2643,26 +2643,18 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
     #[instrument(skip(self), level = "debug")]
     fn closure_trait_ref_unnormalized(
         &mut self,
-        obligation: &PolyTraitObligation<'tcx>,
-        args: GenericArgsRef<'tcx>,
+        self_ty: Ty<'tcx>,
+        fn_trait_def_id: DefId,
         fn_host_effect: ty::Const<'tcx>,
     ) -> ty::PolyTraitRef<'tcx> {
+        let ty::Closure(_, args) = *self_ty.kind() else {
+            bug!("expected closure, found {self_ty}");
+        };
         let closure_sig = args.as_closure().sig();
 
-        debug!(?closure_sig);
-
-        // NOTE: The self-type is an unboxed closure type and hence is
-        // in fact unparameterized (or at least does not reference any
-        // regions bound in the obligation).
-        let self_ty = obligation
-            .predicate
-            .self_ty()
-            .no_bound_vars()
-            .expect("unboxed closure type should not capture bound vars from the predicate");
-
         closure_trait_ref_and_return_type(
             self.tcx(),
-            obligation.predicate.def_id(),
+            fn_trait_def_id,
             self_ty,
             closure_sig,
             util::TupleArgumentsFlag::No,