about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorJack Wrenn <jack@wrenn.fyi>2024-08-14 20:10:28 +0000
committerJack Wrenn <jack@wrenn.fyi>2024-08-18 18:31:06 +0000
commit17995d5cc2e75b65b8bd6b77ad30b0c764500c0c (patch)
tree27bbb55672e96182eedb86e9f26229eaa5214606 /compiler
parent0f442e265c165c0a78633bef98de18517815150c (diff)
downloadrust-17995d5cc2e75b65b8bd6b77ad30b0c764500c0c.tar.gz
rust-17995d5cc2e75b65b8bd6b77ad30b0c764500c0c.zip
safe transmute: forbid reference lifetime extension
Modifies `BikeshedIntrinsicFrom` to forbid lifetime extensions on
references. This static check can be opted out of with the
`Assume::lifetimes` flag.

Fixes #129097
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/confirmation.rs176
-rw-r--r--compiler/rustc_transmute/src/layout/mod.rs13
-rw-r--r--compiler/rustc_transmute/src/layout/tree.rs141
-rw-r--r--compiler/rustc_transmute/src/lib.rs2
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/mod.rs7
5 files changed, 207 insertions, 132 deletions
diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
index 0d7ceca4301..f19cd19c99a 100644
--- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
@@ -17,8 +17,7 @@ use rustc_infer::infer::{DefineOpaqueTypes, HigherRankedType, InferOk};
 use rustc_infer::traits::ObligationCauseCode;
 use rustc_middle::traits::{BuiltinImplSource, SignatureMismatchData};
 use rustc_middle::ty::{
-    self, GenericArgs, GenericArgsRef, GenericParamDefKind, ToPolyTraitRef, TraitPredicate, Ty,
-    TyCtxt, Upcast,
+    self, GenericArgs, GenericArgsRef, GenericParamDefKind, ToPolyTraitRef, Ty, TyCtxt, Upcast,
 };
 use rustc_middle::{bug, span_bug};
 use rustc_span::def_id::DefId;
@@ -292,90 +291,120 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        use rustc_transmute::{Answer, Condition};
-        #[instrument(level = "debug", skip(tcx, obligation, predicate))]
+        use rustc_transmute::{Answer, Assume, Condition};
+
+        /// Generate sub-obligations for reference-to-reference transmutations.
+        fn reference_obligations<'tcx>(
+            tcx: TyCtxt<'tcx>,
+            obligation: &PolyTraitObligation<'tcx>,
+            (src_lifetime, src_ty, src_mut): (ty::Region<'tcx>, Ty<'tcx>, Mutability),
+            (dst_lifetime, dst_ty, dst_mut): (ty::Region<'tcx>, Ty<'tcx>, Mutability),
+            assume: Assume,
+        ) -> Vec<PredicateObligation<'tcx>> {
+            let make_transmute_obl = |src, dst| {
+                let transmute_trait = obligation.predicate.def_id();
+                let assume = obligation.predicate.skip_binder().trait_ref.args.const_at(2);
+                let trait_ref = ty::TraitRef::new(
+                    tcx,
+                    transmute_trait,
+                    [
+                        ty::GenericArg::from(dst),
+                        ty::GenericArg::from(src),
+                        ty::GenericArg::from(assume),
+                    ],
+                );
+                Obligation::with_depth(
+                    tcx,
+                    obligation.cause.clone(),
+                    obligation.recursion_depth + 1,
+                    obligation.param_env,
+                    obligation.predicate.rebind(trait_ref),
+                )
+            };
+
+            let make_freeze_obl = |ty| {
+                let trait_ref = ty::TraitRef::new(
+                    tcx,
+                    tcx.require_lang_item(LangItem::Freeze, None),
+                    [ty::GenericArg::from(ty)],
+                );
+                Obligation::with_depth(
+                    tcx,
+                    obligation.cause.clone(),
+                    obligation.recursion_depth + 1,
+                    obligation.param_env,
+                    trait_ref,
+                )
+            };
+
+            let make_outlives_obl = |target, region| {
+                let outlives = ty::OutlivesPredicate(target, region);
+                Obligation::with_depth(
+                    tcx,
+                    obligation.cause.clone(),
+                    obligation.recursion_depth + 1,
+                    obligation.param_env,
+                    obligation.predicate.rebind(outlives),
+                )
+            };
+
+            // Given a transmutation from `&'a (mut) Src` and `&'dst (mut) Dst`,
+            // it is always the case that `Src` must be transmutable into `Dst`,
+            // and that that `'src` must outlive `'dst`.
+            let mut obls = vec![make_transmute_obl(src_ty, dst_ty)];
+            if !assume.lifetimes {
+                obls.push(make_outlives_obl(src_lifetime, dst_lifetime));
+            }
+
+            // Given a transmutation from `&Src`, both `Src` and `Dst` must be
+            // `Freeze`, otherwise, using the transmuted value could lead to
+            // data races.
+            if src_mut == Mutability::Not {
+                obls.extend([make_freeze_obl(src_ty), make_freeze_obl(dst_ty)])
+            }
+
+            // Given a transmutation into `&'dst mut Dst`, it also must be the
+            // case that `Dst` is transmutable into `Src`. For example,
+            // transmuting bool -> u8 is OK as long as you can't update that u8
+            // to be > 1, because you could later transmute the u8 back to a
+            // bool and get undefined behavior. It also must be the case that
+            // `'dst` lives exactly as long as `'src`.
+            if dst_mut == Mutability::Mut {
+                obls.push(make_transmute_obl(dst_ty, src_ty));
+                if !assume.lifetimes {
+                    obls.push(make_outlives_obl(dst_lifetime, src_lifetime));
+                }
+            }
+
+            obls
+        }
+
+        /// Flatten the `Condition` tree into a conjunction of obligations.
+        #[instrument(level = "debug", skip(tcx, obligation))]
         fn flatten_answer_tree<'tcx>(
             tcx: TyCtxt<'tcx>,
             obligation: &PolyTraitObligation<'tcx>,
-            predicate: TraitPredicate<'tcx>,
             cond: Condition<rustc_transmute::layout::rustc::Ref<'tcx>>,
+            assume: Assume,
         ) -> Vec<PredicateObligation<'tcx>> {
             match cond {
                 // FIXME(bryangarza): Add separate `IfAny` case, instead of treating as `IfAll`
                 // Not possible until the trait solver supports disjunctions of obligations
                 Condition::IfAll(conds) | Condition::IfAny(conds) => conds
                     .into_iter()
-                    .flat_map(|cond| flatten_answer_tree(tcx, obligation, predicate, cond))
+                    .flat_map(|cond| flatten_answer_tree(tcx, obligation, cond, assume))
                     .collect(),
-                Condition::IfTransmutable { src, dst } => {
-                    let transmute_trait = obligation.predicate.def_id();
-                    let assume_const = predicate.trait_ref.args.const_at(2);
-                    let make_transmute_obl = |from_ty, to_ty| {
-                        let trait_ref = ty::TraitRef::new(
-                            tcx,
-                            transmute_trait,
-                            [
-                                ty::GenericArg::from(to_ty),
-                                ty::GenericArg::from(from_ty),
-                                ty::GenericArg::from(assume_const),
-                            ],
-                        );
-                        Obligation::with_depth(
-                            tcx,
-                            obligation.cause.clone(),
-                            obligation.recursion_depth + 1,
-                            obligation.param_env,
-                            trait_ref,
-                        )
-                    };
-
-                    let make_freeze_obl = |ty| {
-                        let trait_ref = ty::TraitRef::new(
-                            tcx,
-                            tcx.require_lang_item(LangItem::Freeze, None),
-                            [ty::GenericArg::from(ty)],
-                        );
-                        Obligation::with_depth(
-                            tcx,
-                            obligation.cause.clone(),
-                            obligation.recursion_depth + 1,
-                            obligation.param_env,
-                            trait_ref,
-                        )
-                    };
-
-                    let mut obls = vec![];
-
-                    // If the source is a shared reference, it must be `Freeze`;
-                    // otherwise, transmuting could lead to data races.
-                    if src.mutability == Mutability::Not {
-                        obls.extend([make_freeze_obl(src.ty), make_freeze_obl(dst.ty)])
-                    }
-
-                    // If Dst is mutable, check bidirectionally.
-                    // For example, transmuting bool -> u8 is OK as long as you can't update that u8
-                    // to be > 1, because you could later transmute the u8 back to a bool and get UB.
-                    match dst.mutability {
-                        Mutability::Not => obls.push(make_transmute_obl(src.ty, dst.ty)),
-                        Mutability::Mut => obls.extend([
-                            make_transmute_obl(src.ty, dst.ty),
-                            make_transmute_obl(dst.ty, src.ty),
-                        ]),
-                    }
-
-                    obls
-                }
+                Condition::IfTransmutable { src, dst } => reference_obligations(
+                    tcx,
+                    obligation,
+                    (src.lifetime, src.ty, src.mutability),
+                    (dst.lifetime, dst.ty, dst.mutability),
+                    assume,
+                ),
             }
         }
 
-        // We erase regions here because transmutability calls layout queries,
-        // which does not handle inference regions and doesn't particularly
-        // care about other regions. Erasing late-bound regions is equivalent
-        // to instantiating the binder with placeholders then erasing those
-        // placeholder regions.
-        let predicate = self
-            .tcx()
-            .erase_regions(self.tcx().instantiate_bound_regions_with_erased(obligation.predicate));
+        let predicate = obligation.predicate.skip_binder();
 
         let Some(assume) = rustc_transmute::Assume::from_const(
             self.infcx.tcx,
@@ -387,6 +416,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
 
         let dst = predicate.trait_ref.args.type_at(0);
         let src = predicate.trait_ref.args.type_at(1);
+
         debug!(?src, ?dst);
         let mut transmute_env = rustc_transmute::TransmuteTypeEnv::new(self.infcx);
         let maybe_transmutable = transmute_env.is_transmutable(
@@ -397,7 +427,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
 
         let fully_flattened = match maybe_transmutable {
             Answer::No(_) => Err(Unimplemented)?,
-            Answer::If(cond) => flatten_answer_tree(self.tcx(), obligation, predicate, cond),
+            Answer::If(cond) => flatten_answer_tree(self.tcx(), obligation, cond, assume),
             Answer::Yes => vec![],
         };
 
diff --git a/compiler/rustc_transmute/src/layout/mod.rs b/compiler/rustc_transmute/src/layout/mod.rs
index bbf155581f9..1cf9e0b9b70 100644
--- a/compiler/rustc_transmute/src/layout/mod.rs
+++ b/compiler/rustc_transmute/src/layout/mod.rs
@@ -63,7 +63,9 @@ pub mod rustc {
     use std::fmt::{self, Write};
 
     use rustc_middle::mir::Mutability;
-    use rustc_middle::ty::{self, Ty};
+    use rustc_middle::ty::layout::{LayoutCx, LayoutError};
+    use rustc_middle::ty::{self, Ty, TyCtxt};
+    use rustc_target::abi::Layout;
 
     /// A reference in the layout.
     #[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)]
@@ -120,4 +122,13 @@ pub mod rustc {
             self != &Self::Primitive
         }
     }
+
+    pub(crate) fn layout_of<'tcx>(
+        cx: LayoutCx<'tcx, TyCtxt<'tcx>>,
+        ty: Ty<'tcx>,
+    ) -> Result<Layout<'tcx>, &'tcx LayoutError<'tcx>> {
+        use rustc_middle::ty::layout::LayoutOf;
+        let ty = cx.tcx.erase_regions(ty);
+        cx.layout_of(ty).map(|tl| tl.layout)
+    }
 }
diff --git a/compiler/rustc_transmute/src/layout/tree.rs b/compiler/rustc_transmute/src/layout/tree.rs
index 5c25f913ffe..7c73f74e629 100644
--- a/compiler/rustc_transmute/src/layout/tree.rs
+++ b/compiler/rustc_transmute/src/layout/tree.rs
@@ -171,10 +171,12 @@ where
 
 #[cfg(feature = "rustc")]
 pub(crate) mod rustc {
-    use rustc_middle::ty::layout::{HasTyCtxt, LayoutCx, LayoutError, LayoutOf};
+    use rustc_middle::ty::layout::{HasTyCtxt, LayoutCx, LayoutError};
     use rustc_middle::ty::{self, AdtDef, AdtKind, List, ScalarInt, Ty, TyCtxt, TypeVisitableExt};
     use rustc_span::ErrorGuaranteed;
-    use rustc_target::abi::{FieldsShape, Size, TyAndLayout, Variants};
+    use rustc_target::abi::{
+        FieldIdx, FieldsShape, Layout, Size, TyAndLayout, VariantIdx, Variants,
+    };
 
     use super::Tree;
     use crate::layout::rustc::{Def, Ref};
@@ -202,20 +204,18 @@ pub(crate) mod rustc {
     }
 
     impl<'tcx> Tree<Def<'tcx>, Ref<'tcx>> {
-        pub fn from_ty(
-            ty_and_layout: TyAndLayout<'tcx, Ty<'tcx>>,
-            cx: LayoutCx<'tcx, TyCtxt<'tcx>>,
-        ) -> Result<Self, Err> {
+        pub fn from_ty(ty: Ty<'tcx>, cx: LayoutCx<'tcx, TyCtxt<'tcx>>) -> Result<Self, Err> {
             use rustc_target::abi::HasDataLayout;
+            let layout = ty_layout(cx, ty);
 
-            if let Err(e) = ty_and_layout.ty.error_reported() {
+            if let Err(e) = ty.error_reported() {
                 return Err(Err::TypeError(e));
             }
 
             let target = cx.tcx.data_layout();
             let pointer_size = target.pointer_size;
 
-            match ty_and_layout.ty.kind() {
+            match ty.kind() {
                 ty::Bool => Ok(Self::bool()),
 
                 ty::Float(nty) => {
@@ -233,32 +233,30 @@ pub(crate) mod rustc {
                     Ok(Self::number(width as _))
                 }
 
-                ty::Tuple(members) => Self::from_tuple(ty_and_layout, members, cx),
+                ty::Tuple(members) => Self::from_tuple((ty, layout), members, cx),
 
                 ty::Array(inner_ty, len) => {
-                    let FieldsShape::Array { stride, count } = &ty_and_layout.fields else {
+                    let FieldsShape::Array { stride, count } = &layout.fields else {
                         return Err(Err::NotYetSupported);
                     };
-                    let inner_ty_and_layout = cx.layout_of(*inner_ty)?;
-                    assert_eq!(*stride, inner_ty_and_layout.size);
-                    let elt = Tree::from_ty(inner_ty_and_layout, cx)?;
+                    let inner_layout = ty_layout(cx, *inner_ty);
+                    assert_eq!(*stride, inner_layout.size);
+                    let elt = Tree::from_ty(*inner_ty, cx)?;
                     Ok(std::iter::repeat(elt)
                         .take(*count as usize)
                         .fold(Tree::unit(), |tree, elt| tree.then(elt)))
                 }
 
-                ty::Adt(adt_def, _args_ref) if !ty_and_layout.ty.is_box() => {
-                    match adt_def.adt_kind() {
-                        AdtKind::Struct => Self::from_struct(ty_and_layout, *adt_def, cx),
-                        AdtKind::Enum => Self::from_enum(ty_and_layout, *adt_def, cx),
-                        AdtKind::Union => Self::from_union(ty_and_layout, *adt_def, cx),
-                    }
-                }
+                ty::Adt(adt_def, _args_ref) if !ty.is_box() => match adt_def.adt_kind() {
+                    AdtKind::Struct => Self::from_struct((ty, layout), *adt_def, cx),
+                    AdtKind::Enum => Self::from_enum((ty, layout), *adt_def, cx),
+                    AdtKind::Union => Self::from_union((ty, layout), *adt_def, cx),
+                },
 
                 ty::Ref(lifetime, ty, mutability) => {
-                    let ty_and_layout = cx.layout_of(*ty)?;
-                    let align = ty_and_layout.align.abi.bytes_usize();
-                    let size = ty_and_layout.size.bytes_usize();
+                    let layout = ty_layout(cx, *ty);
+                    let align = layout.align.abi.bytes_usize();
+                    let size = layout.size.bytes_usize();
                     Ok(Tree::Ref(Ref {
                         lifetime: *lifetime,
                         ty: *ty,
@@ -274,21 +272,20 @@ pub(crate) mod rustc {
 
         /// Constructs a `Tree` from a tuple.
         fn from_tuple(
-            ty_and_layout: TyAndLayout<'tcx, Ty<'tcx>>,
+            (ty, layout): (Ty<'tcx>, Layout<'tcx>),
             members: &'tcx List<Ty<'tcx>>,
             cx: LayoutCx<'tcx, TyCtxt<'tcx>>,
         ) -> Result<Self, Err> {
-            match &ty_and_layout.fields {
+            match &layout.fields {
                 FieldsShape::Primitive => {
                     assert_eq!(members.len(), 1);
                     let inner_ty = members[0];
-                    let inner_ty_and_layout = cx.layout_of(inner_ty)?;
-                    assert_eq!(ty_and_layout.layout, inner_ty_and_layout.layout);
-                    Self::from_ty(inner_ty_and_layout, cx)
+                    let inner_layout = ty_layout(cx, inner_ty);
+                    Self::from_ty(inner_ty, cx)
                 }
                 FieldsShape::Arbitrary { offsets, .. } => {
                     assert_eq!(offsets.len(), members.len());
-                    Self::from_variant(Def::Primitive, None, ty_and_layout, ty_and_layout.size, cx)
+                    Self::from_variant(Def::Primitive, None, (ty, layout), layout.size, cx)
                 }
                 FieldsShape::Array { .. } | FieldsShape::Union(_) => Err(Err::NotYetSupported),
             }
@@ -300,13 +297,13 @@ pub(crate) mod rustc {
         ///
         /// Panics if `def` is not a struct definition.
         fn from_struct(
-            ty_and_layout: TyAndLayout<'tcx, Ty<'tcx>>,
+            (ty, layout): (Ty<'tcx>, Layout<'tcx>),
             def: AdtDef<'tcx>,
             cx: LayoutCx<'tcx, TyCtxt<'tcx>>,
         ) -> Result<Self, Err> {
             assert!(def.is_struct());
             let def = Def::Adt(def);
-            Self::from_variant(def, None, ty_and_layout, ty_and_layout.size, cx)
+            Self::from_variant(def, None, (ty, layout), layout.size, cx)
         }
 
         /// Constructs a `Tree` from an enum.
@@ -315,19 +312,18 @@ pub(crate) mod rustc {
         ///
         /// Panics if `def` is not an enum definition.
         fn from_enum(
-            ty_and_layout: TyAndLayout<'tcx, Ty<'tcx>>,
+            (ty, layout): (Ty<'tcx>, Layout<'tcx>),
             def: AdtDef<'tcx>,
             cx: LayoutCx<'tcx, TyCtxt<'tcx>>,
         ) -> Result<Self, Err> {
             assert!(def.is_enum());
-            let layout = ty_and_layout.layout;
 
             // Computes the variant of a given index.
             let layout_of_variant = |index| {
-                let tag = cx.tcx.tag_for_variant((ty_and_layout.ty, index));
+                let tag = cx.tcx.tag_for_variant((cx.tcx.erase_regions(ty), index));
                 let variant_def = Def::Variant(def.variant(index));
-                let variant_ty_and_layout = ty_and_layout.for_variant(&cx, index);
-                Self::from_variant(variant_def, tag, variant_ty_and_layout, layout.size, cx)
+                let variant_layout = ty_variant(cx, (ty, layout), index);
+                Self::from_variant(variant_def, tag, (ty, variant_layout), layout.size, cx)
             };
 
             // We consider three kinds of enums, each demanding a different
@@ -385,21 +381,20 @@ pub(crate) mod rustc {
         fn from_variant(
             def: Def<'tcx>,
             tag: Option<ScalarInt>,
-            ty_and_layout: TyAndLayout<'tcx, Ty<'tcx>>,
+            (ty, layout): (Ty<'tcx>, Layout<'tcx>),
             total_size: Size,
             cx: LayoutCx<'tcx, TyCtxt<'tcx>>,
         ) -> Result<Self, Err> {
             // This constructor does not support non-`FieldsShape::Arbitrary`
             // layouts.
-            let FieldsShape::Arbitrary { offsets, memory_index } = ty_and_layout.layout.fields()
-            else {
+            let FieldsShape::Arbitrary { offsets, memory_index } = layout.fields() else {
                 return Err(Err::NotYetSupported);
             };
 
             // When this function is invoked with enum variants,
             // `ty_and_layout.size` does not encompass the entire size of the
             // enum. We rely on `total_size` for this.
-            assert!(ty_and_layout.size <= total_size);
+            assert!(layout.size <= total_size);
 
             let mut size = Size::ZERO;
             let mut struct_tree = Self::def(def);
@@ -412,17 +407,18 @@ pub(crate) mod rustc {
 
             // Append the fields, in memory order, to the layout.
             let inverse_memory_index = memory_index.invert_bijective_mapping();
-            for (memory_idx, field_idx) in inverse_memory_index.iter_enumerated() {
+            for (memory_idx, &field_idx) in inverse_memory_index.iter_enumerated() {
                 // Add interfield padding.
-                let padding_needed = offsets[*field_idx] - size;
+                let padding_needed = offsets[field_idx] - size;
                 let padding = Self::padding(padding_needed.bytes_usize());
 
-                let field_ty_and_layout = ty_and_layout.field(&cx, field_idx.as_usize());
-                let field_tree = Self::from_ty(field_ty_and_layout, cx)?;
+                let field_ty = ty_field(cx, (ty, layout), field_idx);
+                let field_layout = ty_layout(cx, field_ty);
+                let field_tree = Self::from_ty(field_ty, cx)?;
 
                 struct_tree = struct_tree.then(padding).then(field_tree);
 
-                size += padding_needed + field_ty_and_layout.size;
+                size += padding_needed + field_layout.size;
             }
 
             // Add trailing padding.
@@ -457,28 +453,27 @@ pub(crate) mod rustc {
         ///
         /// Panics if `def` is not a union definition.
         fn from_union(
-            ty_and_layout: TyAndLayout<'tcx, Ty<'tcx>>,
+            (ty, layout): (Ty<'tcx>, Layout<'tcx>),
             def: AdtDef<'tcx>,
             cx: LayoutCx<'tcx, TyCtxt<'tcx>>,
         ) -> Result<Self, Err> {
             assert!(def.is_union());
 
-            let union_layout = ty_and_layout.layout;
-
             // This constructor does not support non-`FieldsShape::Union`
             // layouts. Fields of this shape are all placed at offset 0.
-            let FieldsShape::Union(fields) = union_layout.fields() else {
+            let FieldsShape::Union(fields) = layout.fields() else {
                 return Err(Err::NotYetSupported);
             };
 
             let fields = &def.non_enum_variant().fields;
             let fields = fields.iter_enumerated().try_fold(
                 Self::uninhabited(),
-                |fields, (idx, ref field_def)| {
+                |fields, (idx, field_def)| {
                     let field_def = Def::Field(field_def);
-                    let field_ty_and_layout = ty_and_layout.field(&cx, idx.as_usize());
-                    let field = Self::from_ty(field_ty_and_layout, cx)?;
-                    let trailing_padding_needed = union_layout.size - field_ty_and_layout.size;
+                    let field_ty = ty_field(cx, (ty, layout), idx);
+                    let field_layout = ty_layout(cx, field_ty);
+                    let field = Self::from_ty(field_ty, cx)?;
+                    let trailing_padding_needed = layout.size - field_layout.size;
                     let trailing_padding = Self::padding(trailing_padding_needed.bytes_usize());
                     let field_and_padding = field.then(trailing_padding);
                     Result::<Self, Err>::Ok(fields.or(field_and_padding))
@@ -488,4 +483,44 @@ pub(crate) mod rustc {
             Ok(Self::def(Def::Adt(def)).then(fields))
         }
     }
+
+    pub(crate) fn ty_layout<'tcx>(cx: LayoutCx<'tcx, TyCtxt<'tcx>>, ty: Ty<'tcx>) -> Layout<'tcx> {
+        crate::layout::rustc::layout_of(cx, ty).unwrap()
+    }
+
+    fn ty_field<'tcx>(
+        cx: LayoutCx<'tcx, TyCtxt<'tcx>>,
+        (ty, layout): (Ty<'tcx>, Layout<'tcx>),
+        i: FieldIdx,
+    ) -> Ty<'tcx> {
+        match ty.kind() {
+            ty::Adt(def, args) => {
+                match layout.variants {
+                    Variants::Single { index } => {
+                        let field = &def.variant(index).fields[i];
+                        field.ty(cx.tcx, args)
+                    }
+                    // Discriminant field for enums (where applicable).
+                    Variants::Multiple { tag, .. } => {
+                        assert_eq!(i.as_usize(), 0);
+                        ty::layout::PrimitiveExt::to_ty(&tag.primitive(), cx.tcx)
+                    }
+                }
+            }
+            ty::Tuple(fields) => fields[i.as_usize()],
+            kind @ _ => unimplemented!(
+                "only a subset of `Ty::ty_and_layout_field`'s functionality is implemented. implementation needed for {:?}",
+                kind
+            ),
+        }
+    }
+
+    fn ty_variant<'tcx>(
+        cx: LayoutCx<'tcx, TyCtxt<'tcx>>,
+        (ty, layout): (Ty<'tcx>, Layout<'tcx>),
+        i: VariantIdx,
+    ) -> Layout<'tcx> {
+        let ty = cx.tcx.erase_regions(ty);
+        TyAndLayout { ty, layout }.for_variant(&cx, i).layout
+    }
 }
diff --git a/compiler/rustc_transmute/src/lib.rs b/compiler/rustc_transmute/src/lib.rs
index 31664ee6c4f..bdc98bcea5e 100644
--- a/compiler/rustc_transmute/src/lib.rs
+++ b/compiler/rustc_transmute/src/lib.rs
@@ -9,7 +9,7 @@ pub(crate) use rustc_data_structures::fx::{FxIndexMap as Map, FxIndexSet as Set}
 pub mod layout;
 mod maybe_transmutable;
 
-#[derive(Default)]
+#[derive(Copy, Clone, Debug, Default)]
 pub struct Assume {
     pub alignment: bool,
     pub lifetimes: bool,
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
index 7c66a827db9..1f3c4e3c817 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
@@ -30,7 +30,7 @@ where
 // FIXME: Nix this cfg, so we can write unit tests independently of rustc
 #[cfg(feature = "rustc")]
 mod rustc {
-    use rustc_middle::ty::layout::{LayoutCx, LayoutOf};
+    use rustc_middle::ty::layout::LayoutCx;
     use rustc_middle::ty::{ParamEnv, Ty, TyCtxt};
 
     use super::*;
@@ -45,10 +45,9 @@ mod rustc {
 
             let layout_cx = LayoutCx { tcx: context, param_env: ParamEnv::reveal_all() };
             let layout_of = |ty| {
-                layout_cx
-                    .layout_of(ty)
+                crate::layout::rustc::layout_of(layout_cx, ty)
                     .map_err(|_| Err::NotYetSupported)
-                    .and_then(|tl| Tree::from_ty(tl, layout_cx))
+                    .and_then(|_| Tree::from_ty(ty, layout_cx))
             };
 
             // Convert `src` and `dst` from their rustc representations, to `Tree`-based