about summary refs log tree commit diff
path: root/compiler/rustc_const_eval/src/interpret/eval_context.rs
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2024-07-17 11:48:22 +0200
committerRalf Jung <post@ralfj.de>2024-07-18 11:41:10 +0200
commita7b80819e9c7c9fcc822d2e0cb2dfb4426a9a911 (patch)
tree829e6e6002a1d887a646f45db4145ea9b71faa15 /compiler/rustc_const_eval/src/interpret/eval_context.rs
parent4cd8dc63353a9859e3e3c2d5296024c810fc0923 (diff)
downloadrust-a7b80819e9c7c9fcc822d2e0cb2dfb4426a9a911.tar.gz
rust-a7b80819e9c7c9fcc822d2e0cb2dfb4426a9a911.zip
interpret: add sanity check in dyn upcast to double-check what codegen does
Diffstat (limited to 'compiler/rustc_const_eval/src/interpret/eval_context.rs')
-rw-r--r--compiler/rustc_const_eval/src/interpret/eval_context.rs30
1 files changed, 30 insertions, 0 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/eval_context.rs b/compiler/rustc_const_eval/src/interpret/eval_context.rs
index 6d3e5ea1031..9fddeec2973 100644
--- a/compiler/rustc_const_eval/src/interpret/eval_context.rs
+++ b/compiler/rustc_const_eval/src/interpret/eval_context.rs
@@ -2,11 +2,15 @@ use std::cell::Cell;
 use std::{fmt, mem};
 
 use either::{Either, Left, Right};
+use rustc_infer::infer::at::ToTrace;
+use rustc_infer::traits::ObligationCause;
+use rustc_trait_selection::traits::ObligationCtxt;
 use tracing::{debug, info, info_span, instrument, trace};
 
 use rustc_errors::DiagCtxtHandle;
 use rustc_hir::{self as hir, def_id::DefId, definitions::DefPathData};
 use rustc_index::IndexVec;
+use rustc_infer::infer::TyCtxtInferExt;
 use rustc_middle::mir;
 use rustc_middle::mir::interpret::{
     CtfeProvenance, ErrorHandled, InvalidMetaKind, ReportedErrorInfo,
@@ -640,6 +644,32 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         }
     }
 
+    /// Check if the two things are equal in the current param_env, using an infctx to get proper
+    /// equality checks.
+    pub(super) fn eq_in_param_env<T>(&self, a: T, b: T) -> bool
+    where
+        T: PartialEq + TypeFoldable<TyCtxt<'tcx>> + ToTrace<'tcx>,
+    {
+        // Fast path: compare directly.
+        if a == b {
+            return true;
+        }
+        // Slow path: spin up an inference context to check if these traits are sufficiently equal.
+        let infcx = self.tcx.infer_ctxt().build();
+        let ocx = ObligationCtxt::new(&infcx);
+        let cause = ObligationCause::dummy_with_span(self.cur_span());
+        // equate the two trait refs after normalization
+        let a = ocx.normalize(&cause, self.param_env, a);
+        let b = ocx.normalize(&cause, self.param_env, b);
+        if ocx.eq(&cause, self.param_env, a, b).is_ok() {
+            if ocx.select_all_or_error().is_empty() {
+                // All good.
+                return true;
+            }
+        }
+        return false;
+    }
+
     /// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a
     /// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic,
     /// and is primarily intended for the panic machinery.