about summary refs log tree commit diff
path: root/compiler/rustc_transmute/src
diff options
context:
space:
mode:
authorBryan Garza <1396101+bryangarza@users.noreply.github.com>2023-04-21 16:49:36 -0700
committerBryan Garza <1396101+bryangarza@users.noreply.github.com>2023-05-24 14:52:18 -0700
commit8f1cec8d8472c3ffacedd4783c64182a407c72df (patch)
tree79f6bc12055896112e0b0898c13042411638847d /compiler/rustc_transmute/src
parent97d328012b9ed9b7d481c40e84aa1f2c65b33ec8 (diff)
downloadrust-8f1cec8d8472c3ffacedd4783c64182a407c72df.tar.gz
rust-8f1cec8d8472c3ffacedd4783c64182a407c72df.zip
Safe Transmute: Enable handling references, including recursive types
This patch enables support for references in Safe Transmute, by generating
nested obligations during trait selection. Specifically, when we call
`confirm_transmutability_candidate(...)`, we now recursively traverse the
`rustc_transmute::Answer` tree and create obligations for all the `Answer`
variants, some of which include multiple nested `Answer`s.

Also, to handle recursive types, enable support for coinduction for the Safe
Transmute trait (`BikeshedIntrinsicFrom`) by adding the `#[rustc_coinduction]`
annotation.

Also fix some small logic issues when reducing the `or` and `and` combinations
in `rustc_transmute`, so that we don't end up with additional redundant
`Answer`s in the tree.

Co-authored-by: Jack Wrenn <jack@wrenn.fyi>
Diffstat (limited to 'compiler/rustc_transmute/src')
-rw-r--r--compiler/rustc_transmute/src/layout/mod.rs37
-rw-r--r--compiler/rustc_transmute/src/layout/tree.rs11
-rw-r--r--compiler/rustc_transmute/src/lib.rs11
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/mod.rs170
4 files changed, 178 insertions, 51 deletions
diff --git a/compiler/rustc_transmute/src/layout/mod.rs b/compiler/rustc_transmute/src/layout/mod.rs
index f8d05bc89d2..b318447e581 100644
--- a/compiler/rustc_transmute/src/layout/mod.rs
+++ b/compiler/rustc_transmute/src/layout/mod.rs
@@ -30,33 +30,46 @@ impl fmt::Debug for Byte {
 }
 
 pub(crate) trait Def: Debug + Hash + Eq + PartialEq + Copy + Clone {}
-pub trait Ref: Debug + Hash + Eq + PartialEq + Copy + Clone {}
+pub trait Ref: Debug + Hash + Eq + PartialEq + Copy + Clone {
+    fn min_align(&self) -> usize {
+        1
+    }
+
+    fn is_mutable(&self) -> bool {
+        false
+    }
+}
 
 impl Def for ! {}
 impl Ref for ! {}
 
 #[cfg(feature = "rustc")]
-pub(crate) mod rustc {
+pub mod rustc {
     use rustc_middle::mir::Mutability;
-    use rustc_middle::ty;
-    use rustc_middle::ty::Region;
-    use rustc_middle::ty::Ty;
+    use rustc_middle::ty::{self, Ty};
 
     /// A reference in the layout.
     #[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord, Clone, Copy)]
     pub struct Ref<'tcx> {
-        lifetime: Region<'tcx>,
-        ty: Ty<'tcx>,
-        mutability: Mutability,
+        pub lifetime: ty::Region<'tcx>,
+        pub ty: Ty<'tcx>,
+        pub mutability: Mutability,
+        pub align: usize,
     }
 
-    impl<'tcx> super::Ref for Ref<'tcx> {}
+    impl<'tcx> super::Ref for Ref<'tcx> {
+        fn min_align(&self) -> usize {
+            self.align
+        }
 
-    impl<'tcx> Ref<'tcx> {
-        pub fn min_align(&self) -> usize {
-            todo!()
+        fn is_mutable(&self) -> bool {
+            match self.mutability {
+                Mutability::Mut => true,
+                Mutability::Not => false,
+            }
         }
     }
+    impl<'tcx> Ref<'tcx> {}
 
     /// A visibility node in the layout.
     #[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)]
diff --git a/compiler/rustc_transmute/src/layout/tree.rs b/compiler/rustc_transmute/src/layout/tree.rs
index a6d88b1342a..ed9309b015d 100644
--- a/compiler/rustc_transmute/src/layout/tree.rs
+++ b/compiler/rustc_transmute/src/layout/tree.rs
@@ -365,6 +365,17 @@ pub(crate) mod rustc {
                         }
                     }))
                 }
+
+                ty::Ref(lifetime, ty, mutability) => {
+                    let align = layout_of(tcx, *ty)?.align();
+                    Ok(Tree::Ref(Ref {
+                        lifetime: *lifetime,
+                        ty: *ty,
+                        mutability: *mutability,
+                        align,
+                    }))
+                }
+
                 _ => Err(Err::Unspecified),
             }
         }
diff --git a/compiler/rustc_transmute/src/lib.rs b/compiler/rustc_transmute/src/lib.rs
index 77c0526e3aa..c4a99d9eb89 100644
--- a/compiler/rustc_transmute/src/lib.rs
+++ b/compiler/rustc_transmute/src/lib.rs
@@ -8,7 +8,7 @@ extern crate tracing;
 
 pub(crate) use rustc_data_structures::fx::{FxIndexMap as Map, FxIndexSet as Set};
 
-pub(crate) mod layout;
+pub mod layout;
 pub(crate) mod maybe_transmutable;
 
 #[derive(Default)]
@@ -21,10 +21,7 @@ pub struct Assume {
 
 /// The type encodes answers to the question: "Are these types transmutable?"
 #[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord, Clone)]
-pub enum Answer<R>
-where
-    R: layout::Ref,
-{
+pub enum Answer<R> {
     /// `Src` is transmutable into `Dst`.
     Yes,
 
@@ -54,6 +51,10 @@ pub enum Reason {
     DstIsPrivate,
     /// `Dst` is larger than `Src`, and the excess bytes were not exclusively uninitialized.
     DstIsTooBig,
+    /// Src should have a stricter alignment than Dst, but it does not.
+    DstHasStricterAlignment,
+    /// Can't go from shared pointer to unique pointer
+    DstIsMoreUnique,
 }
 
 #[cfg(feature = "rustc")]
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
index 2e2fb90e71c..d1077488c79 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
@@ -1,13 +1,13 @@
-use crate::Map;
-use crate::{Answer, Reason};
-
+pub(crate) mod query_context;
 #[cfg(test)]
 mod tests;
 
-mod query_context;
-use query_context::QueryContext;
+use crate::{
+    layout::{self, dfa, Byte, Dfa, Nfa, Ref, Tree, Uninhabited},
+    maybe_transmutable::query_context::QueryContext,
+    Answer, Map, Reason,
+};
 
-use crate::layout::{self, dfa, Byte, Dfa, Nfa, Tree, Uninhabited};
 pub(crate) struct MaybeTransmutableQuery<L, C>
 where
     C: QueryContext,
@@ -53,6 +53,7 @@ where
     }
 }
 
+// FIXME: Nix this cfg, so we can write unit tests independently of rustc
 #[cfg(feature = "rustc")]
 mod rustc {
     use super::*;
@@ -77,12 +78,11 @@ mod rustc {
                 match (src, dst) {
                     // Answer `Yes` here, because 'unknown layout' and type errors will already
                     // be reported by rustc. No need to spam the user with more errors.
-                    (Err(Err::TypeError(_)), _) => Err(Answer::Yes),
-                    (_, Err(Err::TypeError(_))) => Err(Answer::Yes),
-                    (Err(Err::Unknown), _) => Err(Answer::Yes),
-                    (_, Err(Err::Unknown)) => Err(Answer::Yes),
-                    (Err(Err::Unspecified), _) => Err(Answer::No(Reason::SrcIsUnspecified)),
-                    (_, Err(Err::Unspecified)) => Err(Answer::No(Reason::DstIsUnspecified)),
+                    (Err(Err::TypeError(_)), _) | (_, Err(Err::TypeError(_))) => Err(Answer::Yes),
+                    (Err(Err::Unknown), _) | (_, Err(Err::Unknown)) => Err(Answer::Yes),
+                    (Err(Err::Unspecified), _) | (_, Err(Err::Unspecified)) => {
+                        Err(Answer::No(Reason::SrcIsUnspecified))
+                    }
                     (Ok(src), Ok(dst)) => Ok((src, dst)),
                 }
             });
@@ -214,34 +214,99 @@ where
                     Answer::No(Reason::DstIsTooBig)
                 }
             } else {
-                let src_quantification = if self.assume.validity {
+                let src_quantifier = if self.assume.validity {
                     // if the compiler may assume that the programmer is doing additional validity checks,
                     // (e.g.: that `src != 3u8` when the destination type is `bool`)
                     // then there must exist at least one transition out of `src_state` such that the transmute is viable...
-                    there_exists
+                    Quantifier::ThereExists
                 } else {
                     // if the compiler cannot assume that the programmer is doing additional validity checks,
                     // then for all transitions out of `src_state`, such that the transmute is viable...
-                    // then there must exist at least one transition out of `src_state` such that the transmute is viable...
-                    for_all
+                    // then there must exist at least one transition out of `dst_state` such that the transmute is viable...
+                    Quantifier::ForAll
                 };
 
-                src_quantification(
-                    self.src.bytes_from(src_state).unwrap_or(&Map::default()),
-                    |(&src_validity, &src_state_prime)| {
-                        if let Some(dst_state_prime) = self.dst.byte_from(dst_state, src_validity) {
-                            self.answer_memo(cache, src_state_prime, dst_state_prime)
-                        } else if let Some(dst_state_prime) =
-                            self.dst.byte_from(dst_state, Byte::Uninit)
-                        {
-                            self.answer_memo(cache, src_state_prime, dst_state_prime)
-                        } else {
-                            Answer::No(Reason::DstIsBitIncompatible)
-                        }
-                    },
-                )
+                let bytes_answer = src_quantifier.apply(
+                    // for each of the byte transitions out of the `src_state`...
+                    self.src.bytes_from(src_state).unwrap_or(&Map::default()).into_iter().map(
+                        |(&src_validity, &src_state_prime)| {
+                            // ...try to find a matching transition out of `dst_state`.
+                            if let Some(dst_state_prime) =
+                                self.dst.byte_from(dst_state, src_validity)
+                            {
+                                self.answer_memo(cache, src_state_prime, dst_state_prime)
+                            } else if let Some(dst_state_prime) =
+                                // otherwise, see if `dst_state` has any outgoing `Uninit` transitions
+                                // (any init byte is a valid uninit byte)
+                                self.dst.byte_from(dst_state, Byte::Uninit)
+                            {
+                                self.answer_memo(cache, src_state_prime, dst_state_prime)
+                            } else {
+                                // otherwise, we've exhausted our options.
+                                // the DFAs, from this point onwards, are bit-incompatible.
+                                Answer::No(Reason::DstIsBitIncompatible)
+                            }
+                        },
+                    ),
+                );
+
+                // The below early returns reflect how this code would behave:
+                //   if self.assume.validity {
+                //       bytes_answer.or(refs_answer)
+                //   } else {
+                //       bytes_answer.and(refs_answer)
+                //   }
+                // ...if `refs_answer` was computed lazily. The below early
+                // returns can be deleted without impacting the correctness of
+                // the algoritm; only its performance.
+                match bytes_answer {
+                    Answer::No(..) if !self.assume.validity => return bytes_answer,
+                    Answer::Yes if self.assume.validity => return bytes_answer,
+                    _ => {}
+                };
+
+                let refs_answer = src_quantifier.apply(
+                    // for each reference transition out of `src_state`...
+                    self.src.refs_from(src_state).unwrap_or(&Map::default()).into_iter().map(
+                        |(&src_ref, &src_state_prime)| {
+                            // ...there exists a reference transition out of `dst_state`...
+                            Quantifier::ThereExists.apply(
+                                self.dst
+                                    .refs_from(dst_state)
+                                    .unwrap_or(&Map::default())
+                                    .into_iter()
+                                    .map(|(&dst_ref, &dst_state_prime)| {
+                                        if !src_ref.is_mutable() && dst_ref.is_mutable() {
+                                            Answer::No(Reason::DstIsMoreUnique)
+                                        } else if !self.assume.alignment
+                                            && src_ref.min_align() < dst_ref.min_align()
+                                        {
+                                            Answer::No(Reason::DstHasStricterAlignment)
+                                        } else {
+                                            // ...such that `src` is transmutable into `dst`, if
+                                            // `src_ref` is transmutability into `dst_ref`.
+                                            Answer::IfTransmutable { src: src_ref, dst: dst_ref }
+                                                .and(self.answer_memo(
+                                                    cache,
+                                                    src_state_prime,
+                                                    dst_state_prime,
+                                                ))
+                                        }
+                                    }),
+                            )
+                        },
+                    ),
+                );
+
+                if self.assume.validity {
+                    bytes_answer.or(refs_answer)
+                } else {
+                    bytes_answer.and(refs_answer)
+                }
             };
-            cache.insert((src_state, dst_state), answer.clone());
+            if let Some(..) = cache.insert((src_state, dst_state), answer.clone()) {
+                panic!("failed to correctly cache transmutability")
+            }
             answer
         }
     }
@@ -253,17 +318,21 @@ where
 {
     pub(crate) fn and(self, rhs: Self) -> Self {
         match (self, rhs) {
-            (Self::No(reason), _) | (_, Self::No(reason)) => Self::No(reason),
-            (Self::Yes, Self::Yes) => Self::Yes,
+            (_, Self::No(reason)) | (Self::No(reason), _) => Self::No(reason),
+
+            (Self::Yes, other) | (other, Self::Yes) => other,
+
             (Self::IfAll(mut lhs), Self::IfAll(ref mut rhs)) => {
                 lhs.append(rhs);
                 Self::IfAll(lhs)
             }
+
             (constraint, Self::IfAll(mut constraints))
             | (Self::IfAll(mut constraints), constraint) => {
                 constraints.push(constraint);
                 Self::IfAll(constraints)
             }
+
             (lhs, rhs) => Self::IfAll(vec![lhs, rhs]),
         }
     }
@@ -271,7 +340,7 @@ where
     pub(crate) fn or(self, rhs: Self) -> Self {
         match (self, rhs) {
             (Self::Yes, _) | (_, Self::Yes) => Self::Yes,
-            (Self::No(lhr), Self::No(rhr)) => Self::No(lhr),
+            (other, Self::No(reason)) | (Self::No(reason), other) => other,
             (Self::IfAny(mut lhs), Self::IfAny(ref mut rhs)) => {
                 lhs.append(rhs);
                 Self::IfAny(lhs)
@@ -319,3 +388,36 @@ where
     );
     result
 }
+
+pub enum Quantifier {
+    ThereExists,
+    ForAll,
+}
+
+impl Quantifier {
+    pub fn apply<R, I>(&self, iter: I) -> Answer<R>
+    where
+        R: layout::Ref,
+        I: IntoIterator<Item = Answer<R>>,
+    {
+        use std::ops::ControlFlow::{Break, Continue};
+
+        let (init, try_fold_f): (_, fn(_, _) -> _) = match self {
+            Self::ThereExists => {
+                (Answer::No(Reason::DstIsBitIncompatible), |accum: Answer<R>, next| {
+                    match accum.or(next) {
+                        Answer::Yes => Break(Answer::Yes),
+                        maybe => Continue(maybe),
+                    }
+                })
+            }
+            Self::ForAll => (Answer::Yes, |accum: Answer<R>, next| match accum.and(next) {
+                Answer::No(reason) => Break(Answer::No(reason)),
+                maybe => Continue(maybe),
+            }),
+        };
+
+        let (Continue(result) | Break(result)) = iter.into_iter().try_fold(init, try_fold_f);
+        result
+    }
+}