about summary refs log tree commit diff
path: root/compiler/rustc_transmute/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_transmute/src')
-rw-r--r--compiler/rustc_transmute/src/layout/dfa.rs433
-rw-r--r--compiler/rustc_transmute/src/layout/mod.rs72
-rw-r--r--compiler/rustc_transmute/src/layout/tree.rs8
-rw-r--r--compiler/rustc_transmute/src/lib.rs3
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/mod.rs218
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/query_context.rs14
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/tests.rs183
7 files changed, 775 insertions, 156 deletions
diff --git a/compiler/rustc_transmute/src/layout/dfa.rs b/compiler/rustc_transmute/src/layout/dfa.rs
index bb909c54d2b..d1f58157b69 100644
--- a/compiler/rustc_transmute/src/layout/dfa.rs
+++ b/compiler/rustc_transmute/src/layout/dfa.rs
@@ -1,8 +1,9 @@
 use std::fmt;
+use std::ops::RangeInclusive;
 use std::sync::atomic::{AtomicU32, Ordering};
 
 use super::{Byte, Ref, Tree, Uninhabited};
-use crate::Map;
+use crate::{Map, Set};
 
 #[derive(PartialEq)]
 #[cfg_attr(test, derive(Clone))]
@@ -20,7 +21,7 @@ pub(crate) struct Transitions<R>
 where
     R: Ref,
 {
-    byte_transitions: Map<Byte, State>,
+    byte_transitions: EdgeSet<State>,
     ref_transitions: Map<R, State>,
 }
 
@@ -29,7 +30,7 @@ where
     R: Ref,
 {
     fn default() -> Self {
-        Self { byte_transitions: Map::default(), ref_transitions: Map::default() }
+        Self { byte_transitions: EdgeSet::empty(), ref_transitions: Map::default() }
     }
 }
 
@@ -56,15 +57,10 @@ where
 {
     #[cfg(test)]
     pub(crate) fn bool() -> Self {
-        let mut transitions: Map<State, Transitions<R>> = Map::default();
-        let start = State::new();
-        let accept = State::new();
-
-        transitions.entry(start).or_default().byte_transitions.insert(Byte::Init(0x00), accept);
-
-        transitions.entry(start).or_default().byte_transitions.insert(Byte::Init(0x01), accept);
-
-        Self { transitions, start, accept }
+        Self::from_transitions(|accept| Transitions {
+            byte_transitions: EdgeSet::new(Byte::new(0x00..=0x01), accept),
+            ref_transitions: Map::default(),
+        })
     }
 
     pub(crate) fn unit() -> Self {
@@ -76,23 +72,24 @@ where
     }
 
     pub(crate) fn from_byte(byte: Byte) -> Self {
-        let mut transitions: Map<State, Transitions<R>> = Map::default();
-        let start = State::new();
-        let accept = State::new();
-
-        transitions.entry(start).or_default().byte_transitions.insert(byte, accept);
-
-        Self { transitions, start, accept }
+        Self::from_transitions(|accept| Transitions {
+            byte_transitions: EdgeSet::new(byte, accept),
+            ref_transitions: Map::default(),
+        })
     }
 
     pub(crate) fn from_ref(r: R) -> Self {
-        let mut transitions: Map<State, Transitions<R>> = Map::default();
+        Self::from_transitions(|accept| Transitions {
+            byte_transitions: EdgeSet::empty(),
+            ref_transitions: [(r, accept)].into_iter().collect(),
+        })
+    }
+
+    fn from_transitions(f: impl FnOnce(State) -> Transitions<R>) -> Self {
         let start = State::new();
         let accept = State::new();
 
-        transitions.entry(start).or_default().ref_transitions.insert(r, accept);
-
-        Self { transitions, start, accept }
+        Self { transitions: [(start, f(accept))].into_iter().collect(), start, accept }
     }
 
     pub(crate) fn from_tree(tree: Tree<!, R>) -> Result<Self, Uninhabited> {
@@ -132,13 +129,16 @@ where
 
         for (source, transition) in other.transitions {
             let fix_state = |state| if state == other.start { self.accept } else { state };
-            let entry = transitions.entry(fix_state(source)).or_default();
-            for (edge, destination) in transition.byte_transitions {
-                entry.byte_transitions.insert(edge, fix_state(destination));
-            }
-            for (edge, destination) in transition.ref_transitions {
-                entry.ref_transitions.insert(edge, fix_state(destination));
-            }
+            let byte_transitions = transition.byte_transitions.map_states(&fix_state);
+            let ref_transitions = transition
+                .ref_transitions
+                .into_iter()
+                .map(|(r, state)| (r, fix_state(state)))
+                .collect();
+
+            let old = transitions
+                .insert(fix_state(source), Transitions { byte_transitions, ref_transitions });
+            assert!(old.is_none());
         }
 
         Self { transitions, start, accept }
@@ -170,67 +170,111 @@ where
 
         let start = mapped((Some(a.start), Some(b.start)));
         let mut transitions: Map<State, Transitions<R>> = Map::default();
-        let mut queue = vec![(Some(a.start), Some(b.start))];
         let empty_transitions = Transitions::default();
 
-        while let Some((a_src, b_src)) = queue.pop() {
+        struct WorkQueue {
+            queue: Vec<(Option<State>, Option<State>)>,
+            // Track all entries ever enqueued to avoid duplicating work. This
+            // gives us a guarantee that a given (a_state, b_state) pair will
+            // only ever be visited once.
+            enqueued: Set<(Option<State>, Option<State>)>,
+        }
+        impl WorkQueue {
+            fn enqueue(&mut self, a_state: Option<State>, b_state: Option<State>) {
+                if self.enqueued.insert((a_state, b_state)) {
+                    self.queue.push((a_state, b_state));
+                }
+            }
+        }
+        let mut queue = WorkQueue { queue: Vec::new(), enqueued: Set::default() };
+        queue.enqueue(Some(a.start), Some(b.start));
+
+        while let Some((a_src, b_src)) = queue.queue.pop() {
+            let src = mapped((a_src, b_src));
+            if src == accept {
+                // While it's possible to have a DFA whose accept state has
+                // out-edges, these do not affect the semantics of the DFA, and
+                // so there's no point in processing them. Continuing here also
+                // has the advantage of guaranteeing that we only ever process a
+                // given node in the output DFA once. In particular, with the
+                // exception of the accept state, we ensure that we only push a
+                // given node to the `queue` once. This allows the following
+                // code to assume that we're processing a node we've never
+                // processed before, which means we never need to merge two edge
+                // sets - we only ever need to construct a new edge set from
+                // whole cloth.
+                continue;
+            }
+
             let a_transitions =
                 a_src.and_then(|a_src| a.transitions.get(&a_src)).unwrap_or(&empty_transitions);
             let b_transitions =
                 b_src.and_then(|b_src| b.transitions.get(&b_src)).unwrap_or(&empty_transitions);
 
             let byte_transitions =
-                a_transitions.byte_transitions.keys().chain(b_transitions.byte_transitions.keys());
-
-            for byte_transition in byte_transitions {
-                let a_dst = a_transitions.byte_transitions.get(byte_transition).copied();
-                let b_dst = b_transitions.byte_transitions.get(byte_transition).copied();
+                a_transitions.byte_transitions.union(&b_transitions.byte_transitions);
 
+            let byte_transitions = byte_transitions.map_states(|(a_dst, b_dst)| {
                 assert!(a_dst.is_some() || b_dst.is_some());
 
-                let src = mapped((a_src, b_src));
-                let dst = mapped((a_dst, b_dst));
-
-                transitions.entry(src).or_default().byte_transitions.insert(*byte_transition, dst);
-
-                if !transitions.contains_key(&dst) {
-                    queue.push((a_dst, b_dst))
-                }
-            }
+                queue.enqueue(a_dst, b_dst);
+                mapped((a_dst, b_dst))
+            });
 
             let ref_transitions =
                 a_transitions.ref_transitions.keys().chain(b_transitions.ref_transitions.keys());
 
-            for ref_transition in ref_transitions {
-                let a_dst = a_transitions.ref_transitions.get(ref_transition).copied();
-                let b_dst = b_transitions.ref_transitions.get(ref_transition).copied();
+            let ref_transitions = ref_transitions
+                .map(|ref_transition| {
+                    let a_dst = a_transitions.ref_transitions.get(ref_transition).copied();
+                    let b_dst = b_transitions.ref_transitions.get(ref_transition).copied();
 
-                assert!(a_dst.is_some() || b_dst.is_some());
-
-                let src = mapped((a_src, b_src));
-                let dst = mapped((a_dst, b_dst));
+                    assert!(a_dst.is_some() || b_dst.is_some());
 
-                transitions.entry(src).or_default().ref_transitions.insert(*ref_transition, dst);
+                    queue.enqueue(a_dst, b_dst);
+                    (*ref_transition, mapped((a_dst, b_dst)))
+                })
+                .collect();
 
-                if !transitions.contains_key(&dst) {
-                    queue.push((a_dst, b_dst))
-                }
-            }
+            let old = transitions.insert(src, Transitions { byte_transitions, ref_transitions });
+            // See `if src == accept { ... }` above. The comment there explains
+            // why this assert is valid.
+            assert_eq!(old, None);
         }
 
         Self { transitions, start, accept }
     }
 
-    pub(crate) fn bytes_from(&self, start: State) -> Option<&Map<Byte, State>> {
-        Some(&self.transitions.get(&start)?.byte_transitions)
+    pub(crate) fn states_from(
+        &self,
+        state: State,
+        src_validity: RangeInclusive<u8>,
+    ) -> impl Iterator<Item = (Byte, State)> {
+        self.transitions
+            .get(&state)
+            .map(move |t| t.byte_transitions.states_from(src_validity))
+            .into_iter()
+            .flatten()
+    }
+
+    pub(crate) fn get_uninit_edge_dst(&self, state: State) -> Option<State> {
+        let transitions = self.transitions.get(&state)?;
+        transitions.byte_transitions.get_uninit_edge_dst()
     }
 
-    pub(crate) fn byte_from(&self, start: State, byte: Byte) -> Option<State> {
-        self.transitions.get(&start)?.byte_transitions.get(&byte).copied()
+    pub(crate) fn bytes_from(&self, start: State) -> impl Iterator<Item = (Byte, State)> {
+        self.transitions
+            .get(&start)
+            .into_iter()
+            .flat_map(|transitions| transitions.byte_transitions.iter())
     }
 
-    pub(crate) fn refs_from(&self, start: State) -> Option<&Map<R, State>> {
-        Some(&self.transitions.get(&start)?.ref_transitions)
+    pub(crate) fn refs_from(&self, start: State) -> impl Iterator<Item = (R, State)> {
+        self.transitions
+            .get(&start)
+            .into_iter()
+            .flat_map(|transitions| transitions.ref_transitions.iter())
+            .map(|(r, s)| (*r, *s))
     }
 
     #[cfg(test)]
@@ -241,15 +285,25 @@ where
     ) -> Self {
         let start = State(start);
         let accept = State(accept);
-        let mut transitions: Map<State, Transitions<R>> = Map::default();
+        let mut transitions: Map<State, Vec<(Byte, State)>> = Map::default();
 
-        for &(src, edge, dst) in edges {
-            let src = State(src);
-            let dst = State(dst);
-            let old = transitions.entry(src).or_default().byte_transitions.insert(edge.into(), dst);
-            assert!(old.is_none());
+        for (src, edge, dst) in edges.iter().copied() {
+            transitions.entry(State(src)).or_default().push((edge.into(), State(dst)));
         }
 
+        let transitions = transitions
+            .into_iter()
+            .map(|(src, edges)| {
+                (
+                    src,
+                    Transitions {
+                        byte_transitions: EdgeSet::from_edges(edges),
+                        ref_transitions: Map::default(),
+                    },
+                )
+            })
+            .collect();
+
         Self { start, accept, transitions }
     }
 }
@@ -277,3 +331,242 @@ where
         writeln!(f, "}}")
     }
 }
+
+use edge_set::EdgeSet;
+mod edge_set {
+    use std::cmp;
+
+    use run::*;
+    use smallvec::{SmallVec, smallvec};
+
+    use super::*;
+    mod run {
+        use std::ops::{Range, RangeInclusive};
+
+        use super::*;
+        use crate::layout::Byte;
+
+        /// A logical set of edges.
+        ///
+        /// A `Run` encodes one edge for every byte value in `start..=end`
+        /// pointing to `dst`.
+        #[derive(Eq, PartialEq, Copy, Clone, Debug)]
+        pub(super) struct Run<S> {
+            // `start` and `end` are both inclusive (ie, closed) bounds, as this
+            // is required in order to be able to store 0..=255. We provide
+            // setters and getters which operate on closed/open ranges, which
+            // are more intuitive and easier for performing offset math.
+            start: u8,
+            end: u8,
+            pub(super) dst: S,
+        }
+
+        impl<S> Run<S> {
+            pub(super) fn new(range: RangeInclusive<u8>, dst: S) -> Self {
+                Self { start: *range.start(), end: *range.end(), dst }
+            }
+
+            pub(super) fn from_inclusive_exclusive(range: Range<u16>, dst: S) -> Self {
+                Self {
+                    start: range.start.try_into().unwrap(),
+                    end: (range.end - 1).try_into().unwrap(),
+                    dst,
+                }
+            }
+
+            pub(super) fn contains(&self, idx: u16) -> bool {
+                idx >= u16::from(self.start) && idx <= u16::from(self.end)
+            }
+
+            pub(super) fn as_inclusive_exclusive(&self) -> (u16, u16) {
+                (u16::from(self.start), u16::from(self.end) + 1)
+            }
+
+            pub(super) fn as_byte(&self) -> Byte {
+                Byte::new(self.start..=self.end)
+            }
+
+            pub(super) fn map_state<SS>(self, f: impl FnOnce(S) -> SS) -> Run<SS> {
+                let Run { start, end, dst } = self;
+                Run { start, end, dst: f(dst) }
+            }
+
+            /// Produces a new `Run` whose lower bound is the greater of
+            /// `self`'s existing lower bound and `lower_bound`.
+            pub(super) fn clamp_lower(self, lower_bound: u8) -> Self {
+                let Run { start, end, dst } = self;
+                Run { start: cmp::max(start, lower_bound), end, dst }
+            }
+        }
+    }
+
+    /// The set of outbound byte edges associated with a DFA node (not including
+    /// reference edges).
+    #[derive(Eq, PartialEq, Clone, Debug)]
+    pub(super) struct EdgeSet<S = State> {
+        // A sequence of runs stored in ascending order. Since the graph is a
+        // DFA, these must be non-overlapping with one another.
+        runs: SmallVec<[Run<S>; 1]>,
+        // The edge labeled with the uninit byte, if any.
+        //
+        // FIXME(@joshlf): Make `State` a `NonZero` so that this is NPO'd.
+        uninit: Option<S>,
+    }
+
+    impl<S> EdgeSet<S> {
+        pub(crate) fn new(byte: Byte, dst: S) -> Self {
+            match byte.range() {
+                Some(range) => Self { runs: smallvec![Run::new(range, dst)], uninit: None },
+                None => Self { runs: SmallVec::new(), uninit: Some(dst) },
+            }
+        }
+
+        pub(crate) fn empty() -> Self {
+            Self { runs: SmallVec::new(), uninit: None }
+        }
+
+        #[cfg(test)]
+        pub(crate) fn from_edges(mut edges: Vec<(Byte, S)>) -> Self
+        where
+            S: Ord,
+        {
+            edges.sort();
+            Self {
+                runs: edges
+                    .into_iter()
+                    .map(|(byte, state)| Run::new(byte.range().unwrap(), state))
+                    .collect(),
+                uninit: None,
+            }
+        }
+
+        pub(crate) fn iter(&self) -> impl Iterator<Item = (Byte, S)>
+        where
+            S: Copy,
+        {
+            self.uninit
+                .map(|dst| (Byte::uninit(), dst))
+                .into_iter()
+                .chain(self.runs.iter().map(|run| (run.as_byte(), run.dst)))
+        }
+
+        pub(crate) fn states_from(
+            &self,
+            byte: RangeInclusive<u8>,
+        ) -> impl Iterator<Item = (Byte, S)>
+        where
+            S: Copy,
+        {
+            // FIXME(@joshlf): Optimize this. A manual scan over `self.runs` may
+            // permit us to more efficiently discard runs which will not be
+            // produced by this iterator.
+            self.iter().filter(move |(o, _)| Byte::new(byte.clone()).transmutable_into(&o))
+        }
+
+        pub(crate) fn get_uninit_edge_dst(&self) -> Option<S>
+        where
+            S: Copy,
+        {
+            self.uninit
+        }
+
+        pub(crate) fn map_states<SS>(self, mut f: impl FnMut(S) -> SS) -> EdgeSet<SS> {
+            EdgeSet {
+                // NOTE: It appears as through `<Vec<_> as
+                // IntoIterator>::IntoIter` and `std::iter::Map` both implement
+                // `TrustedLen`, which in turn means that this `.collect()`
+                // allocates the correct number of elements once up-front [1].
+                //
+                // [1] https://doc.rust-lang.org/1.85.0/src/alloc/vec/spec_from_iter_nested.rs.html#47
+                runs: self.runs.into_iter().map(|run| run.map_state(&mut f)).collect(),
+                uninit: self.uninit.map(f),
+            }
+        }
+
+        /// Unions two edge sets together.
+        ///
+        /// If `u = a.union(b)`, then for each byte value, `u` will have an edge
+        /// with that byte value and with the destination `(Some(_), None)`,
+        /// `(None, Some(_))`, or `(Some(_), Some(_))` depending on whether `a`,
+        /// `b`, or both have an edge with that byte value.
+        ///
+        /// If neither `a` nor `b` have an edge with a particular byte value,
+        /// then no edge with that value will be present in `u`.
+        pub(crate) fn union(&self, other: &Self) -> EdgeSet<(Option<S>, Option<S>)>
+        where
+            S: Copy,
+        {
+            let uninit = match (self.uninit, other.uninit) {
+                (None, None) => None,
+                (s, o) => Some((s, o)),
+            };
+
+            let mut runs = SmallVec::new();
+
+            // Iterate over `self.runs` and `other.runs` simultaneously,
+            // advancing `idx` as we go. At each step, we advance `idx` as far
+            // as we can without crossing a run boundary in either `self.runs`
+            // or `other.runs`.
+
+            // INVARIANT: `idx < s[0].end && idx < o[0].end`.
+            let (mut s, mut o) = (self.runs.as_slice(), other.runs.as_slice());
+            let mut idx = 0u16;
+            while let (Some((s_run, s_rest)), Some((o_run, o_rest))) =
+                (s.split_first(), o.split_first())
+            {
+                let (s_start, s_end) = s_run.as_inclusive_exclusive();
+                let (o_start, o_end) = o_run.as_inclusive_exclusive();
+
+                // Compute `end` as the end of the current run (which starts
+                // with `idx`).
+                let (end, dst) = match (s_run.contains(idx), o_run.contains(idx)) {
+                    // `idx` is in an existing run in both `s` and `o`, so `end`
+                    // is equal to the smallest of the two ends of those runs.
+                    (true, true) => (cmp::min(s_end, o_end), (Some(s_run.dst), Some(o_run.dst))),
+                    // `idx` is in an existing run in `s`, but not in any run in
+                    // `o`. `end` is either the end of the `s` run or the
+                    // beginning of the next `o` run, whichever comes first.
+                    (true, false) => (cmp::min(s_end, o_start), (Some(s_run.dst), None)),
+                    // The inverse of the previous case.
+                    (false, true) => (cmp::min(s_start, o_end), (None, Some(o_run.dst))),
+                    // `idx` is not in a run in either `s` or `o`, so advance it
+                    // to the beginning of the next run.
+                    (false, false) => {
+                        idx = cmp::min(s_start, o_start);
+                        continue;
+                    }
+                };
+
+                // FIXME(@joshlf): If this is contiguous with the previous run
+                // and has the same `dst`, just merge it into that run rather
+                // than adding a new one.
+                runs.push(Run::from_inclusive_exclusive(idx..end, dst));
+                idx = end;
+
+                if idx >= s_end {
+                    s = s_rest;
+                }
+                if idx >= o_end {
+                    o = o_rest;
+                }
+            }
+
+            // At this point, either `s` or `o` have been exhausted, so the
+            // remaining elements in the other slice are guaranteed to be
+            // non-overlapping. We can add all remaining runs to `runs` with no
+            // further processing.
+            if let Ok(idx) = u8::try_from(idx) {
+                let (slc, map) = if !s.is_empty() {
+                    let map: fn(_) -> _ = |st| (Some(st), None);
+                    (s, map)
+                } else {
+                    let map: fn(_) -> _ = |st| (None, Some(st));
+                    (o, map)
+                };
+                runs.extend(slc.iter().map(|run| run.clamp_lower(idx).map_state(map)));
+            }
+
+            EdgeSet { runs, uninit }
+        }
+    }
+}
diff --git a/compiler/rustc_transmute/src/layout/mod.rs b/compiler/rustc_transmute/src/layout/mod.rs
index c940f7c42a8..d555ea702a9 100644
--- a/compiler/rustc_transmute/src/layout/mod.rs
+++ b/compiler/rustc_transmute/src/layout/mod.rs
@@ -1,5 +1,6 @@
 use std::fmt::{self, Debug};
 use std::hash::Hash;
+use std::ops::RangeInclusive;
 
 pub(crate) mod tree;
 pub(crate) use tree::Tree;
@@ -10,18 +11,56 @@ pub(crate) use dfa::Dfa;
 #[derive(Debug)]
 pub(crate) struct Uninhabited;
 
-/// An instance of a byte is either initialized to a particular value, or uninitialized.
-#[derive(Hash, Eq, PartialEq, Clone, Copy)]
-pub(crate) enum Byte {
-    Uninit,
-    Init(u8),
+/// A range of byte values, or the uninit byte.
+#[derive(Hash, Eq, PartialEq, Ord, PartialOrd, Clone, Copy)]
+pub(crate) struct Byte {
+    // An inclusive-inclusive range. We use this instead of `RangeInclusive`
+    // because `RangeInclusive: !Copy`.
+    //
+    // `None` means uninit.
+    //
+    // FIXME(@joshlf): Optimize this representation. Some pairs of values (where
+    // `lo > hi`) are illegal, and we could use these to represent `None`.
+    range: Option<(u8, u8)>,
+}
+
+impl Byte {
+    fn new(range: RangeInclusive<u8>) -> Self {
+        Self { range: Some((*range.start(), *range.end())) }
+    }
+
+    fn from_val(val: u8) -> Self {
+        Self { range: Some((val, val)) }
+    }
+
+    pub(crate) fn uninit() -> Byte {
+        Byte { range: None }
+    }
+
+    /// Returns `None` if `self` is the uninit byte.
+    pub(crate) fn range(&self) -> Option<RangeInclusive<u8>> {
+        self.range.map(|(lo, hi)| lo..=hi)
+    }
+
+    /// Are any of the values in `self` transmutable into `other`?
+    ///
+    /// Note two special cases: An uninit byte is only transmutable into another
+    /// uninit byte. Any byte is transmutable into an uninit byte.
+    pub(crate) fn transmutable_into(&self, other: &Byte) -> bool {
+        match (self.range, other.range) {
+            (None, None) => true,
+            (None, Some(_)) => false,
+            (Some(_), None) => true,
+            (Some((slo, shi)), Some((olo, ohi))) => slo <= ohi && olo <= shi,
+        }
+    }
 }
 
 impl fmt::Debug for Byte {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        match &self {
-            Self::Uninit => f.write_str("??u8"),
-            Self::Init(b) => write!(f, "{b:#04x}u8"),
+        match self.range {
+            None => write!(f, "uninit"),
+            Some((lo, hi)) => write!(f, "{lo}..={hi}"),
         }
     }
 }
@@ -29,7 +68,7 @@ impl fmt::Debug for Byte {
 #[cfg(test)]
 impl From<u8> for Byte {
     fn from(src: u8) -> Self {
-        Self::Init(src)
+        Self::from_val(src)
     }
 }
 
@@ -62,6 +101,21 @@ impl Ref for ! {
     }
 }
 
+#[cfg(test)]
+impl<const N: usize> Ref for [(); N] {
+    fn min_align(&self) -> usize {
+        N
+    }
+
+    fn size(&self) -> usize {
+        N
+    }
+
+    fn is_mutable(&self) -> bool {
+        false
+    }
+}
+
 #[cfg(feature = "rustc")]
 pub mod rustc {
     use std::fmt::{self, Write};
diff --git a/compiler/rustc_transmute/src/layout/tree.rs b/compiler/rustc_transmute/src/layout/tree.rs
index 70ecc75403f..6a09be18ef9 100644
--- a/compiler/rustc_transmute/src/layout/tree.rs
+++ b/compiler/rustc_transmute/src/layout/tree.rs
@@ -54,22 +54,22 @@ where
 
     /// A `Tree` containing a single, uninitialized byte.
     pub(crate) fn uninit() -> Self {
-        Self::Byte(Byte::Uninit)
+        Self::Byte(Byte::uninit())
     }
 
     /// A `Tree` representing the layout of `bool`.
     pub(crate) fn bool() -> Self {
-        Self::from_bits(0x00).or(Self::from_bits(0x01))
+        Self::Byte(Byte::new(0x00..=0x01))
     }
 
     /// A `Tree` whose layout matches that of a `u8`.
     pub(crate) fn u8() -> Self {
-        Self::Alt((0u8..=255).map(Self::from_bits).collect())
+        Self::Byte(Byte::new(0x00..=0xFF))
     }
 
     /// A `Tree` whose layout accepts exactly the given bit pattern.
     pub(crate) fn from_bits(bits: u8) -> Self {
-        Self::Byte(Byte::Init(bits))
+        Self::Byte(Byte::from_val(bits))
     }
 
     /// A `Tree` whose layout is a number of the given width.
diff --git a/compiler/rustc_transmute/src/lib.rs b/compiler/rustc_transmute/src/lib.rs
index 76fa6ceabe7..ce18dad5517 100644
--- a/compiler/rustc_transmute/src/lib.rs
+++ b/compiler/rustc_transmute/src/lib.rs
@@ -1,8 +1,9 @@
 // tidy-alphabetical-start
+#![cfg_attr(test, feature(test))]
 #![feature(never_type)]
 // tidy-alphabetical-end
 
-pub(crate) use rustc_data_structures::fx::FxIndexMap as Map;
+pub(crate) use rustc_data_structures::fx::{FxIndexMap as Map, FxIndexSet as Set};
 
 pub mod layout;
 mod maybe_transmutable;
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
index db0e1ab8e98..0a19cccc2ed 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
@@ -1,10 +1,14 @@
+use std::rc::Rc;
+use std::{cmp, iter};
+
+use itertools::Either;
 use tracing::{debug, instrument, trace};
 
 pub(crate) mod query_context;
 #[cfg(test)]
 mod tests;
 
-use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, Uninhabited, dfa};
+use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, dfa};
 use crate::maybe_transmutable::query_context::QueryContext;
 use crate::{Answer, Condition, Map, Reason};
 
@@ -111,7 +115,7 @@ where
         // the `src` type do not exist.
         let src = match Dfa::from_tree(src) {
             Ok(src) => src,
-            Err(Uninhabited) => return Answer::Yes,
+            Err(layout::Uninhabited) => return Answer::Yes,
         };
 
         // Convert `dst` from a tree-based representation to an DFA-based
@@ -122,7 +126,7 @@ where
         // free of safety invariants.
         let dst = match Dfa::from_tree(dst) {
             Ok(dst) => dst,
-            Err(Uninhabited) => return Answer::No(Reason::DstMayHaveSafetyInvariants),
+            Err(layout::Uninhabited) => return Answer::No(Reason::DstMayHaveSafetyInvariants),
         };
 
         MaybeTransmutableQuery { src, dst, assume, context }.answer()
@@ -174,8 +178,8 @@ where
                 // are able to safely transmute, even with truncation.
                 Answer::Yes
             } else if src_state == self.src.accept {
-                // extension: `size_of(Src) >= size_of(Dst)`
-                if let Some(dst_state_prime) = self.dst.byte_from(dst_state, Byte::Uninit) {
+                // extension: `size_of(Src) <= size_of(Dst)`
+                if let Some(dst_state_prime) = self.dst.get_uninit_edge_dst(dst_state) {
                     self.answer_memo(cache, src_state, dst_state_prime)
                 } else {
                     Answer::No(Reason::DstIsTooBig)
@@ -193,26 +197,120 @@ where
                     Quantifier::ForAll
                 };
 
+                let c = &core::cell::RefCell::new(&mut *cache);
                 let bytes_answer = src_quantifier.apply(
-                    // for each of the byte transitions out of the `src_state`...
-                    self.src.bytes_from(src_state).unwrap_or(&Map::default()).into_iter().map(
-                        |(&src_validity, &src_state_prime)| {
-                            // ...try to find a matching transition out of `dst_state`.
-                            if let Some(dst_state_prime) =
-                                self.dst.byte_from(dst_state, src_validity)
-                            {
-                                self.answer_memo(cache, src_state_prime, dst_state_prime)
-                            } else if let Some(dst_state_prime) =
-                                // otherwise, see if `dst_state` has any outgoing `Uninit` transitions
-                                // (any init byte is a valid uninit byte)
-                                self.dst.byte_from(dst_state, Byte::Uninit)
-                            {
-                                self.answer_memo(cache, src_state_prime, dst_state_prime)
-                            } else {
-                                // otherwise, we've exhausted our options.
-                                // the DFAs, from this point onwards, are bit-incompatible.
-                                Answer::No(Reason::DstIsBitIncompatible)
+                    // for each of the byte set transitions out of the `src_state`...
+                    self.src.bytes_from(src_state).flat_map(
+                        move |(src_validity, src_state_prime)| {
+                            // ...find all matching transitions out of `dst_state`.
+
+                            let Some(src_validity) = src_validity.range() else {
+                                // NOTE: We construct an iterator here rather
+                                // than just computing the value directly (via
+                                // `self.answer_memo`) so that, if the iterator
+                                // we produce from this branch is
+                                // short-circuited, we don't waste time
+                                // computing `self.answer_memo` unnecessarily.
+                                // That will specifically happen if
+                                // `src_quantifier == Quantifier::ThereExists`,
+                                // since we emit `Answer::Yes` first (before
+                                // chaining `answer_iter`).
+                                let answer_iter = if let Some(dst_state_prime) =
+                                    self.dst.get_uninit_edge_dst(dst_state)
+                                {
+                                    Either::Left(iter::once_with(move || {
+                                        let mut c = c.borrow_mut();
+                                        self.answer_memo(&mut *c, src_state_prime, dst_state_prime)
+                                    }))
+                                } else {
+                                    Either::Right(iter::once(Answer::No(
+                                        Reason::DstIsBitIncompatible,
+                                    )))
+                                };
+
+                                // When `answer == Answer::No(...)`, there are
+                                // two cases to consider:
+                                // - If `assume.validity`, then we should
+                                //   succeed because the user is responsible for
+                                //   ensuring that the *specific* byte value
+                                //   appearing at runtime is valid for the
+                                //   destination type. When `assume.validity`,
+                                //   `src_quantifier ==
+                                //   Quantifier::ThereExists`, so adding an
+                                //   `Answer::Yes` has the effect of ensuring
+                                //   that the "there exists" is always
+                                //   satisfied.
+                                // - If `!assume.validity`, then we should fail.
+                                //   In this case, `src_quantifier ==
+                                //   Quantifier::ForAll`, so adding an
+                                //   `Answer::Yes` has no effect.
+                                return Either::Left(iter::once(Answer::Yes).chain(answer_iter));
+                            };
+
+                            #[derive(Copy, Clone, Debug)]
+                            struct Accum {
+                                // The number of matching byte edges that we
+                                // have found in the destination so far.
+                                sum: usize,
+                                found_uninit: bool,
                             }
+
+                            let accum1 = Rc::new(std::cell::Cell::new(Accum {
+                                sum: 0,
+                                found_uninit: false,
+                            }));
+                            let accum2 = Rc::clone(&accum1);
+                            let sv = src_validity.clone();
+                            let update_accum = move |mut accum: Accum, dst_validity: Byte| {
+                                if let Some(dst_validity) = dst_validity.range() {
+                                    // Only add the part of `dst_validity` that
+                                    // overlaps with `src_validity`.
+                                    let start = cmp::max(*sv.start(), *dst_validity.start());
+                                    let end = cmp::min(*sv.end(), *dst_validity.end());
+
+                                    // We add 1 here to account for the fact
+                                    // that `end` is an inclusive bound.
+                                    accum.sum += 1 + usize::from(end.saturating_sub(start));
+                                } else {
+                                    accum.found_uninit = true;
+                                }
+                                accum
+                            };
+
+                            let answers = self
+                                .dst
+                                .states_from(dst_state, src_validity.clone())
+                                .map(move |(dst_validity, dst_state_prime)| {
+                                    let mut c = c.borrow_mut();
+                                    accum1.set(update_accum(accum1.get(), dst_validity));
+                                    let answer =
+                                        self.answer_memo(&mut *c, src_state_prime, dst_state_prime);
+                                    answer
+                                })
+                                .chain(
+                                    iter::once_with(move || {
+                                        let src_validity_len = usize::from(*src_validity.end())
+                                            - usize::from(*src_validity.start())
+                                            + 1;
+                                        let accum = accum2.get();
+
+                                        // If this condition is false, then
+                                        // there are some byte values in the
+                                        // source which have no corresponding
+                                        // transition in the destination DFA. In
+                                        // that case, we add a `No` to our list
+                                        // of answers. When
+                                        // `!self.assume.validity`, this will
+                                        // cause the query to fail.
+                                        if accum.found_uninit || accum.sum == src_validity_len {
+                                            None
+                                        } else {
+                                            Some(Answer::No(Reason::DstIsBitIncompatible))
+                                        }
+                                    })
+                                    .flatten(),
+                                );
+                            Either::Right(answers)
                         },
                     ),
                 );
@@ -235,48 +333,38 @@ where
 
                 let refs_answer = src_quantifier.apply(
                     // for each reference transition out of `src_state`...
-                    self.src.refs_from(src_state).unwrap_or(&Map::default()).into_iter().map(
-                        |(&src_ref, &src_state_prime)| {
-                            // ...there exists a reference transition out of `dst_state`...
-                            Quantifier::ThereExists.apply(
-                                self.dst
-                                    .refs_from(dst_state)
-                                    .unwrap_or(&Map::default())
-                                    .into_iter()
-                                    .map(|(&dst_ref, &dst_state_prime)| {
-                                        if !src_ref.is_mutable() && dst_ref.is_mutable() {
-                                            Answer::No(Reason::DstIsMoreUnique)
-                                        } else if !self.assume.alignment
-                                            && src_ref.min_align() < dst_ref.min_align()
-                                        {
-                                            Answer::No(Reason::DstHasStricterAlignment {
-                                                src_min_align: src_ref.min_align(),
-                                                dst_min_align: dst_ref.min_align(),
-                                            })
-                                        } else if dst_ref.size() > src_ref.size() {
-                                            Answer::No(Reason::DstRefIsTooBig {
-                                                src: src_ref,
-                                                dst: dst_ref,
-                                            })
-                                        } else {
-                                            // ...such that `src` is transmutable into `dst`, if
-                                            // `src_ref` is transmutability into `dst_ref`.
-                                            and(
-                                                Answer::If(Condition::IfTransmutable {
-                                                    src: src_ref,
-                                                    dst: dst_ref,
-                                                }),
-                                                self.answer_memo(
-                                                    cache,
-                                                    src_state_prime,
-                                                    dst_state_prime,
-                                                ),
-                                            )
-                                        }
-                                    }),
-                            )
-                        },
-                    ),
+                    self.src.refs_from(src_state).map(|(src_ref, src_state_prime)| {
+                        // ...there exists a reference transition out of `dst_state`...
+                        Quantifier::ThereExists.apply(self.dst.refs_from(dst_state).map(
+                            |(dst_ref, dst_state_prime)| {
+                                if !src_ref.is_mutable() && dst_ref.is_mutable() {
+                                    Answer::No(Reason::DstIsMoreUnique)
+                                } else if !self.assume.alignment
+                                    && src_ref.min_align() < dst_ref.min_align()
+                                {
+                                    Answer::No(Reason::DstHasStricterAlignment {
+                                        src_min_align: src_ref.min_align(),
+                                        dst_min_align: dst_ref.min_align(),
+                                    })
+                                } else if dst_ref.size() > src_ref.size() {
+                                    Answer::No(Reason::DstRefIsTooBig {
+                                        src: src_ref,
+                                        dst: dst_ref,
+                                    })
+                                } else {
+                                    // ...such that `src` is transmutable into `dst`, if
+                                    // `src_ref` is transmutability into `dst_ref`.
+                                    and(
+                                        Answer::If(Condition::IfTransmutable {
+                                            src: src_ref,
+                                            dst: dst_ref,
+                                        }),
+                                        self.answer_memo(cache, src_state_prime, dst_state_prime),
+                                    )
+                                }
+                            },
+                        ))
+                    }),
                 );
 
                 if self.assume.validity {
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs b/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs
index f8b59bdf326..214da101be3 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs
@@ -8,9 +8,17 @@ pub(crate) trait QueryContext {
 
 #[cfg(test)]
 pub(crate) mod test {
+    use std::marker::PhantomData;
+
     use super::QueryContext;
 
-    pub(crate) struct UltraMinimal;
+    pub(crate) struct UltraMinimal<R = !>(PhantomData<R>);
+
+    impl<R> Default for UltraMinimal<R> {
+        fn default() -> Self {
+            Self(PhantomData)
+        }
+    }
 
     #[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)]
     pub(crate) enum Def {
@@ -24,9 +32,9 @@ pub(crate) mod test {
         }
     }
 
-    impl QueryContext for UltraMinimal {
+    impl<R: crate::layout::Ref> QueryContext for UltraMinimal<R> {
         type Def = Def;
-        type Ref = !;
+        type Ref = R;
     }
 }
 
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
index cc6a4dce17b..24e2a1acadd 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
@@ -1,3 +1,5 @@
+extern crate test;
+
 use itertools::Itertools;
 
 use super::query_context::test::{Def, UltraMinimal};
@@ -12,15 +14,25 @@ trait Representation {
 
 impl Representation for Tree {
     fn is_transmutable(src: Self, dst: Self, assume: Assume) -> Answer<!> {
-        crate::maybe_transmutable::MaybeTransmutableQuery::new(src, dst, assume, UltraMinimal)
-            .answer()
+        crate::maybe_transmutable::MaybeTransmutableQuery::new(
+            src,
+            dst,
+            assume,
+            UltraMinimal::default(),
+        )
+        .answer()
     }
 }
 
 impl Representation for Dfa {
     fn is_transmutable(src: Self, dst: Self, assume: Assume) -> Answer<!> {
-        crate::maybe_transmutable::MaybeTransmutableQuery::new(src, dst, assume, UltraMinimal)
-            .answer()
+        crate::maybe_transmutable::MaybeTransmutableQuery::new(
+            src,
+            dst,
+            assume,
+            UltraMinimal::default(),
+        )
+        .answer()
     }
 }
 
@@ -89,6 +101,36 @@ mod safety {
     }
 }
 
+mod size {
+    use super::*;
+
+    #[test]
+    fn size() {
+        let small = Tree::number(1);
+        let large = Tree::number(2);
+
+        for alignment in [false, true] {
+            for lifetimes in [false, true] {
+                for safety in [false, true] {
+                    for validity in [false, true] {
+                        let assume = Assume { alignment, lifetimes, safety, validity };
+                        assert_eq!(
+                            is_transmutable(&small, &large, assume),
+                            Answer::No(Reason::DstIsTooBig),
+                            "assume: {assume:?}"
+                        );
+                        assert_eq!(
+                            is_transmutable(&large, &small, assume),
+                            Answer::Yes,
+                            "assume: {assume:?}"
+                        );
+                    }
+                }
+            }
+        }
+    }
+}
+
 mod bool {
     use super::*;
 
@@ -113,6 +155,27 @@ mod bool {
     }
 
     #[test]
+    fn transmute_u8() {
+        let bool = &Tree::bool();
+        let u8 = &Tree::u8();
+        for (src, dst, assume_validity, answer) in [
+            (bool, u8, false, Answer::Yes),
+            (bool, u8, true, Answer::Yes),
+            (u8, bool, false, Answer::No(Reason::DstIsBitIncompatible)),
+            (u8, bool, true, Answer::Yes),
+        ] {
+            assert_eq!(
+                is_transmutable(
+                    src,
+                    dst,
+                    Assume { validity: assume_validity, ..Assume::default() }
+                ),
+                answer
+            );
+        }
+    }
+
+    #[test]
     fn should_permit_validity_expansion_and_reject_contraction() {
         let b0 = layout::Tree::<Def, !>::from_bits(0);
         let b1 = layout::Tree::<Def, !>::from_bits(1);
@@ -175,6 +238,62 @@ mod bool {
     }
 }
 
+mod uninit {
+    use super::*;
+
+    #[test]
+    fn size() {
+        let mu = Tree::uninit();
+        let u8 = Tree::u8();
+
+        for alignment in [false, true] {
+            for lifetimes in [false, true] {
+                for safety in [false, true] {
+                    for validity in [false, true] {
+                        let assume = Assume { alignment, lifetimes, safety, validity };
+
+                        let want = if validity {
+                            Answer::Yes
+                        } else {
+                            Answer::No(Reason::DstIsBitIncompatible)
+                        };
+
+                        assert_eq!(is_transmutable(&mu, &u8, assume), want, "assume: {assume:?}");
+                        assert_eq!(
+                            is_transmutable(&u8, &mu, assume),
+                            Answer::Yes,
+                            "assume: {assume:?}"
+                        );
+                    }
+                }
+            }
+        }
+    }
+}
+
+mod alt {
+    use super::*;
+    use crate::Answer;
+
+    #[test]
+    fn should_permit_identity_transmutation() {
+        type Tree = layout::Tree<Def, !>;
+
+        let x = Tree::Seq(vec![Tree::from_bits(0), Tree::from_bits(0)]);
+        let y = Tree::Seq(vec![Tree::bool(), Tree::from_bits(1)]);
+        let layout = Tree::Alt(vec![x, y]);
+
+        let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new(
+            layout.clone(),
+            layout.clone(),
+            crate::Assume::default(),
+            UltraMinimal::default(),
+        )
+        .answer();
+        assert_eq!(answer, Answer::Yes, "layout:{:#?}", layout);
+    }
+}
+
 mod union {
     use super::*;
 
@@ -203,3 +322,59 @@ mod union {
         assert_eq!(is_transmutable(&t, &u, Assume::default()), Answer::Yes);
     }
 }
+
+mod r#ref {
+    use super::*;
+
+    #[test]
+    fn should_permit_identity_transmutation() {
+        type Tree = crate::layout::Tree<Def, [(); 1]>;
+
+        let layout = Tree::Seq(vec![Tree::from_bits(0), Tree::Ref([()])]);
+
+        let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new(
+            layout.clone(),
+            layout,
+            Assume::default(),
+            UltraMinimal::default(),
+        )
+        .answer();
+        assert_eq!(answer, Answer::If(crate::Condition::IfTransmutable { src: [()], dst: [()] }));
+    }
+}
+
+mod benches {
+    use std::hint::black_box;
+
+    use test::Bencher;
+
+    use super::*;
+
+    #[bench]
+    fn bench_dfa_from_tree(b: &mut Bencher) {
+        let num = Tree::number(8).prune(&|_| false);
+        let num = black_box(num);
+
+        b.iter(|| {
+            let _ = black_box(Dfa::from_tree(num.clone()));
+        })
+    }
+
+    #[bench]
+    fn bench_transmute(b: &mut Bencher) {
+        let num = Tree::number(8).prune(&|_| false);
+        let dfa = black_box(Dfa::from_tree(num).unwrap());
+
+        b.iter(|| {
+            let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new(
+                dfa.clone(),
+                dfa.clone(),
+                Assume::default(),
+                UltraMinimal::default(),
+            )
+            .answer();
+            let answer = std::hint::black_box(answer);
+            assert_eq!(answer, Answer::Yes);
+        })
+    }
+}