about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_middle/src/traits/mod.rs22
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs2
2 files changed, 17 insertions, 7 deletions
diff --git a/compiler/rustc_middle/src/traits/mod.rs b/compiler/rustc_middle/src/traits/mod.rs
index 04e3daf3045..664eef7ca56 100644
--- a/compiler/rustc_middle/src/traits/mod.rs
+++ b/compiler/rustc_middle/src/traits/mod.rs
@@ -165,12 +165,9 @@ impl<'tcx> ObligationCause<'tcx> {
 
     pub fn map_code(
         &mut self,
-        f: impl FnOnce(Lrc<ObligationCauseCode<'tcx>>) -> Lrc<ObligationCauseCode<'tcx>>,
+        f: impl FnOnce(InternedObligationCauseCode<'tcx>) -> Lrc<ObligationCauseCode<'tcx>>,
     ) {
-        self.code = Some(f(match self.code.take() {
-            Some(code) => code,
-            None => Lrc::new(MISC_OBLIGATION_CAUSE_CODE),
-        }));
+        self.code = Some(f(InternedObligationCauseCode { code: self.code.take() }));
     }
 
     pub fn derived_cause(
@@ -207,6 +204,19 @@ pub struct UnifyReceiverContext<'tcx> {
 }
 
 #[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
+pub struct InternedObligationCauseCode<'tcx> {
+    code: Option<Lrc<ObligationCauseCode<'tcx>>>,
+}
+
+impl<'tcx> std::ops::Deref for InternedObligationCauseCode<'tcx> {
+    type Target = ObligationCauseCode<'tcx>;
+
+    fn deref(&self) -> &Self::Target {
+        self.code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE)
+    }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
 pub enum ObligationCauseCode<'tcx> {
     /// Not well classified or should be obvious from the span.
     MiscObligation,
@@ -293,7 +303,7 @@ pub enum ObligationCauseCode<'tcx> {
         /// The node of the function call.
         call_hir_id: hir::HirId,
         /// The obligation introduced by this argument.
-        parent_code: Lrc<ObligationCauseCode<'tcx>>,
+        parent_code: InternedObligationCauseCode<'tcx>,
     },
 
     /// Error derived when matching traits/impls; see ObligationCause for more details
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
index 1522eedf7fe..ee3e9544b4d 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
@@ -1652,7 +1652,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
             debug!("maybe_note_obligation_cause_for_async_await: code={:?}", code);
             match code {
                 ObligationCauseCode::FunctionArgumentObligation { parent_code, .. } => {
-                    next_code = Some(parent_code.as_ref());
+                    next_code = Some(parent_code);
                 }
                 ObligationCauseCode::ImplDerivedObligation(cause) => {
                     let ty = cause.derived.parent_trait_pred.skip_binder().self_ty();