about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJack Wrenn <jack@wrenn.fyi>2024-03-20 17:45:14 +0000
committerJack Wrenn <jack@wrenn.fyi>2024-03-22 17:01:49 +0000
commit2de9010f66febb7074706e05ff8d32ec329f4968 (patch)
tree381affdfd8c17f0cf9952f4df913ae824556566f
parent9023f908cfbe7a475f369717a61cb8eb865cfd25 (diff)
downloadrust-2de9010f66febb7074706e05ff8d32ec329f4968.tar.gz
rust-2de9010f66febb7074706e05ff8d32ec329f4968.zip
Add `tag_for_variant` query
This query allows for sharing code between `rustc_const_eval` and
`rustc_transmutability`.

Also moves `DummyMachine` to `rustc_const_eval`.
-rw-r--r--compiler/rustc_const_eval/src/const_eval/dummy_machine.rs193
-rw-r--r--compiler/rustc_const_eval/src/const_eval/eval_queries.rs20
-rw-r--r--compiler/rustc_const_eval/src/const_eval/mod.rs2
-rw-r--r--compiler/rustc_const_eval/src/interpret/discriminant.rs156
-rw-r--r--compiler/rustc_const_eval/src/lib.rs1
-rw-r--r--compiler/rustc_middle/src/query/erase.rs1
-rw-r--r--compiler/rustc_middle/src/query/keys.rs9
-rw-r--r--compiler/rustc_middle/src/query/mod.rs7
-rw-r--r--compiler/rustc_mir_transform/src/dataflow_const_prop.rs200
-rw-r--r--compiler/rustc_mir_transform/src/gvn.rs2
-rw-r--r--compiler/rustc_mir_transform/src/jump_threading.rs2
-rw-r--r--compiler/rustc_mir_transform/src/known_panics_lint.rs2
-rw-r--r--compiler/rustc_transmute/src/layout/tree.rs44
13 files changed, 347 insertions, 292 deletions
diff --git a/compiler/rustc_const_eval/src/const_eval/dummy_machine.rs b/compiler/rustc_const_eval/src/const_eval/dummy_machine.rs
new file mode 100644
index 00000000000..ba2e2a1e353
--- /dev/null
+++ b/compiler/rustc_const_eval/src/const_eval/dummy_machine.rs
@@ -0,0 +1,193 @@
+use crate::interpret::{self, HasStaticRootDefId, ImmTy, Immediate, InterpCx, PointerArithmetic};
+use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult};
+use rustc_middle::mir::*;
+use rustc_middle::query::TyCtxtAt;
+use rustc_middle::ty;
+use rustc_middle::ty::layout::TyAndLayout;
+use rustc_span::def_id::DefId;
+
+/// Macro for machine-specific `InterpError` without allocation.
+/// (These will never be shown to the user, but they help diagnose ICEs.)
+pub macro throw_machine_stop_str($($tt:tt)*) {{
+    // We make a new local type for it. The type itself does not carry any information,
+    // but its vtable (for the `MachineStopType` trait) does.
+    #[derive(Debug)]
+    struct Zst;
+    // Printing this type shows the desired string.
+    impl std::fmt::Display for Zst {
+        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+            write!(f, $($tt)*)
+        }
+    }
+
+    impl rustc_middle::mir::interpret::MachineStopType for Zst {
+        fn diagnostic_message(&self) -> rustc_errors::DiagMessage {
+            self.to_string().into()
+        }
+
+        fn add_args(
+            self: Box<Self>,
+            _: &mut dyn FnMut(rustc_errors::DiagArgName, rustc_errors::DiagArgValue),
+        ) {}
+    }
+    throw_machine_stop!(Zst)
+}}
+
+pub struct DummyMachine;
+
+impl HasStaticRootDefId for DummyMachine {
+    fn static_def_id(&self) -> Option<rustc_hir::def_id::LocalDefId> {
+        None
+    }
+}
+
+impl<'mir, 'tcx: 'mir> interpret::Machine<'mir, 'tcx> for DummyMachine {
+    interpret::compile_time_machine!(<'mir, 'tcx>);
+    type MemoryKind = !;
+    const PANIC_ON_ALLOC_FAIL: bool = true;
+
+    #[inline(always)]
+    fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
+        false // no reason to enforce alignment
+    }
+
+    fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
+        false
+    }
+
+    fn before_access_global(
+        _tcx: TyCtxtAt<'tcx>,
+        _machine: &Self,
+        _alloc_id: AllocId,
+        alloc: ConstAllocation<'tcx>,
+        _static_def_id: Option<DefId>,
+        is_write: bool,
+    ) -> InterpResult<'tcx> {
+        if is_write {
+            throw_machine_stop_str!("can't write to global");
+        }
+
+        // If the static allocation is mutable, then we can't const prop it as its content
+        // might be different at runtime.
+        if alloc.inner().mutability.is_mut() {
+            throw_machine_stop_str!("can't access mutable globals in ConstProp");
+        }
+
+        Ok(())
+    }
+
+    fn find_mir_or_eval_fn(
+        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
+        _instance: ty::Instance<'tcx>,
+        _abi: rustc_target::spec::abi::Abi,
+        _args: &[interpret::FnArg<'tcx, Self::Provenance>],
+        _destination: &interpret::MPlaceTy<'tcx, Self::Provenance>,
+        _target: Option<BasicBlock>,
+        _unwind: UnwindAction,
+    ) -> interpret::InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> {
+        unimplemented!()
+    }
+
+    fn panic_nounwind(
+        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
+        _msg: &str,
+    ) -> interpret::InterpResult<'tcx> {
+        unimplemented!()
+    }
+
+    fn call_intrinsic(
+        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
+        _instance: ty::Instance<'tcx>,
+        _args: &[interpret::OpTy<'tcx, Self::Provenance>],
+        _destination: &interpret::MPlaceTy<'tcx, Self::Provenance>,
+        _target: Option<BasicBlock>,
+        _unwind: UnwindAction,
+    ) -> interpret::InterpResult<'tcx> {
+        unimplemented!()
+    }
+
+    fn assert_panic(
+        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
+        _msg: &rustc_middle::mir::AssertMessage<'tcx>,
+        _unwind: UnwindAction,
+    ) -> interpret::InterpResult<'tcx> {
+        unimplemented!()
+    }
+
+    fn binary_ptr_op(
+        ecx: &InterpCx<'mir, 'tcx, Self>,
+        bin_op: BinOp,
+        left: &interpret::ImmTy<'tcx, Self::Provenance>,
+        right: &interpret::ImmTy<'tcx, Self::Provenance>,
+    ) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> {
+        use rustc_middle::mir::BinOp::*;
+        Ok(match bin_op {
+            Eq | Ne | Lt | Le | Gt | Ge => {
+                // Types can differ, e.g. fn ptrs with different `for`.
+                assert_eq!(left.layout.abi, right.layout.abi);
+                let size = ecx.pointer_size();
+                // Just compare the bits. ScalarPairs are compared lexicographically.
+                // We thus always compare pairs and simply fill scalars up with 0.
+                // If the pointer has provenance, `to_bits` will return `Err` and we bail out.
+                let left = match **left {
+                    Immediate::Scalar(l) => (l.to_bits(size)?, 0),
+                    Immediate::ScalarPair(l1, l2) => (l1.to_bits(size)?, l2.to_bits(size)?),
+                    Immediate::Uninit => panic!("we should never see uninit data here"),
+                };
+                let right = match **right {
+                    Immediate::Scalar(r) => (r.to_bits(size)?, 0),
+                    Immediate::ScalarPair(r1, r2) => (r1.to_bits(size)?, r2.to_bits(size)?),
+                    Immediate::Uninit => panic!("we should never see uninit data here"),
+                };
+                let res = match bin_op {
+                    Eq => left == right,
+                    Ne => left != right,
+                    Lt => left < right,
+                    Le => left <= right,
+                    Gt => left > right,
+                    Ge => left >= right,
+                    _ => bug!(),
+                };
+                (ImmTy::from_bool(res, *ecx.tcx), false)
+            }
+
+            // Some more operations are possible with atomics.
+            // The return value always has the provenance of the *left* operand.
+            Add | Sub | BitOr | BitAnd | BitXor => {
+                throw_machine_stop_str!("pointer arithmetic is not handled")
+            }
+
+            _ => span_bug!(ecx.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
+        })
+    }
+
+    fn expose_ptr(
+        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
+        _ptr: interpret::Pointer<Self::Provenance>,
+    ) -> interpret::InterpResult<'tcx> {
+        unimplemented!()
+    }
+
+    fn init_frame_extra(
+        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
+        _frame: interpret::Frame<'mir, 'tcx, Self::Provenance>,
+    ) -> interpret::InterpResult<
+        'tcx,
+        interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>,
+    > {
+        unimplemented!()
+    }
+
+    fn stack<'a>(
+        _ecx: &'a InterpCx<'mir, 'tcx, Self>,
+    ) -> &'a [interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>] {
+        // Return an empty stack instead of panicking, as `cur_span` uses it to evaluate constants.
+        &[]
+    }
+
+    fn stack_mut<'a>(
+        _ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
+    ) -> &'a mut Vec<interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>> {
+        unimplemented!()
+    }
+}
diff --git a/compiler/rustc_const_eval/src/const_eval/eval_queries.rs b/compiler/rustc_const_eval/src/const_eval/eval_queries.rs
index 5a1c7cc4209..16bd0296247 100644
--- a/compiler/rustc_const_eval/src/const_eval/eval_queries.rs
+++ b/compiler/rustc_const_eval/src/const_eval/eval_queries.rs
@@ -3,7 +3,7 @@ use either::{Left, Right};
 use rustc_hir::def::DefKind;
 use rustc_middle::mir::interpret::{AllocId, ErrorHandled, InterpErrorInfo};
 use rustc_middle::mir::{self, ConstAlloc, ConstValue};
-use rustc_middle::query::TyCtxtAt;
+use rustc_middle::query::{Key, TyCtxtAt};
 use rustc_middle::traits::Reveal;
 use rustc_middle::ty::layout::LayoutOf;
 use rustc_middle::ty::print::with_no_trimmed_paths;
@@ -243,6 +243,24 @@ pub(crate) fn turn_into_const_value<'tcx>(
     op_to_const(&ecx, &mplace.into(), /* for diagnostics */ false)
 }
 
+/// Computes the tag (if any) for a given type and variant.
+#[instrument(skip(tcx), level = "debug")]
+pub fn tag_for_variant_provider<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    (ty, variant_index): (Ty<'tcx>, abi::VariantIdx),
+) -> Option<ty::ScalarInt> {
+    assert!(ty.is_enum());
+
+    let ecx = InterpCx::new(
+        tcx,
+        ty.default_span(tcx),
+        ty::ParamEnv::reveal_all(),
+        crate::const_eval::DummyMachine,
+    );
+
+    ecx.tag_for_variant(ty, variant_index).unwrap().map(|(tag, _tag_field)| tag)
+}
+
 #[instrument(skip(tcx), level = "debug")]
 pub fn eval_to_const_value_raw_provider<'tcx>(
     tcx: TyCtxt<'tcx>,
diff --git a/compiler/rustc_const_eval/src/const_eval/mod.rs b/compiler/rustc_const_eval/src/const_eval/mod.rs
index d0d6adbfad0..b768c429070 100644
--- a/compiler/rustc_const_eval/src/const_eval/mod.rs
+++ b/compiler/rustc_const_eval/src/const_eval/mod.rs
@@ -7,12 +7,14 @@ use rustc_middle::ty::{self, Ty};
 
 use crate::interpret::format_interp_error;
 
+mod dummy_machine;
 mod error;
 mod eval_queries;
 mod fn_queries;
 mod machine;
 mod valtrees;
 
+pub use dummy_machine::*;
 pub use error::*;
 pub use eval_queries::*;
 pub use fn_queries::*;
diff --git a/compiler/rustc_const_eval/src/interpret/discriminant.rs b/compiler/rustc_const_eval/src/interpret/discriminant.rs
index 6d4f6d0cb3c..40469c6632c 100644
--- a/compiler/rustc_const_eval/src/interpret/discriminant.rs
+++ b/compiler/rustc_const_eval/src/interpret/discriminant.rs
@@ -2,7 +2,7 @@
 
 use rustc_middle::mir;
 use rustc_middle::ty::layout::{LayoutOf, PrimitiveExt};
-use rustc_middle::ty::{self, Ty};
+use rustc_middle::ty::{self, ScalarInt, Ty};
 use rustc_target::abi::{self, TagEncoding};
 use rustc_target::abi::{VariantIdx, Variants};
 
@@ -28,78 +28,27 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
             throw_ub!(UninhabitedEnumVariantWritten(variant_index))
         }
 
-        match dest.layout().variants {
-            abi::Variants::Single { index } => {
-                assert_eq!(index, variant_index);
-            }
-            abi::Variants::Multiple {
-                tag_encoding: TagEncoding::Direct,
-                tag: tag_layout,
-                tag_field,
-                ..
-            } => {
+        match self.tag_for_variant(dest.layout().ty, variant_index)? {
+            Some((tag, tag_field)) => {
                 // No need to validate that the discriminant here because the
-                // `TyAndLayout::for_variant()` call earlier already checks the variant is valid.
-
-                let discr_val = dest
-                    .layout()
-                    .ty
-                    .discriminant_for_variant(*self.tcx, variant_index)
-                    .unwrap()
-                    .val;
-
-                // raw discriminants for enums are isize or bigger during
-                // their computation, but the in-memory tag is the smallest possible
-                // representation
-                let size = tag_layout.size(self);
-                let tag_val = size.truncate(discr_val);
-
+                // `TyAndLayout::for_variant()` call earlier already checks the
+                // variant is valid.
                 let tag_dest = self.project_field(dest, tag_field)?;
-                self.write_scalar(Scalar::from_uint(tag_val, size), &tag_dest)?;
+                self.write_scalar(tag, &tag_dest)
             }
-            abi::Variants::Multiple {
-                tag_encoding:
-                    TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
-                tag: tag_layout,
-                tag_field,
-                ..
-            } => {
-                // No need to validate that the discriminant here because the
-                // `TyAndLayout::for_variant()` call earlier already checks the variant is valid.
-
-                if variant_index != untagged_variant {
-                    let variants_start = niche_variants.start().as_u32();
-                    let variant_index_relative = variant_index
-                        .as_u32()
-                        .checked_sub(variants_start)
-                        .expect("overflow computing relative variant idx");
-                    // We need to use machine arithmetic when taking into account `niche_start`:
-                    // tag_val = variant_index_relative + niche_start_val
-                    let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
-                    let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
-                    let variant_index_relative_val =
-                        ImmTy::from_uint(variant_index_relative, tag_layout);
-                    let tag_val = self.wrapping_binary_op(
-                        mir::BinOp::Add,
-                        &variant_index_relative_val,
-                        &niche_start_val,
-                    )?;
-                    // Write result.
-                    let niche_dest = self.project_field(dest, tag_field)?;
-                    self.write_immediate(*tag_val, &niche_dest)?;
-                } else {
-                    // The untagged variant is implicitly encoded simply by having a value that is
-                    // outside the niche variants. But what if the data stored here does not
-                    // actually encode this variant? That would be bad! So let's double-check...
-                    let actual_variant = self.read_discriminant(&dest.to_op(self)?)?;
-                    if actual_variant != variant_index {
-                        throw_ub!(InvalidNichedEnumVariantWritten { enum_ty: dest.layout().ty });
-                    }
+            None => {
+                // No need to write the tag here, because an untagged variant is
+                // implicitly encoded. For `Niche`-optimized enums, it's by
+                // simply by having a value that is outside the niche variants.
+                // But what if the data stored here does not actually encode
+                // this variant? That would be bad! So let's double-check...
+                let actual_variant = self.read_discriminant(&dest.to_op(self)?)?;
+                if actual_variant != variant_index {
+                    throw_ub!(InvalidNichedEnumVariantWritten { enum_ty: dest.layout().ty });
                 }
+                Ok(())
             }
         }
-
-        Ok(())
     }
 
     /// Read discriminant, return the runtime value as well as the variant index.
@@ -277,4 +226,77 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
         };
         Ok(ImmTy::from_scalar(discr_value, discr_layout))
     }
+
+    /// Computes the tag value and its field number (if any) of a given variant
+    /// of type `ty`.
+    pub(crate) fn tag_for_variant(
+        &self,
+        ty: Ty<'tcx>,
+        variant_index: VariantIdx,
+    ) -> InterpResult<'tcx, Option<(ScalarInt, usize)>> {
+        match self.layout_of(ty)?.variants {
+            abi::Variants::Single { index } => {
+                assert_eq!(index, variant_index);
+                Ok(None)
+            }
+
+            abi::Variants::Multiple {
+                tag_encoding: TagEncoding::Direct,
+                tag: tag_layout,
+                tag_field,
+                ..
+            } => {
+                // raw discriminants for enums are isize or bigger during
+                // their computation, but the in-memory tag is the smallest possible
+                // representation
+                let discr = self.discriminant_for_variant(ty, variant_index)?;
+                let discr_size = discr.layout.size;
+                let discr_val = discr.to_scalar().to_bits(discr_size)?;
+                let tag_size = tag_layout.size(self);
+                let tag_val = tag_size.truncate(discr_val);
+                let tag = ScalarInt::try_from_uint(tag_val, tag_size).unwrap();
+                Ok(Some((tag, tag_field)))
+            }
+
+            abi::Variants::Multiple {
+                tag_encoding: TagEncoding::Niche { untagged_variant, .. },
+                ..
+            } if untagged_variant == variant_index => {
+                // The untagged variant is implicitly encoded simply by having a
+                // value that is outside the niche variants.
+                Ok(None)
+            }
+
+            abi::Variants::Multiple {
+                tag_encoding:
+                    TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
+                tag: tag_layout,
+                tag_field,
+                ..
+            } => {
+                assert!(variant_index != untagged_variant);
+                let variants_start = niche_variants.start().as_u32();
+                let variant_index_relative = variant_index
+                    .as_u32()
+                    .checked_sub(variants_start)
+                    .expect("overflow computing relative variant idx");
+                // We need to use machine arithmetic when taking into account `niche_start`:
+                // tag_val = variant_index_relative + niche_start_val
+                let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
+                let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
+                let variant_index_relative_val =
+                    ImmTy::from_uint(variant_index_relative, tag_layout);
+                let tag = self
+                    .wrapping_binary_op(
+                        mir::BinOp::Add,
+                        &variant_index_relative_val,
+                        &niche_start_val,
+                    )?
+                    .to_scalar()
+                    .try_to_int()
+                    .unwrap();
+                Ok(Some((tag, tag_field)))
+            }
+        }
+    }
 }
diff --git a/compiler/rustc_const_eval/src/lib.rs b/compiler/rustc_const_eval/src/lib.rs
index 1e7ee208af1..633caf8d092 100644
--- a/compiler/rustc_const_eval/src/lib.rs
+++ b/compiler/rustc_const_eval/src/lib.rs
@@ -40,6 +40,7 @@ rustc_fluent_macro::fluent_messages! { "../messages.ftl" }
 
 pub fn provide(providers: &mut Providers) {
     const_eval::provide(providers);
+    providers.tag_for_variant = const_eval::tag_for_variant_provider;
     providers.eval_to_const_value_raw = const_eval::eval_to_const_value_raw_provider;
     providers.eval_to_allocation_raw = const_eval::eval_to_allocation_raw_provider;
     providers.eval_static_initializer = const_eval::eval_static_initializer_provider;
diff --git a/compiler/rustc_middle/src/query/erase.rs b/compiler/rustc_middle/src/query/erase.rs
index 33ee3371605..d3da49c26a2 100644
--- a/compiler/rustc_middle/src/query/erase.rs
+++ b/compiler/rustc_middle/src/query/erase.rs
@@ -234,6 +234,7 @@ trivial! {
     Option<rustc_middle::middle::stability::DeprecationEntry>,
     Option<rustc_middle::ty::Destructor>,
     Option<rustc_middle::ty::ImplTraitInTraitData>,
+    Option<rustc_middle::ty::ScalarInt>,
     Option<rustc_span::def_id::CrateNum>,
     Option<rustc_span::def_id::DefId>,
     Option<rustc_span::def_id::LocalDefId>,
diff --git a/compiler/rustc_middle/src/query/keys.rs b/compiler/rustc_middle/src/query/keys.rs
index 69d3974184d..3b1d1a04d6f 100644
--- a/compiler/rustc_middle/src/query/keys.rs
+++ b/compiler/rustc_middle/src/query/keys.rs
@@ -13,6 +13,7 @@ use rustc_query_system::query::DefIdCacheSelector;
 use rustc_query_system::query::{DefaultCacheSelector, SingleCacheSelector, VecCacheSelector};
 use rustc_span::symbol::{Ident, Symbol};
 use rustc_span::{Span, DUMMY_SP};
+use rustc_target::abi;
 
 /// Placeholder for `CrateNum`'s "local" counterpart
 #[derive(Copy, Clone, Debug)]
@@ -502,6 +503,14 @@ impl<'tcx> Key for (DefId, Ty<'tcx>, GenericArgsRef<'tcx>, ty::ParamEnv<'tcx>) {
     }
 }
 
+impl<'tcx> Key for (Ty<'tcx>, abi::VariantIdx) {
+    type CacheSelector = DefaultCacheSelector<Self>;
+
+    fn default_span(&self, _tcx: TyCtxt<'_>) -> Span {
+        DUMMY_SP
+    }
+}
+
 impl<'tcx> Key for (ty::Predicate<'tcx>, traits::WellFormedLoc) {
     type CacheSelector = DefaultCacheSelector<Self>;
 
diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs
index 865299e15c8..f3c0280fb81 100644
--- a/compiler/rustc_middle/src/query/mod.rs
+++ b/compiler/rustc_middle/src/query/mod.rs
@@ -1042,6 +1042,13 @@ rustc_queries! {
         }
     }
 
+    /// Computes the tag (if any) for a given type and variant.
+    query tag_for_variant(
+        key: (Ty<'tcx>, abi::VariantIdx)
+    ) -> Option<ty::ScalarInt> {
+        desc { "computing variant tag for enum" }
+    }
+
     /// Evaluates a constant and returns the computed allocation.
     ///
     /// **Do not use this** directly, use the `eval_to_const_value` or `eval_to_valtree` instead.
diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
index f456196b282..6c7a7278ca3 100644
--- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
+++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
@@ -2,52 +2,22 @@
 //!
 //! Currently, this pass only propagates scalar values.
 
-use rustc_const_eval::interpret::{
-    HasStaticRootDefId, ImmTy, Immediate, InterpCx, OpTy, PlaceTy, PointerArithmetic, Projectable,
-};
+use rustc_const_eval::const_eval::{throw_machine_stop_str, DummyMachine};
+use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable};
 use rustc_data_structures::fx::FxHashMap;
 use rustc_hir::def::DefKind;
-use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar};
+use rustc_middle::mir::interpret::{InterpResult, Scalar};
 use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
 use rustc_middle::mir::*;
-use rustc_middle::query::TyCtxtAt;
-use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
+use rustc_middle::ty::layout::LayoutOf;
 use rustc_middle::ty::{self, Ty, TyCtxt};
 use rustc_mir_dataflow::value_analysis::{
     Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
 };
 use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
-use rustc_span::def_id::DefId;
 use rustc_span::DUMMY_SP;
 use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT};
 
-/// Macro for machine-specific `InterpError` without allocation.
-/// (These will never be shown to the user, but they help diagnose ICEs.)
-pub(crate) macro throw_machine_stop_str($($tt:tt)*) {{
-    // We make a new local type for it. The type itself does not carry any information,
-    // but its vtable (for the `MachineStopType` trait) does.
-    #[derive(Debug)]
-    struct Zst;
-    // Printing this type shows the desired string.
-    impl std::fmt::Display for Zst {
-        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-            write!(f, $($tt)*)
-        }
-    }
-
-    impl rustc_middle::mir::interpret::MachineStopType for Zst {
-        fn diagnostic_message(&self) -> rustc_errors::DiagMessage {
-            self.to_string().into()
-        }
-
-        fn add_args(
-            self: Box<Self>,
-            _: &mut dyn FnMut(rustc_errors::DiagArgName, rustc_errors::DiagArgValue),
-        ) {}
-    }
-    throw_machine_stop!(Zst)
-}}
-
 // These constants are somewhat random guesses and have not been optimized.
 // If `tcx.sess.mir_opt_level() >= 4`, we ignore the limits (this can become very expensive).
 const BLOCK_LIMIT: usize = 100;
@@ -888,165 +858,3 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
         }
     }
 }
-
-pub(crate) struct DummyMachine;
-
-impl HasStaticRootDefId for DummyMachine {
-    fn static_def_id(&self) -> Option<rustc_hir::def_id::LocalDefId> {
-        None
-    }
-}
-
-impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for DummyMachine {
-    rustc_const_eval::interpret::compile_time_machine!(<'mir, 'tcx>);
-    type MemoryKind = !;
-    const PANIC_ON_ALLOC_FAIL: bool = true;
-
-    #[inline(always)]
-    fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
-        false // no reason to enforce alignment
-    }
-
-    fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
-        false
-    }
-
-    fn before_access_global(
-        _tcx: TyCtxtAt<'tcx>,
-        _machine: &Self,
-        _alloc_id: AllocId,
-        alloc: ConstAllocation<'tcx>,
-        _static_def_id: Option<DefId>,
-        is_write: bool,
-    ) -> InterpResult<'tcx> {
-        if is_write {
-            throw_machine_stop_str!("can't write to global");
-        }
-
-        // If the static allocation is mutable, then we can't const prop it as its content
-        // might be different at runtime.
-        if alloc.inner().mutability.is_mut() {
-            throw_machine_stop_str!("can't access mutable globals in ConstProp");
-        }
-
-        Ok(())
-    }
-
-    fn find_mir_or_eval_fn(
-        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
-        _instance: ty::Instance<'tcx>,
-        _abi: rustc_target::spec::abi::Abi,
-        _args: &[rustc_const_eval::interpret::FnArg<'tcx, Self::Provenance>],
-        _destination: &rustc_const_eval::interpret::MPlaceTy<'tcx, Self::Provenance>,
-        _target: Option<BasicBlock>,
-        _unwind: UnwindAction,
-    ) -> interpret::InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> {
-        unimplemented!()
-    }
-
-    fn panic_nounwind(
-        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
-        _msg: &str,
-    ) -> interpret::InterpResult<'tcx> {
-        unimplemented!()
-    }
-
-    fn call_intrinsic(
-        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
-        _instance: ty::Instance<'tcx>,
-        _args: &[rustc_const_eval::interpret::OpTy<'tcx, Self::Provenance>],
-        _destination: &rustc_const_eval::interpret::MPlaceTy<'tcx, Self::Provenance>,
-        _target: Option<BasicBlock>,
-        _unwind: UnwindAction,
-    ) -> interpret::InterpResult<'tcx> {
-        unimplemented!()
-    }
-
-    fn assert_panic(
-        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
-        _msg: &rustc_middle::mir::AssertMessage<'tcx>,
-        _unwind: UnwindAction,
-    ) -> interpret::InterpResult<'tcx> {
-        unimplemented!()
-    }
-
-    fn binary_ptr_op(
-        ecx: &InterpCx<'mir, 'tcx, Self>,
-        bin_op: BinOp,
-        left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
-        right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
-    ) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> {
-        use rustc_middle::mir::BinOp::*;
-        Ok(match bin_op {
-            Eq | Ne | Lt | Le | Gt | Ge => {
-                // Types can differ, e.g. fn ptrs with different `for`.
-                assert_eq!(left.layout.abi, right.layout.abi);
-                let size = ecx.pointer_size();
-                // Just compare the bits. ScalarPairs are compared lexicographically.
-                // We thus always compare pairs and simply fill scalars up with 0.
-                // If the pointer has provenance, `to_bits` will return `Err` and we bail out.
-                let left = match **left {
-                    Immediate::Scalar(l) => (l.to_bits(size)?, 0),
-                    Immediate::ScalarPair(l1, l2) => (l1.to_bits(size)?, l2.to_bits(size)?),
-                    Immediate::Uninit => panic!("we should never see uninit data here"),
-                };
-                let right = match **right {
-                    Immediate::Scalar(r) => (r.to_bits(size)?, 0),
-                    Immediate::ScalarPair(r1, r2) => (r1.to_bits(size)?, r2.to_bits(size)?),
-                    Immediate::Uninit => panic!("we should never see uninit data here"),
-                };
-                let res = match bin_op {
-                    Eq => left == right,
-                    Ne => left != right,
-                    Lt => left < right,
-                    Le => left <= right,
-                    Gt => left > right,
-                    Ge => left >= right,
-                    _ => bug!(),
-                };
-                (ImmTy::from_bool(res, *ecx.tcx), false)
-            }
-
-            // Some more operations are possible with atomics.
-            // The return value always has the provenance of the *left* operand.
-            Add | Sub | BitOr | BitAnd | BitXor => {
-                throw_machine_stop_str!("pointer arithmetic is not handled")
-            }
-
-            _ => span_bug!(ecx.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
-        })
-    }
-
-    fn expose_ptr(
-        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
-        _ptr: interpret::Pointer<Self::Provenance>,
-    ) -> interpret::InterpResult<'tcx> {
-        unimplemented!()
-    }
-
-    fn init_frame_extra(
-        _ecx: &mut InterpCx<'mir, 'tcx, Self>,
-        _frame: rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance>,
-    ) -> interpret::InterpResult<
-        'tcx,
-        rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>,
-    > {
-        unimplemented!()
-    }
-
-    fn stack<'a>(
-        _ecx: &'a InterpCx<'mir, 'tcx, Self>,
-    ) -> &'a [rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>]
-    {
-        // Return an empty stack instead of panicking, as `cur_span` uses it to evaluate constants.
-        &[]
-    }
-
-    fn stack_mut<'a>(
-        _ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
-    ) -> &'a mut Vec<
-        rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>,
-    > {
-        unimplemented!()
-    }
-}
diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs
index a3a2108787a..bb257d84c9f 100644
--- a/compiler/rustc_mir_transform/src/gvn.rs
+++ b/compiler/rustc_mir_transform/src/gvn.rs
@@ -82,6 +82,7 @@
 //! Second, when writing constants in MIR, we do not write `Const::Slice` or `Const`
 //! that contain `AllocId`s.
 
+use rustc_const_eval::const_eval::DummyMachine;
 use rustc_const_eval::interpret::{intern_const_alloc_for_constprop, MemoryKind};
 use rustc_const_eval::interpret::{ImmTy, InterpCx, OpTy, Projectable, Scalar};
 use rustc_data_structures::fx::FxIndexSet;
@@ -101,7 +102,6 @@ use rustc_target::abi::{self, Abi, Size, VariantIdx, FIRST_VARIANT};
 use smallvec::SmallVec;
 use std::borrow::Cow;
 
-use crate::dataflow_const_prop::DummyMachine;
 use crate::ssa::{AssignedValue, SsaLocals};
 use either::Either;
 
diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs
index 6629face940..f680f391feb 100644
--- a/compiler/rustc_mir_transform/src/jump_threading.rs
+++ b/compiler/rustc_mir_transform/src/jump_threading.rs
@@ -36,6 +36,7 @@
 //! cost by `MAX_COST`.
 
 use rustc_arena::DroplessArena;
+use rustc_const_eval::const_eval::DummyMachine;
 use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
 use rustc_data_structures::fx::FxHashSet;
 use rustc_index::bit_set::BitSet;
@@ -50,7 +51,6 @@ use rustc_span::DUMMY_SP;
 use rustc_target::abi::{TagEncoding, Variants};
 
 use crate::cost_checker::CostChecker;
-use crate::dataflow_const_prop::DummyMachine;
 
 pub struct JumpThreading;
 
diff --git a/compiler/rustc_mir_transform/src/known_panics_lint.rs b/compiler/rustc_mir_transform/src/known_panics_lint.rs
index 4bca437ea6f..c9db374bec9 100644
--- a/compiler/rustc_mir_transform/src/known_panics_lint.rs
+++ b/compiler/rustc_mir_transform/src/known_panics_lint.rs
@@ -6,6 +6,7 @@
 
 use std::fmt::Debug;
 
+use rustc_const_eval::const_eval::DummyMachine;
 use rustc_const_eval::interpret::{
     format_interp_error, ImmTy, InterpCx, InterpResult, Projectable, Scalar,
 };
@@ -20,7 +21,6 @@ use rustc_middle::ty::{self, ConstInt, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisi
 use rustc_span::Span;
 use rustc_target::abi::{Abi, FieldIdx, HasDataLayout, Size, TargetDataLayout, VariantIdx};
 
-use crate::dataflow_const_prop::DummyMachine;
 use crate::errors::{AssertLint, AssertLintKind};
 use crate::MirLint;
 
diff --git a/compiler/rustc_transmute/src/layout/tree.rs b/compiler/rustc_transmute/src/layout/tree.rs
index 9a43d67d435..f6bc224c7e7 100644
--- a/compiler/rustc_transmute/src/layout/tree.rs
+++ b/compiler/rustc_transmute/src/layout/tree.rs
@@ -174,10 +174,10 @@ pub(crate) mod rustc {
     use crate::layout::rustc::{Def, Ref};
 
     use rustc_middle::ty::layout::LayoutError;
-    use rustc_middle::ty::util::Discr;
     use rustc_middle::ty::AdtDef;
     use rustc_middle::ty::GenericArgsRef;
     use rustc_middle::ty::ParamEnv;
+    use rustc_middle::ty::ScalarInt;
     use rustc_middle::ty::VariantDef;
     use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
     use rustc_span::ErrorGuaranteed;
@@ -331,14 +331,15 @@ pub(crate) mod rustc {
                             trace!(?adt_def, "treeifying enum");
                             let mut tree = Tree::uninhabited();
 
-                            for (idx, discr) in adt_def.discriminants(tcx) {
+                            for (idx, variant) in adt_def.variants().iter_enumerated() {
+                                let tag = tcx.tag_for_variant((ty, idx));
                                 tree = tree.or(Self::from_repr_c_variant(
                                     ty,
                                     *adt_def,
                                     args_ref,
                                     &layout_summary,
-                                    Some(discr),
-                                    adt_def.variant(idx),
+                                    tag,
+                                    variant,
                                     tcx,
                                 )?);
                             }
@@ -393,7 +394,7 @@ pub(crate) mod rustc {
             adt_def: AdtDef<'tcx>,
             args_ref: GenericArgsRef<'tcx>,
             layout_summary: &LayoutSummary,
-            discr: Option<Discr<'tcx>>,
+            tag: Option<ScalarInt>,
             variant_def: &'tcx VariantDef,
             tcx: TyCtxt<'tcx>,
         ) -> Result<Self, Err> {
@@ -403,9 +404,6 @@ pub(crate) mod rustc {
             let min_align = repr.align.unwrap_or(Align::ONE);
             let max_align = repr.pack.unwrap_or(Align::MAX);
 
-            let clamp =
-                |align: Align| align.clamp(min_align, max_align).bytes().try_into().unwrap();
-
             let variant_span = trace_span!(
                 "treeifying variant",
                 min_align = ?min_align,
@@ -419,17 +417,12 @@ pub(crate) mod rustc {
             )
             .unwrap();
 
-            // The layout of the variant is prefixed by the discriminant, if any.
-            if let Some(discr) = discr {
-                trace!(?discr, "treeifying discriminant");
-                let discr_layout = alloc::Layout::from_size_align(
-                    layout_summary.discriminant_size,
-                    clamp(layout_summary.discriminant_align),
-                )
-                .unwrap();
-                trace!(?discr_layout, "computed discriminant layout");
-                variant_layout = variant_layout.extend(discr_layout).unwrap().0;
-                tree = tree.then(Self::from_discr(discr, tcx, layout_summary.discriminant_size));
+            // The layout of the variant is prefixed by the tag, if any.
+            if let Some(tag) = tag {
+                let tag_layout =
+                    alloc::Layout::from_size_align(tag.size().bytes_usize(), 1).unwrap();
+                tree = tree.then(Self::from_tag(tag, tcx));
+                variant_layout = variant_layout.extend(tag_layout).unwrap().0;
             }
 
             // Next come fields.
@@ -469,18 +462,19 @@ pub(crate) mod rustc {
             Ok(tree)
         }
 
-        pub fn from_discr(discr: Discr<'tcx>, tcx: TyCtxt<'tcx>, size: usize) -> Self {
+        pub fn from_tag(tag: ScalarInt, tcx: TyCtxt<'tcx>) -> Self {
             use rustc_target::abi::Endian;
-
+            let size = tag.size();
+            let bits = tag.to_bits(size).unwrap();
             let bytes: [u8; 16];
             let bytes = match tcx.data_layout.endian {
                 Endian::Little => {
-                    bytes = discr.val.to_le_bytes();
-                    &bytes[..size]
+                    bytes = bits.to_le_bytes();
+                    &bytes[..size.bytes_usize()]
                 }
                 Endian::Big => {
-                    bytes = discr.val.to_be_bytes();
-                    &bytes[bytes.len() - size..]
+                    bytes = bits.to_be_bytes();
+                    &bytes[bytes.len() - size.bytes_usize()..]
                 }
             };
             Self::Seq(bytes.iter().map(|&b| Self::from_bits(b)).collect())