about summary refs log tree commit diff
path: root/compiler/rustc_transmute/src/maybe_transmutable
diff options
context:
space:
mode:
authorJoshua Liebow-Feeser <hello@joshlf.com>2025-04-10 13:45:39 -0700
committerJoshua Liebow-Feeser <hello@joshlf.com>2025-04-23 11:45:00 -0700
commit4326a44e6f0859077b7789d42416b9291b0ff4d1 (patch)
tree713ed8f802314f672e8d51725b158eb1a6c1dc9e /compiler/rustc_transmute/src/maybe_transmutable
parentbe181dd75c83d72fcc95538e235768bc367b76b9 (diff)
downloadrust-4326a44e6f0859077b7789d42416b9291b0ff4d1.tar.gz
rust-4326a44e6f0859077b7789d42416b9291b0ff4d1.zip
transmutability: Mark edges by ranges, not values
In the `Tree` and `Dfa` representations of a type's layout, store byte
ranges rather than needing to separately store each byte value. This
permits us to, for example, represent a `u8` using a single 0..=255 edge
in the DFA rather than using 256 separate edges.

This leads to drastic performance improvements. For example, on the
author's 2024 MacBook Pro, the time to convert the `Tree` representation
of a `u64` to its equivalent DFA representation drops from ~8.5ms to
~1us, a reduction of ~8,500x. See `bench_dfa_from_tree`.

Similarly, the time to execute a transmutability query from `u64` to
`u64` drops from ~35us to ~1.7us, a reduction of ~20x. See
`bench_transmute`.
Diffstat (limited to 'compiler/rustc_transmute/src/maybe_transmutable')
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/mod.rs218
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/query_context.rs14
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/tests.rs183
3 files changed, 343 insertions, 72 deletions
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
index db0e1ab8e98..0a19cccc2ed 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
@@ -1,10 +1,14 @@
+use std::rc::Rc;
+use std::{cmp, iter};
+
+use itertools::Either;
 use tracing::{debug, instrument, trace};
 
 pub(crate) mod query_context;
 #[cfg(test)]
 mod tests;
 
-use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, Uninhabited, dfa};
+use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, dfa};
 use crate::maybe_transmutable::query_context::QueryContext;
 use crate::{Answer, Condition, Map, Reason};
 
@@ -111,7 +115,7 @@ where
         // the `src` type do not exist.
         let src = match Dfa::from_tree(src) {
             Ok(src) => src,
-            Err(Uninhabited) => return Answer::Yes,
+            Err(layout::Uninhabited) => return Answer::Yes,
         };
 
         // Convert `dst` from a tree-based representation to an DFA-based
@@ -122,7 +126,7 @@ where
         // free of safety invariants.
         let dst = match Dfa::from_tree(dst) {
             Ok(dst) => dst,
-            Err(Uninhabited) => return Answer::No(Reason::DstMayHaveSafetyInvariants),
+            Err(layout::Uninhabited) => return Answer::No(Reason::DstMayHaveSafetyInvariants),
         };
 
         MaybeTransmutableQuery { src, dst, assume, context }.answer()
@@ -174,8 +178,8 @@ where
                 // are able to safely transmute, even with truncation.
                 Answer::Yes
             } else if src_state == self.src.accept {
-                // extension: `size_of(Src) >= size_of(Dst)`
-                if let Some(dst_state_prime) = self.dst.byte_from(dst_state, Byte::Uninit) {
+                // extension: `size_of(Src) <= size_of(Dst)`
+                if let Some(dst_state_prime) = self.dst.get_uninit_edge_dst(dst_state) {
                     self.answer_memo(cache, src_state, dst_state_prime)
                 } else {
                     Answer::No(Reason::DstIsTooBig)
@@ -193,26 +197,120 @@ where
                     Quantifier::ForAll
                 };
 
+                let c = &core::cell::RefCell::new(&mut *cache);
                 let bytes_answer = src_quantifier.apply(
-                    // for each of the byte transitions out of the `src_state`...
-                    self.src.bytes_from(src_state).unwrap_or(&Map::default()).into_iter().map(
-                        |(&src_validity, &src_state_prime)| {
-                            // ...try to find a matching transition out of `dst_state`.
-                            if let Some(dst_state_prime) =
-                                self.dst.byte_from(dst_state, src_validity)
-                            {
-                                self.answer_memo(cache, src_state_prime, dst_state_prime)
-                            } else if let Some(dst_state_prime) =
-                                // otherwise, see if `dst_state` has any outgoing `Uninit` transitions
-                                // (any init byte is a valid uninit byte)
-                                self.dst.byte_from(dst_state, Byte::Uninit)
-                            {
-                                self.answer_memo(cache, src_state_prime, dst_state_prime)
-                            } else {
-                                // otherwise, we've exhausted our options.
-                                // the DFAs, from this point onwards, are bit-incompatible.
-                                Answer::No(Reason::DstIsBitIncompatible)
+                    // for each of the byte set transitions out of the `src_state`...
+                    self.src.bytes_from(src_state).flat_map(
+                        move |(src_validity, src_state_prime)| {
+                            // ...find all matching transitions out of `dst_state`.
+
+                            let Some(src_validity) = src_validity.range() else {
+                                // NOTE: We construct an iterator here rather
+                                // than just computing the value directly (via
+                                // `self.answer_memo`) so that, if the iterator
+                                // we produce from this branch is
+                                // short-circuited, we don't waste time
+                                // computing `self.answer_memo` unnecessarily.
+                                // That will specifically happen if
+                                // `src_quantifier == Quantifier::ThereExists`,
+                                // since we emit `Answer::Yes` first (before
+                                // chaining `answer_iter`).
+                                let answer_iter = if let Some(dst_state_prime) =
+                                    self.dst.get_uninit_edge_dst(dst_state)
+                                {
+                                    Either::Left(iter::once_with(move || {
+                                        let mut c = c.borrow_mut();
+                                        self.answer_memo(&mut *c, src_state_prime, dst_state_prime)
+                                    }))
+                                } else {
+                                    Either::Right(iter::once(Answer::No(
+                                        Reason::DstIsBitIncompatible,
+                                    )))
+                                };
+
+                                // When `answer == Answer::No(...)`, there are
+                                // two cases to consider:
+                                // - If `assume.validity`, then we should
+                                //   succeed because the user is responsible for
+                                //   ensuring that the *specific* byte value
+                                //   appearing at runtime is valid for the
+                                //   destination type. When `assume.validity`,
+                                //   `src_quantifier ==
+                                //   Quantifier::ThereExists`, so adding an
+                                //   `Answer::Yes` has the effect of ensuring
+                                //   that the "there exists" is always
+                                //   satisfied.
+                                // - If `!assume.validity`, then we should fail.
+                                //   In this case, `src_quantifier ==
+                                //   Quantifier::ForAll`, so adding an
+                                //   `Answer::Yes` has no effect.
+                                return Either::Left(iter::once(Answer::Yes).chain(answer_iter));
+                            };
+
+                            #[derive(Copy, Clone, Debug)]
+                            struct Accum {
+                                // The number of matching byte edges that we
+                                // have found in the destination so far.
+                                sum: usize,
+                                found_uninit: bool,
                             }
+
+                            let accum1 = Rc::new(std::cell::Cell::new(Accum {
+                                sum: 0,
+                                found_uninit: false,
+                            }));
+                            let accum2 = Rc::clone(&accum1);
+                            let sv = src_validity.clone();
+                            let update_accum = move |mut accum: Accum, dst_validity: Byte| {
+                                if let Some(dst_validity) = dst_validity.range() {
+                                    // Only add the part of `dst_validity` that
+                                    // overlaps with `src_validity`.
+                                    let start = cmp::max(*sv.start(), *dst_validity.start());
+                                    let end = cmp::min(*sv.end(), *dst_validity.end());
+
+                                    // We add 1 here to account for the fact
+                                    // that `end` is an inclusive bound.
+                                    accum.sum += 1 + usize::from(end.saturating_sub(start));
+                                } else {
+                                    accum.found_uninit = true;
+                                }
+                                accum
+                            };
+
+                            let answers = self
+                                .dst
+                                .states_from(dst_state, src_validity.clone())
+                                .map(move |(dst_validity, dst_state_prime)| {
+                                    let mut c = c.borrow_mut();
+                                    accum1.set(update_accum(accum1.get(), dst_validity));
+                                    let answer =
+                                        self.answer_memo(&mut *c, src_state_prime, dst_state_prime);
+                                    answer
+                                })
+                                .chain(
+                                    iter::once_with(move || {
+                                        let src_validity_len = usize::from(*src_validity.end())
+                                            - usize::from(*src_validity.start())
+                                            + 1;
+                                        let accum = accum2.get();
+
+                                        // If this condition is false, then
+                                        // there are some byte values in the
+                                        // source which have no corresponding
+                                        // transition in the destination DFA. In
+                                        // that case, we add a `No` to our list
+                                        // of answers. When
+                                        // `!self.assume.validity`, this will
+                                        // cause the query to fail.
+                                        if accum.found_uninit || accum.sum == src_validity_len {
+                                            None
+                                        } else {
+                                            Some(Answer::No(Reason::DstIsBitIncompatible))
+                                        }
+                                    })
+                                    .flatten(),
+                                );
+                            Either::Right(answers)
                         },
                     ),
                 );
@@ -235,48 +333,38 @@ where
 
                 let refs_answer = src_quantifier.apply(
                     // for each reference transition out of `src_state`...
-                    self.src.refs_from(src_state).unwrap_or(&Map::default()).into_iter().map(
-                        |(&src_ref, &src_state_prime)| {
-                            // ...there exists a reference transition out of `dst_state`...
-                            Quantifier::ThereExists.apply(
-                                self.dst
-                                    .refs_from(dst_state)
-                                    .unwrap_or(&Map::default())
-                                    .into_iter()
-                                    .map(|(&dst_ref, &dst_state_prime)| {
-                                        if !src_ref.is_mutable() && dst_ref.is_mutable() {
-                                            Answer::No(Reason::DstIsMoreUnique)
-                                        } else if !self.assume.alignment
-                                            && src_ref.min_align() < dst_ref.min_align()
-                                        {
-                                            Answer::No(Reason::DstHasStricterAlignment {
-                                                src_min_align: src_ref.min_align(),
-                                                dst_min_align: dst_ref.min_align(),
-                                            })
-                                        } else if dst_ref.size() > src_ref.size() {
-                                            Answer::No(Reason::DstRefIsTooBig {
-                                                src: src_ref,
-                                                dst: dst_ref,
-                                            })
-                                        } else {
-                                            // ...such that `src` is transmutable into `dst`, if
-                                            // `src_ref` is transmutability into `dst_ref`.
-                                            and(
-                                                Answer::If(Condition::IfTransmutable {
-                                                    src: src_ref,
-                                                    dst: dst_ref,
-                                                }),
-                                                self.answer_memo(
-                                                    cache,
-                                                    src_state_prime,
-                                                    dst_state_prime,
-                                                ),
-                                            )
-                                        }
-                                    }),
-                            )
-                        },
-                    ),
+                    self.src.refs_from(src_state).map(|(src_ref, src_state_prime)| {
+                        // ...there exists a reference transition out of `dst_state`...
+                        Quantifier::ThereExists.apply(self.dst.refs_from(dst_state).map(
+                            |(dst_ref, dst_state_prime)| {
+                                if !src_ref.is_mutable() && dst_ref.is_mutable() {
+                                    Answer::No(Reason::DstIsMoreUnique)
+                                } else if !self.assume.alignment
+                                    && src_ref.min_align() < dst_ref.min_align()
+                                {
+                                    Answer::No(Reason::DstHasStricterAlignment {
+                                        src_min_align: src_ref.min_align(),
+                                        dst_min_align: dst_ref.min_align(),
+                                    })
+                                } else if dst_ref.size() > src_ref.size() {
+                                    Answer::No(Reason::DstRefIsTooBig {
+                                        src: src_ref,
+                                        dst: dst_ref,
+                                    })
+                                } else {
+                                    // ...such that `src` is transmutable into `dst`, if
+                                    // `src_ref` is transmutability into `dst_ref`.
+                                    and(
+                                        Answer::If(Condition::IfTransmutable {
+                                            src: src_ref,
+                                            dst: dst_ref,
+                                        }),
+                                        self.answer_memo(cache, src_state_prime, dst_state_prime),
+                                    )
+                                }
+                            },
+                        ))
+                    }),
                 );
 
                 if self.assume.validity {
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs b/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs
index f8b59bdf326..214da101be3 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs
@@ -8,9 +8,17 @@ pub(crate) trait QueryContext {
 
 #[cfg(test)]
 pub(crate) mod test {
+    use std::marker::PhantomData;
+
     use super::QueryContext;
 
-    pub(crate) struct UltraMinimal;
+    pub(crate) struct UltraMinimal<R = !>(PhantomData<R>);
+
+    impl<R> Default for UltraMinimal<R> {
+        fn default() -> Self {
+            Self(PhantomData)
+        }
+    }
 
     #[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)]
     pub(crate) enum Def {
@@ -24,9 +32,9 @@ pub(crate) mod test {
         }
     }
 
-    impl QueryContext for UltraMinimal {
+    impl<R: crate::layout::Ref> QueryContext for UltraMinimal<R> {
         type Def = Def;
-        type Ref = !;
+        type Ref = R;
     }
 }
 
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
index cc6a4dce17b..24e2a1acadd 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
@@ -1,3 +1,5 @@
+extern crate test;
+
 use itertools::Itertools;
 
 use super::query_context::test::{Def, UltraMinimal};
@@ -12,15 +14,25 @@ trait Representation {
 
 impl Representation for Tree {
     fn is_transmutable(src: Self, dst: Self, assume: Assume) -> Answer<!> {
-        crate::maybe_transmutable::MaybeTransmutableQuery::new(src, dst, assume, UltraMinimal)
-            .answer()
+        crate::maybe_transmutable::MaybeTransmutableQuery::new(
+            src,
+            dst,
+            assume,
+            UltraMinimal::default(),
+        )
+        .answer()
     }
 }
 
 impl Representation for Dfa {
     fn is_transmutable(src: Self, dst: Self, assume: Assume) -> Answer<!> {
-        crate::maybe_transmutable::MaybeTransmutableQuery::new(src, dst, assume, UltraMinimal)
-            .answer()
+        crate::maybe_transmutable::MaybeTransmutableQuery::new(
+            src,
+            dst,
+            assume,
+            UltraMinimal::default(),
+        )
+        .answer()
     }
 }
 
@@ -89,6 +101,36 @@ mod safety {
     }
 }
 
+mod size {
+    use super::*;
+
+    #[test]
+    fn size() {
+        let small = Tree::number(1);
+        let large = Tree::number(2);
+
+        for alignment in [false, true] {
+            for lifetimes in [false, true] {
+                for safety in [false, true] {
+                    for validity in [false, true] {
+                        let assume = Assume { alignment, lifetimes, safety, validity };
+                        assert_eq!(
+                            is_transmutable(&small, &large, assume),
+                            Answer::No(Reason::DstIsTooBig),
+                            "assume: {assume:?}"
+                        );
+                        assert_eq!(
+                            is_transmutable(&large, &small, assume),
+                            Answer::Yes,
+                            "assume: {assume:?}"
+                        );
+                    }
+                }
+            }
+        }
+    }
+}
+
 mod bool {
     use super::*;
 
@@ -113,6 +155,27 @@ mod bool {
     }
 
     #[test]
+    fn transmute_u8() {
+        let bool = &Tree::bool();
+        let u8 = &Tree::u8();
+        for (src, dst, assume_validity, answer) in [
+            (bool, u8, false, Answer::Yes),
+            (bool, u8, true, Answer::Yes),
+            (u8, bool, false, Answer::No(Reason::DstIsBitIncompatible)),
+            (u8, bool, true, Answer::Yes),
+        ] {
+            assert_eq!(
+                is_transmutable(
+                    src,
+                    dst,
+                    Assume { validity: assume_validity, ..Assume::default() }
+                ),
+                answer
+            );
+        }
+    }
+
+    #[test]
     fn should_permit_validity_expansion_and_reject_contraction() {
         let b0 = layout::Tree::<Def, !>::from_bits(0);
         let b1 = layout::Tree::<Def, !>::from_bits(1);
@@ -175,6 +238,62 @@ mod bool {
     }
 }
 
+mod uninit {
+    use super::*;
+
+    #[test]
+    fn size() {
+        let mu = Tree::uninit();
+        let u8 = Tree::u8();
+
+        for alignment in [false, true] {
+            for lifetimes in [false, true] {
+                for safety in [false, true] {
+                    for validity in [false, true] {
+                        let assume = Assume { alignment, lifetimes, safety, validity };
+
+                        let want = if validity {
+                            Answer::Yes
+                        } else {
+                            Answer::No(Reason::DstIsBitIncompatible)
+                        };
+
+                        assert_eq!(is_transmutable(&mu, &u8, assume), want, "assume: {assume:?}");
+                        assert_eq!(
+                            is_transmutable(&u8, &mu, assume),
+                            Answer::Yes,
+                            "assume: {assume:?}"
+                        );
+                    }
+                }
+            }
+        }
+    }
+}
+
+mod alt {
+    use super::*;
+    use crate::Answer;
+
+    #[test]
+    fn should_permit_identity_transmutation() {
+        type Tree = layout::Tree<Def, !>;
+
+        let x = Tree::Seq(vec![Tree::from_bits(0), Tree::from_bits(0)]);
+        let y = Tree::Seq(vec![Tree::bool(), Tree::from_bits(1)]);
+        let layout = Tree::Alt(vec![x, y]);
+
+        let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new(
+            layout.clone(),
+            layout.clone(),
+            crate::Assume::default(),
+            UltraMinimal::default(),
+        )
+        .answer();
+        assert_eq!(answer, Answer::Yes, "layout:{:#?}", layout);
+    }
+}
+
 mod union {
     use super::*;
 
@@ -203,3 +322,59 @@ mod union {
         assert_eq!(is_transmutable(&t, &u, Assume::default()), Answer::Yes);
     }
 }
+
+mod r#ref {
+    use super::*;
+
+    #[test]
+    fn should_permit_identity_transmutation() {
+        type Tree = crate::layout::Tree<Def, [(); 1]>;
+
+        let layout = Tree::Seq(vec![Tree::from_bits(0), Tree::Ref([()])]);
+
+        let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new(
+            layout.clone(),
+            layout,
+            Assume::default(),
+            UltraMinimal::default(),
+        )
+        .answer();
+        assert_eq!(answer, Answer::If(crate::Condition::IfTransmutable { src: [()], dst: [()] }));
+    }
+}
+
+mod benches {
+    use std::hint::black_box;
+
+    use test::Bencher;
+
+    use super::*;
+
+    #[bench]
+    fn bench_dfa_from_tree(b: &mut Bencher) {
+        let num = Tree::number(8).prune(&|_| false);
+        let num = black_box(num);
+
+        b.iter(|| {
+            let _ = black_box(Dfa::from_tree(num.clone()));
+        })
+    }
+
+    #[bench]
+    fn bench_transmute(b: &mut Bencher) {
+        let num = Tree::number(8).prune(&|_| false);
+        let dfa = black_box(Dfa::from_tree(num).unwrap());
+
+        b.iter(|| {
+            let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new(
+                dfa.clone(),
+                dfa.clone(),
+                Assume::default(),
+                UltraMinimal::default(),
+            )
+            .answer();
+            let answer = std::hint::black_box(answer);
+            assert_eq!(answer, Answer::Yes);
+        })
+    }
+}