about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_middle/src/traits/mod.rs48
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs2
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs20
-rw-r--r--compiler/rustc_typeck/src/check/fn_ctxt/checks.rs5
4 files changed, 33 insertions, 42 deletions
diff --git a/compiler/rustc_middle/src/traits/mod.rs b/compiler/rustc_middle/src/traits/mod.rs
index 664eef7ca56..4b67fc84d2e 100644
--- a/compiler/rustc_middle/src/traits/mod.rs
+++ b/compiler/rustc_middle/src/traits/mod.rs
@@ -97,9 +97,7 @@ pub struct ObligationCause<'tcx> {
     /// information.
     pub body_id: hir::HirId,
 
-    /// `None` for `MISC_OBLIGATION_CAUSE_CODE` (a common case, occurs ~60% of
-    /// the time). `Some` otherwise.
-    code: Option<Lrc<ObligationCauseCode<'tcx>>>,
+    code: InternedObligationCauseCode<'tcx>,
 }
 
 // This custom hash function speeds up hashing for `Obligation` deduplication
@@ -123,11 +121,7 @@ impl<'tcx> ObligationCause<'tcx> {
         body_id: hir::HirId,
         code: ObligationCauseCode<'tcx>,
     ) -> ObligationCause<'tcx> {
-        ObligationCause {
-            span,
-            body_id,
-            code: if code == MISC_OBLIGATION_CAUSE_CODE { None } else { Some(Lrc::new(code)) },
-        }
+        ObligationCause { span, body_id, code: code.into() }
     }
 
     pub fn misc(span: Span, body_id: hir::HirId) -> ObligationCause<'tcx> {
@@ -136,11 +130,11 @@ impl<'tcx> ObligationCause<'tcx> {
 
     #[inline(always)]
     pub fn dummy() -> ObligationCause<'tcx> {
-        ObligationCause { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: None }
+        ObligationCause::dummy_with_span(DUMMY_SP)
     }
 
     pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> {
-        ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: None }
+        ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: Default::default() }
     }
 
     pub fn span(&self, tcx: TyCtxt<'tcx>) -> Span {
@@ -160,14 +154,14 @@ impl<'tcx> ObligationCause<'tcx> {
 
     #[inline]
     pub fn code(&self) -> &ObligationCauseCode<'tcx> {
-        self.code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE)
+        &self.code
     }
 
     pub fn map_code(
         &mut self,
-        f: impl FnOnce(InternedObligationCauseCode<'tcx>) -> Lrc<ObligationCauseCode<'tcx>>,
+        f: impl FnOnce(InternedObligationCauseCode<'tcx>) -> ObligationCauseCode<'tcx>,
     ) {
-        self.code = Some(f(InternedObligationCauseCode { code: self.code.take() }));
+        self.code = f(std::mem::take(&mut self.code)).into();
     }
 
     pub fn derived_cause(
@@ -188,10 +182,8 @@ impl<'tcx> ObligationCause<'tcx> {
         // NOTE(flaper87): As of now, it keeps track of the whole error
         // chain. Ideally, we should have a way to configure this either
         // by using -Z verbose or just a CLI argument.
-        self.code = Some(
-            variant(DerivedObligationCause { parent_trait_pred, parent_code: self.code.take() })
-                .into(),
-        );
+        self.code =
+            variant(DerivedObligationCause { parent_trait_pred, parent_code: self.code }).into();
         self
     }
 }
@@ -203,11 +195,19 @@ pub struct UnifyReceiverContext<'tcx> {
     pub substs: SubstsRef<'tcx>,
 }
 
-#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
+#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift, Default)]
 pub struct InternedObligationCauseCode<'tcx> {
+    /// `None` for `MISC_OBLIGATION_CAUSE_CODE` (a common case, occurs ~60% of
+    /// the time). `Some` otherwise.
     code: Option<Lrc<ObligationCauseCode<'tcx>>>,
 }
 
+impl<'tcx> From<ObligationCauseCode<'tcx>> for InternedObligationCauseCode<'tcx> {
+    fn from(code: ObligationCauseCode<'tcx>) -> Self {
+        Self { code: if code == MISC_OBLIGATION_CAUSE_CODE { None } else { Some(Lrc::new(code)) } }
+    }
+}
+
 impl<'tcx> std::ops::Deref for InternedObligationCauseCode<'tcx> {
     type Target = ObligationCauseCode<'tcx>;
 
@@ -454,7 +454,7 @@ impl<'tcx> ObligationCauseCode<'tcx> {
             BuiltinDerivedObligation(derived)
             | DerivedObligation(derived)
             | ImplDerivedObligation(box ImplDerivedObligationCause { derived, .. }) => {
-                Some((derived.parent_code(), Some(derived.parent_trait_pred)))
+                Some((&derived.parent_code, Some(derived.parent_trait_pred)))
             }
             _ => None,
         }
@@ -508,15 +508,7 @@ pub struct DerivedObligationCause<'tcx> {
     pub parent_trait_pred: ty::PolyTraitPredicate<'tcx>,
 
     /// The parent trait had this cause.
-    parent_code: Option<Lrc<ObligationCauseCode<'tcx>>>,
-}
-
-impl<'tcx> DerivedObligationCause<'tcx> {
-    /// Get a reference to the derived obligation cause's parent code.
-    #[must_use]
-    pub fn parent_code(&self) -> &ObligationCauseCode<'tcx> {
-        self.parent_code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE)
-    }
+    pub parent_code: InternedObligationCauseCode<'tcx>,
 }
 
 #[derive(Clone, Debug, TypeFoldable, Lift)]
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs
index 81e62f6da06..6082d7529c3 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs
@@ -1868,7 +1868,7 @@ impl<'a, 'tcx> InferCtxtPrivExt<'a, 'tcx> for InferCtxt<'a, 'tcx> {
         match code {
             ObligationCauseCode::BuiltinDerivedObligation(data) => {
                 let parent_trait_ref = self.resolve_vars_if_possible(data.parent_trait_pred);
-                match self.get_parent_trait_ref(data.parent_code()) {
+                match self.get_parent_trait_ref(&data.parent_code) {
                     Some(t) => Some(t),
                     None => {
                         let ty = parent_trait_ref.skip_binder().self_ty();
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
index ee3e9544b4d..833e232e636 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
@@ -1683,7 +1683,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                         _ => {}
                     }
 
-                    next_code = Some(cause.derived.parent_code());
+                    next_code = Some(&cause.derived.parent_code);
                 }
                 ObligationCauseCode::DerivedObligation(derived_obligation)
                 | ObligationCauseCode::BuiltinDerivedObligation(derived_obligation) => {
@@ -1715,7 +1715,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                         _ => {}
                     }
 
-                    next_code = Some(derived_obligation.parent_code());
+                    next_code = Some(&derived_obligation.parent_code);
                 }
                 _ => break,
             }
@@ -2365,7 +2365,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                 let is_upvar_tys_infer_tuple = if !matches!(ty.kind(), ty::Tuple(..)) {
                     false
                 } else {
-                    if let ObligationCauseCode::BuiltinDerivedObligation(data) = data.parent_code()
+                    if let ObligationCauseCode::BuiltinDerivedObligation(data) = &*data.parent_code
                     {
                         let parent_trait_ref =
                             self.resolve_vars_if_possible(data.parent_trait_pred);
@@ -2392,14 +2392,14 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                 obligated_types.push(ty);
 
                 let parent_predicate = parent_trait_ref.to_predicate(tcx);
-                if !self.is_recursive_obligation(obligated_types, data.parent_code()) {
+                if !self.is_recursive_obligation(obligated_types, &data.parent_code) {
                     // #74711: avoid a stack overflow
                     ensure_sufficient_stack(|| {
                         self.note_obligation_cause_code(
                             err,
                             &parent_predicate,
                             param_env,
-                            data.parent_code(),
+                            &data.parent_code,
                             obligated_types,
                             seen_requirements,
                         )
@@ -2410,7 +2410,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                             err,
                             &parent_predicate,
                             param_env,
-                            &cause_code.peel_derives(),
+                            cause_code.peel_derives(),
                             obligated_types,
                             seen_requirements,
                         )
@@ -2461,7 +2461,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                     // We don't want to point at the ADT saying "required because it appears within
                     // the type `X`", like we would otherwise do in test `supertrait-auto-trait.rs`.
                     while let ObligationCauseCode::BuiltinDerivedObligation(derived) =
-                        data.parent_code()
+                        &*data.parent_code
                     {
                         let child_trait_ref =
                             self.resolve_vars_if_possible(derived.parent_trait_pred);
@@ -2474,7 +2474,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                         parent_trait_pred = child_trait_ref;
                     }
                 }
-                while let ObligationCauseCode::ImplDerivedObligation(child) = data.parent_code() {
+                while let ObligationCauseCode::ImplDerivedObligation(child) = &*data.parent_code {
                     // Skip redundant recursive obligation notes. See `ui/issue-20413.rs`.
                     let child_trait_pred =
                         self.resolve_vars_if_possible(child.derived.parent_trait_pred);
@@ -2505,7 +2505,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                         err,
                         &parent_predicate,
                         param_env,
-                        data.parent_code(),
+                        &data.parent_code,
                         obligated_types,
                         seen_requirements,
                     )
@@ -2520,7 +2520,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                         err,
                         &parent_predicate,
                         param_env,
-                        data.parent_code(),
+                        &data.parent_code,
                         obligated_types,
                         seen_requirements,
                     )
diff --git a/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs b/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs
index 300b87aa465..7c180bd1643 100644
--- a/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs
+++ b/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs
@@ -1606,9 +1606,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 let mut result_code = code;
                 loop {
                     let parent = match code {
-                        ObligationCauseCode::ImplDerivedObligation(c) => c.derived.parent_code(),
+                        ObligationCauseCode::ImplDerivedObligation(c) => &c.derived.parent_code,
                         ObligationCauseCode::BuiltinDerivedObligation(c)
-                        | ObligationCauseCode::DerivedObligation(c) => c.parent_code(),
+                        | ObligationCauseCode::DerivedObligation(c) => &c.parent_code,
                         _ => break result_code,
                     };
                     (result_code, code) = (code, parent);
@@ -1670,7 +1670,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                         call_hir_id: expr.hir_id,
                         parent_code,
                     }
-                    .into()
                 });
             } else if error.obligation.cause.span == call_sp {
                 // Make function calls point at the callee, not the whole thing.