about summary refs log tree commit diff
diff options
context:
space:
mode:
authorManish Goregaokar <manishsmail@gmail.com>2023-05-24 15:05:03 -0700
committerGitHub <noreply@github.com>2023-05-24 15:05:03 -0700
commita9a4c9b3dbfcd7c51b377985ccc2a030f93da9fb (patch)
treed3e6d5a0421d68ba5961d197ba9fab6c382a4298
parentc373194cb6d882dc455a588bcc29c92a96b50252 (diff)
parent521a0bcd1f9b47513d65d4f711f5b21a9a2eafdc (diff)
downloadrust-a9a4c9b3dbfcd7c51b377985ccc2a030f93da9fb.tar.gz
rust-a9a4c9b3dbfcd7c51b377985ccc2a030f93da9fb.zip
Rollup merge of #111741 - compiler-errors:custom-type-op, r=lcnr
Use `ObligationCtxt` in custom type ops

We already make one when evaluating the `CustomTypeOp`, so it's simpler to just pass it to the user. Removes a redundant `ObligationCtxt::new_in_snapshot` usage and simplifies some other code.

This makes several refactorings related to opaque types in the new solver simpler, but those are not included in this PR.
-rw-r--r--compiler/rustc_borrowck/src/diagnostics/bound_region_errors.rs2
-rw-r--r--compiler/rustc_borrowck/src/type_check/canonical.rs23
-rw-r--r--compiler/rustc_borrowck/src/type_check/mod.rs15
-rw-r--r--compiler/rustc_borrowck/src/type_check/relate_tys.rs32
-rw-r--r--compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs54
5 files changed, 66 insertions, 60 deletions
diff --git a/compiler/rustc_borrowck/src/diagnostics/bound_region_errors.rs b/compiler/rustc_borrowck/src/diagnostics/bound_region_errors.rs
index 84f75caa692..f41795d60a0 100644
--- a/compiler/rustc_borrowck/src/diagnostics/bound_region_errors.rs
+++ b/compiler/rustc_borrowck/src/diagnostics/bound_region_errors.rs
@@ -128,7 +128,7 @@ impl<'tcx> ToUniverseInfo<'tcx>
     }
 }
 
-impl<'tcx, F, G> ToUniverseInfo<'tcx> for Canonical<'tcx, type_op::custom::CustomTypeOp<F, G>> {
+impl<'tcx, F> ToUniverseInfo<'tcx> for Canonical<'tcx, type_op::custom::CustomTypeOp<F>> {
     fn to_universe_info(self, _base_universe: ty::UniverseIndex) -> UniverseInfo<'tcx> {
         // We can't rerun custom type ops.
         UniverseInfo::other()
diff --git a/compiler/rustc_borrowck/src/type_check/canonical.rs b/compiler/rustc_borrowck/src/type_check/canonical.rs
index b27d5d20532..95dcc8d4b19 100644
--- a/compiler/rustc_borrowck/src/type_check/canonical.rs
+++ b/compiler/rustc_borrowck/src/type_check/canonical.rs
@@ -1,13 +1,13 @@
 use std::fmt;
 
-use rustc_infer::infer::{canonical::Canonical, InferOk};
+use rustc_infer::infer::canonical::Canonical;
 use rustc_middle::mir::ConstraintCategory;
 use rustc_middle::ty::{self, ToPredicate, Ty, TyCtxt, TypeFoldable};
 use rustc_span::def_id::DefId;
 use rustc_span::Span;
 use rustc_trait_selection::traits::query::type_op::{self, TypeOpOutput};
 use rustc_trait_selection::traits::query::{Fallible, NoSolution};
-use rustc_trait_selection::traits::{ObligationCause, ObligationCtxt};
+use rustc_trait_selection::traits::ObligationCause;
 
 use crate::diagnostics::{ToUniverseInfo, UniverseInfo};
 
@@ -219,20 +219,17 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
 
         let cause = ObligationCause::dummy_with_span(span);
         let param_env = self.param_env;
-        let op = |infcx: &'_ _| {
-            let ocx = ObligationCtxt::new_in_snapshot(infcx);
-            let user_ty = ocx.normalize(&cause, param_env, user_ty);
-            ocx.eq(&cause, param_env, user_ty, mir_ty)?;
-            if !ocx.select_all_or_error().is_empty() {
-                return Err(NoSolution);
-            }
-            Ok(InferOk { value: (), obligations: vec![] })
-        };
-
         self.fully_perform_op(
             Locations::All(span),
             ConstraintCategory::Boring,
-            type_op::custom::CustomTypeOp::new(op, || "ascribe_user_type_skip_wf".to_string()),
+            type_op::custom::CustomTypeOp::new(
+                |ocx| {
+                    let user_ty = ocx.normalize(&cause, param_env, user_ty);
+                    ocx.eq(&cause, param_env, user_ty, mir_ty)?;
+                    Ok(())
+                },
+                "ascribe_user_type_skip_wf",
+            ),
         )
         .unwrap_or_else(|err| {
             span_mirbug!(
diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs
index fa4bc926f27..ab5e01e1952 100644
--- a/compiler/rustc_borrowck/src/type_check/mod.rs
+++ b/compiler/rustc_borrowck/src/type_check/mod.rs
@@ -20,7 +20,7 @@ use rustc_infer::infer::outlives::env::RegionBoundPairs;
 use rustc_infer::infer::region_constraints::RegionConstraintData;
 use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
 use rustc_infer::infer::{
-    InferCtxt, InferOk, LateBoundRegion, LateBoundRegionConversionTime, NllRegionVariableOrigin,
+    InferCtxt, LateBoundRegion, LateBoundRegionConversionTime, NllRegionVariableOrigin,
 };
 use rustc_middle::mir::tcx::PlaceTy;
 use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext, Visitor};
@@ -218,16 +218,16 @@ pub(crate) fn type_check<'mir, 'tcx>(
                     Locations::All(body.span),
                     ConstraintCategory::OpaqueType,
                     CustomTypeOp::new(
-                        |infcx| {
-                            infcx.register_member_constraints(
+                        |ocx| {
+                            ocx.infcx.register_member_constraints(
                                 param_env,
                                 opaque_type_key,
                                 decl.hidden_type.ty,
                                 decl.hidden_type.span,
                             );
-                            Ok(InferOk { value: (), obligations: vec![] })
+                            Ok(())
                         },
-                        || "opaque_type_map".to_string(),
+                        "opaque_type_map",
                     ),
                 )
                 .unwrap();
@@ -2713,8 +2713,9 @@ impl<'tcx> TypeOp<'tcx> for InstantiateOpaqueType<'tcx> {
     type ErrorInfo = InstantiateOpaqueType<'tcx>;
 
     fn fully_perform(mut self, infcx: &InferCtxt<'tcx>) -> Fallible<TypeOpOutput<'tcx, Self>> {
-        let (mut output, region_constraints) = scrape_region_constraints(infcx, || {
-            Ok(InferOk { value: (), obligations: self.obligations.clone() })
+        let (mut output, region_constraints) = scrape_region_constraints(infcx, |ocx| {
+            ocx.register_obligations(self.obligations.clone());
+            Ok(())
         })?;
         self.region_constraints = Some(region_constraints);
         output.error_info = Some(self);
diff --git a/compiler/rustc_borrowck/src/type_check/relate_tys.rs b/compiler/rustc_borrowck/src/type_check/relate_tys.rs
index 7158c62b548..dd1f89e5b91 100644
--- a/compiler/rustc_borrowck/src/type_check/relate_tys.rs
+++ b/compiler/rustc_borrowck/src/type_check/relate_tys.rs
@@ -185,17 +185,25 @@ impl<'tcx> TypeRelatingDelegate<'tcx> for NllTypeRelatingDelegate<'_, '_, 'tcx>
     }
 
     fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
-        self.type_checker
-            .fully_perform_op(
-                self.locations,
-                self.category,
-                InstantiateOpaqueType {
-                    obligations,
-                    // These fields are filled in during execution of the operation
-                    base_universe: None,
-                    region_constraints: None,
-                },
-            )
-            .unwrap();
+        match self.type_checker.fully_perform_op(
+            self.locations,
+            self.category,
+            InstantiateOpaqueType {
+                obligations,
+                // These fields are filled in during execution of the operation
+                base_universe: None,
+                region_constraints: None,
+            },
+        ) {
+            Ok(()) => {}
+            Err(_) => {
+                // It's a bit redundant to delay a bug here, but I'd rather
+                // delay more bugs than accidentally not delay a bug at all.
+                self.type_checker.tcx().sess.delay_span_bug(
+                    self.locations.span(self.type_checker.body),
+                    "errors selecting obligation during MIR typeck",
+                );
+            }
+        };
     }
 }
diff --git a/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs b/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs
index 1f8e756043d..3d6c1d9e2b0 100644
--- a/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs
+++ b/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs
@@ -1,32 +1,31 @@
 use crate::infer::canonical::query_response;
-use crate::infer::{InferCtxt, InferOk};
+use crate::infer::InferCtxt;
 use crate::traits::query::type_op::TypeOpOutput;
 use crate::traits::query::Fallible;
 use crate::traits::ObligationCtxt;
 use rustc_infer::infer::region_constraints::RegionConstraintData;
+use rustc_middle::traits::query::NoSolution;
 use rustc_span::source_map::DUMMY_SP;
 
 use std::fmt;
 
-pub struct CustomTypeOp<F, G> {
+pub struct CustomTypeOp<F> {
     closure: F,
-    description: G,
+    description: &'static str,
 }
 
-impl<F, G> CustomTypeOp<F, G> {
-    pub fn new<'tcx, R>(closure: F, description: G) -> Self
+impl<F> CustomTypeOp<F> {
+    pub fn new<'tcx, R>(closure: F, description: &'static str) -> Self
     where
-        F: FnOnce(&InferCtxt<'tcx>) -> Fallible<InferOk<'tcx, R>>,
-        G: Fn() -> String,
+        F: FnOnce(&ObligationCtxt<'_, 'tcx>) -> Fallible<R>,
     {
         CustomTypeOp { closure, description }
     }
 }
 
-impl<'tcx, F, R: fmt::Debug, G> super::TypeOp<'tcx> for CustomTypeOp<F, G>
+impl<'tcx, F, R: fmt::Debug> super::TypeOp<'tcx> for CustomTypeOp<F>
 where
-    F: for<'a, 'cx> FnOnce(&'a InferCtxt<'tcx>) -> Fallible<InferOk<'tcx, R>>,
-    G: Fn() -> String,
+    F: FnOnce(&ObligationCtxt<'_, 'tcx>) -> Fallible<R>,
 {
     type Output = R;
     /// We can't do any custom error reporting for `CustomTypeOp`, so
@@ -41,16 +40,13 @@ where
             info!("fully_perform({:?})", self);
         }
 
-        Ok(scrape_region_constraints(infcx, || (self.closure)(infcx))?.0)
+        Ok(scrape_region_constraints(infcx, self.closure)?.0)
     }
 }
 
-impl<F, G> fmt::Debug for CustomTypeOp<F, G>
-where
-    G: Fn() -> String,
-{
+impl<F> fmt::Debug for CustomTypeOp<F> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write!(f, "{}", (self.description)())
+        self.description.fmt(f)
     }
 }
 
@@ -58,7 +54,7 @@ where
 /// constraints that result, creating query-region-constraints.
 pub fn scrape_region_constraints<'tcx, Op: super::TypeOp<'tcx, Output = R>, R>(
     infcx: &InferCtxt<'tcx>,
-    op: impl FnOnce() -> Fallible<InferOk<'tcx, R>>,
+    op: impl FnOnce(&ObligationCtxt<'_, 'tcx>) -> Fallible<R>,
 ) -> Fallible<(TypeOpOutput<'tcx, Op>, RegionConstraintData<'tcx>)> {
     // During NLL, we expect that nobody will register region
     // obligations **except** as part of a custom type op (and, at the
@@ -72,16 +68,20 @@ pub fn scrape_region_constraints<'tcx, Op: super::TypeOp<'tcx, Output = R>, R>(
         pre_obligations,
     );
 
-    let InferOk { value, obligations } = infcx.commit_if_ok(|_| op())?;
-    let ocx = ObligationCtxt::new(infcx);
-    ocx.register_obligations(obligations);
-    let errors = ocx.select_all_or_error();
-    if !errors.is_empty() {
-        infcx.tcx.sess.diagnostic().delay_span_bug(
-            DUMMY_SP,
-            format!("errors selecting obligation during MIR typeck: {:?}", errors),
-        );
-    }
+    let value = infcx.commit_if_ok(|_| {
+        let ocx = ObligationCtxt::new_in_snapshot(infcx);
+        let value = op(&ocx)?;
+        let errors = ocx.select_all_or_error();
+        if errors.is_empty() {
+            Ok(value)
+        } else {
+            infcx.tcx.sess.delay_span_bug(
+                DUMMY_SP,
+                format!("errors selecting obligation during MIR typeck: {:?}", errors),
+            );
+            Err(NoSolution)
+        }
+    })?;
 
     let region_obligations = infcx.take_registered_region_obligations();
     let region_constraint_data = infcx.take_and_reset_region_constraints();