about summary refs log tree commit diff
path: root/compiler/rustc_transmute
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_transmute')
-rw-r--r--compiler/rustc_transmute/Cargo.toml6
-rw-r--r--compiler/rustc_transmute/src/layout/dfa.rs301
-rw-r--r--compiler/rustc_transmute/src/layout/mod.rs61
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/mod.rs346
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/tests.rs25
5 files changed, 284 insertions, 455 deletions
diff --git a/compiler/rustc_transmute/Cargo.toml b/compiler/rustc_transmute/Cargo.toml
index 0250cc0ea07..246b66d3d03 100644
--- a/compiler/rustc_transmute/Cargo.toml
+++ b/compiler/rustc_transmute/Cargo.toml
@@ -5,7 +5,6 @@ edition = "2024"
 
 [dependencies]
 # tidy-alphabetical-start
-itertools = "0.12"
 rustc_abi = { path = "../rustc_abi", optional = true }
 rustc_data_structures = { path = "../rustc_data_structures" }
 rustc_hir = { path = "../rustc_hir", optional = true }
@@ -15,6 +14,11 @@ smallvec = "1.8.1"
 tracing = "0.1"
 # tidy-alphabetical-end
 
+[dev-dependencies]
+# tidy-alphabetical-start
+itertools = "0.12"
+# tidy-alphabetical-end
+
 [features]
 rustc = [
     "dep:rustc_abi",
diff --git a/compiler/rustc_transmute/src/layout/dfa.rs b/compiler/rustc_transmute/src/layout/dfa.rs
index d1f58157b69..05afa28db31 100644
--- a/compiler/rustc_transmute/src/layout/dfa.rs
+++ b/compiler/rustc_transmute/src/layout/dfa.rs
@@ -1,5 +1,5 @@
 use std::fmt;
-use std::ops::RangeInclusive;
+use std::iter::Peekable;
 use std::sync::atomic::{AtomicU32, Ordering};
 
 use super::{Byte, Ref, Tree, Uninhabited};
@@ -211,15 +211,15 @@ where
             let b_transitions =
                 b_src.and_then(|b_src| b.transitions.get(&b_src)).unwrap_or(&empty_transitions);
 
-            let byte_transitions =
-                a_transitions.byte_transitions.union(&b_transitions.byte_transitions);
-
-            let byte_transitions = byte_transitions.map_states(|(a_dst, b_dst)| {
-                assert!(a_dst.is_some() || b_dst.is_some());
+            let byte_transitions = a_transitions.byte_transitions.union(
+                &b_transitions.byte_transitions,
+                |a_dst, b_dst| {
+                    assert!(a_dst.is_some() || b_dst.is_some());
 
-                queue.enqueue(a_dst, b_dst);
-                mapped((a_dst, b_dst))
-            });
+                    queue.enqueue(a_dst, b_dst);
+                    mapped((a_dst, b_dst))
+                },
+            );
 
             let ref_transitions =
                 a_transitions.ref_transitions.keys().chain(b_transitions.ref_transitions.keys());
@@ -245,18 +245,6 @@ where
         Self { transitions, start, accept }
     }
 
-    pub(crate) fn states_from(
-        &self,
-        state: State,
-        src_validity: RangeInclusive<u8>,
-    ) -> impl Iterator<Item = (Byte, State)> {
-        self.transitions
-            .get(&state)
-            .map(move |t| t.byte_transitions.states_from(src_validity))
-            .into_iter()
-            .flatten()
-    }
-
     pub(crate) fn get_uninit_edge_dst(&self, state: State) -> Option<State> {
         let transitions = self.transitions.get(&state)?;
         transitions.byte_transitions.get_uninit_edge_dst()
@@ -334,95 +322,31 @@ where
 
 use edge_set::EdgeSet;
 mod edge_set {
-    use std::cmp;
-
-    use run::*;
-    use smallvec::{SmallVec, smallvec};
+    use smallvec::SmallVec;
 
     use super::*;
-    mod run {
-        use std::ops::{Range, RangeInclusive};
-
-        use super::*;
-        use crate::layout::Byte;
-
-        /// A logical set of edges.
-        ///
-        /// A `Run` encodes one edge for every byte value in `start..=end`
-        /// pointing to `dst`.
-        #[derive(Eq, PartialEq, Copy, Clone, Debug)]
-        pub(super) struct Run<S> {
-            // `start` and `end` are both inclusive (ie, closed) bounds, as this
-            // is required in order to be able to store 0..=255. We provide
-            // setters and getters which operate on closed/open ranges, which
-            // are more intuitive and easier for performing offset math.
-            start: u8,
-            end: u8,
-            pub(super) dst: S,
-        }
-
-        impl<S> Run<S> {
-            pub(super) fn new(range: RangeInclusive<u8>, dst: S) -> Self {
-                Self { start: *range.start(), end: *range.end(), dst }
-            }
-
-            pub(super) fn from_inclusive_exclusive(range: Range<u16>, dst: S) -> Self {
-                Self {
-                    start: range.start.try_into().unwrap(),
-                    end: (range.end - 1).try_into().unwrap(),
-                    dst,
-                }
-            }
-
-            pub(super) fn contains(&self, idx: u16) -> bool {
-                idx >= u16::from(self.start) && idx <= u16::from(self.end)
-            }
-
-            pub(super) fn as_inclusive_exclusive(&self) -> (u16, u16) {
-                (u16::from(self.start), u16::from(self.end) + 1)
-            }
-
-            pub(super) fn as_byte(&self) -> Byte {
-                Byte::new(self.start..=self.end)
-            }
 
-            pub(super) fn map_state<SS>(self, f: impl FnOnce(S) -> SS) -> Run<SS> {
-                let Run { start, end, dst } = self;
-                Run { start, end, dst: f(dst) }
-            }
-
-            /// Produces a new `Run` whose lower bound is the greater of
-            /// `self`'s existing lower bound and `lower_bound`.
-            pub(super) fn clamp_lower(self, lower_bound: u8) -> Self {
-                let Run { start, end, dst } = self;
-                Run { start: cmp::max(start, lower_bound), end, dst }
-            }
-        }
-    }
-
-    /// The set of outbound byte edges associated with a DFA node (not including
-    /// reference edges).
+    /// The set of outbound byte edges associated with a DFA node.
     #[derive(Eq, PartialEq, Clone, Debug)]
     pub(super) struct EdgeSet<S = State> {
-        // A sequence of runs stored in ascending order. Since the graph is a
-        // DFA, these must be non-overlapping with one another.
-        runs: SmallVec<[Run<S>; 1]>,
-        // The edge labeled with the uninit byte, if any.
+        // A sequence of byte edges with contiguous byte values and a common
+        // destination is stored as a single run.
         //
-        // FIXME(@joshlf): Make `State` a `NonZero` so that this is NPO'd.
-        uninit: Option<S>,
+        // Runs are non-empty, non-overlapping, and stored in ascending order.
+        runs: SmallVec<[(Byte, S); 1]>,
     }
 
     impl<S> EdgeSet<S> {
-        pub(crate) fn new(byte: Byte, dst: S) -> Self {
-            match byte.range() {
-                Some(range) => Self { runs: smallvec![Run::new(range, dst)], uninit: None },
-                None => Self { runs: SmallVec::new(), uninit: Some(dst) },
+        pub(crate) fn new(range: Byte, dst: S) -> Self {
+            let mut this = Self { runs: SmallVec::new() };
+            if !range.is_empty() {
+                this.runs.push((range, dst));
             }
+            this
         }
 
         pub(crate) fn empty() -> Self {
-            Self { runs: SmallVec::new(), uninit: None }
+            Self { runs: SmallVec::new() }
         }
 
         #[cfg(test)]
@@ -431,43 +355,23 @@ mod edge_set {
             S: Ord,
         {
             edges.sort();
-            Self {
-                runs: edges
-                    .into_iter()
-                    .map(|(byte, state)| Run::new(byte.range().unwrap(), state))
-                    .collect(),
-                uninit: None,
-            }
+            Self { runs: edges.into() }
         }
 
         pub(crate) fn iter(&self) -> impl Iterator<Item = (Byte, S)>
         where
             S: Copy,
         {
-            self.uninit
-                .map(|dst| (Byte::uninit(), dst))
-                .into_iter()
-                .chain(self.runs.iter().map(|run| (run.as_byte(), run.dst)))
-        }
-
-        pub(crate) fn states_from(
-            &self,
-            byte: RangeInclusive<u8>,
-        ) -> impl Iterator<Item = (Byte, S)>
-        where
-            S: Copy,
-        {
-            // FIXME(@joshlf): Optimize this. A manual scan over `self.runs` may
-            // permit us to more efficiently discard runs which will not be
-            // produced by this iterator.
-            self.iter().filter(move |(o, _)| Byte::new(byte.clone()).transmutable_into(&o))
+            self.runs.iter().copied()
         }
 
         pub(crate) fn get_uninit_edge_dst(&self) -> Option<S>
         where
             S: Copy,
         {
-            self.uninit
+            // Uninit is ordered last.
+            let &(range, dst) = self.runs.last()?;
+            if range.contains_uninit() { Some(dst) } else { None }
         }
 
         pub(crate) fn map_states<SS>(self, mut f: impl FnMut(S) -> SS) -> EdgeSet<SS> {
@@ -478,95 +382,106 @@ mod edge_set {
                 // allocates the correct number of elements once up-front [1].
                 //
                 // [1] https://doc.rust-lang.org/1.85.0/src/alloc/vec/spec_from_iter_nested.rs.html#47
-                runs: self.runs.into_iter().map(|run| run.map_state(&mut f)).collect(),
-                uninit: self.uninit.map(f),
+                runs: self.runs.into_iter().map(|(b, s)| (b, f(s))).collect(),
             }
         }
 
         /// Unions two edge sets together.
         ///
         /// If `u = a.union(b)`, then for each byte value, `u` will have an edge
-        /// with that byte value and with the destination `(Some(_), None)`,
-        /// `(None, Some(_))`, or `(Some(_), Some(_))` depending on whether `a`,
+        /// with that byte value and with the destination `join(Some(_), None)`,
+        /// `join(None, Some(_))`, or `join(Some(_), Some(_))` depending on whether `a`,
         /// `b`, or both have an edge with that byte value.
         ///
         /// If neither `a` nor `b` have an edge with a particular byte value,
         /// then no edge with that value will be present in `u`.
-        pub(crate) fn union(&self, other: &Self) -> EdgeSet<(Option<S>, Option<S>)>
+        pub(crate) fn union(
+            &self,
+            other: &Self,
+            mut join: impl FnMut(Option<S>, Option<S>) -> S,
+        ) -> EdgeSet<S>
         where
             S: Copy,
         {
-            let uninit = match (self.uninit, other.uninit) {
-                (None, None) => None,
-                (s, o) => Some((s, o)),
-            };
-
-            let mut runs = SmallVec::new();
-
-            // Iterate over `self.runs` and `other.runs` simultaneously,
-            // advancing `idx` as we go. At each step, we advance `idx` as far
-            // as we can without crossing a run boundary in either `self.runs`
-            // or `other.runs`.
-
-            // INVARIANT: `idx < s[0].end && idx < o[0].end`.
-            let (mut s, mut o) = (self.runs.as_slice(), other.runs.as_slice());
-            let mut idx = 0u16;
-            while let (Some((s_run, s_rest)), Some((o_run, o_rest))) =
-                (s.split_first(), o.split_first())
-            {
-                let (s_start, s_end) = s_run.as_inclusive_exclusive();
-                let (o_start, o_end) = o_run.as_inclusive_exclusive();
-
-                // Compute `end` as the end of the current run (which starts
-                // with `idx`).
-                let (end, dst) = match (s_run.contains(idx), o_run.contains(idx)) {
-                    // `idx` is in an existing run in both `s` and `o`, so `end`
-                    // is equal to the smallest of the two ends of those runs.
-                    (true, true) => (cmp::min(s_end, o_end), (Some(s_run.dst), Some(o_run.dst))),
-                    // `idx` is in an existing run in `s`, but not in any run in
-                    // `o`. `end` is either the end of the `s` run or the
-                    // beginning of the next `o` run, whichever comes first.
-                    (true, false) => (cmp::min(s_end, o_start), (Some(s_run.dst), None)),
-                    // The inverse of the previous case.
-                    (false, true) => (cmp::min(s_start, o_end), (None, Some(o_run.dst))),
-                    // `idx` is not in a run in either `s` or `o`, so advance it
-                    // to the beginning of the next run.
-                    (false, false) => {
-                        idx = cmp::min(s_start, o_start);
-                        continue;
-                    }
-                };
+            let xs = self.runs.iter().copied();
+            let ys = other.runs.iter().copied();
+            // FIXME(@joshlf): Merge contiguous runs with common destination.
+            EdgeSet { runs: union(xs, ys).map(|(range, (x, y))| (range, join(x, y))).collect() }
+        }
+    }
+}
+
+/// Merges two sorted sequences into one sorted sequence.
+pub(crate) fn union<S: Copy, X: Iterator<Item = (Byte, S)>, Y: Iterator<Item = (Byte, S)>>(
+    xs: X,
+    ys: Y,
+) -> UnionIter<X, Y> {
+    UnionIter { xs: xs.peekable(), ys: ys.peekable() }
+}
+
+pub(crate) struct UnionIter<X: Iterator, Y: Iterator> {
+    xs: Peekable<X>,
+    ys: Peekable<Y>,
+}
+
+// FIXME(jswrenn) we'd likely benefit from specializing try_fold here.
+impl<S: Copy, X: Iterator<Item = (Byte, S)>, Y: Iterator<Item = (Byte, S)>> Iterator
+    for UnionIter<X, Y>
+{
+    type Item = (Byte, (Option<S>, Option<S>));
 
-                // FIXME(@joshlf): If this is contiguous with the previous run
-                // and has the same `dst`, just merge it into that run rather
-                // than adding a new one.
-                runs.push(Run::from_inclusive_exclusive(idx..end, dst));
-                idx = end;
+    fn next(&mut self) -> Option<Self::Item> {
+        use std::cmp::{self, Ordering};
 
-                if idx >= s_end {
-                    s = s_rest;
+        let ret;
+        match (self.xs.peek_mut(), self.ys.peek_mut()) {
+            (None, None) => {
+                ret = None;
+            }
+            (Some(x), None) => {
+                ret = Some((x.0, (Some(x.1), None)));
+                self.xs.next();
+            }
+            (None, Some(y)) => {
+                ret = Some((y.0, (None, Some(y.1))));
+                self.ys.next();
+            }
+            (Some(x), Some(y)) => {
+                let start;
+                let end;
+                let dst;
+                match x.0.start.cmp(&y.0.start) {
+                    Ordering::Less => {
+                        start = x.0.start;
+                        end = cmp::min(x.0.end, y.0.start);
+                        dst = (Some(x.1), None);
+                    }
+                    Ordering::Greater => {
+                        start = y.0.start;
+                        end = cmp::min(x.0.start, y.0.end);
+                        dst = (None, Some(y.1));
+                    }
+                    Ordering::Equal => {
+                        start = x.0.start;
+                        end = cmp::min(x.0.end, y.0.end);
+                        dst = (Some(x.1), Some(y.1));
+                    }
                 }
-                if idx >= o_end {
-                    o = o_rest;
+                ret = Some((Byte { start, end }, dst));
+                if start == x.0.start {
+                    x.0.start = end;
+                }
+                if start == y.0.start {
+                    y.0.start = end;
+                }
+                if x.0.is_empty() {
+                    self.xs.next();
+                }
+                if y.0.is_empty() {
+                    self.ys.next();
                 }
             }
-
-            // At this point, either `s` or `o` have been exhausted, so the
-            // remaining elements in the other slice are guaranteed to be
-            // non-overlapping. We can add all remaining runs to `runs` with no
-            // further processing.
-            if let Ok(idx) = u8::try_from(idx) {
-                let (slc, map) = if !s.is_empty() {
-                    let map: fn(_) -> _ = |st| (Some(st), None);
-                    (s, map)
-                } else {
-                    let map: fn(_) -> _ = |st| (None, Some(st));
-                    (o, map)
-                };
-                runs.extend(slc.iter().map(|run| run.clamp_lower(idx).map_state(map)));
-            }
-
-            EdgeSet { runs, uninit }
         }
+        ret
     }
 }
diff --git a/compiler/rustc_transmute/src/layout/mod.rs b/compiler/rustc_transmute/src/layout/mod.rs
index 4d5f630ae22..c08bf440734 100644
--- a/compiler/rustc_transmute/src/layout/mod.rs
+++ b/compiler/rustc_transmute/src/layout/mod.rs
@@ -6,61 +6,61 @@ pub(crate) mod tree;
 pub(crate) use tree::Tree;
 
 pub(crate) mod dfa;
-pub(crate) use dfa::Dfa;
+pub(crate) use dfa::{Dfa, union};
 
 #[derive(Debug)]
 pub(crate) struct Uninhabited;
 
-/// A range of byte values, or the uninit byte.
+/// A range of byte values (including an uninit byte value).
 #[derive(Hash, Eq, PartialEq, Ord, PartialOrd, Clone, Copy)]
 pub(crate) struct Byte {
-    // An inclusive-inclusive range. We use this instead of `RangeInclusive`
-    // because `RangeInclusive: !Copy`.
+    // An inclusive-exclusive range. We use this instead of `Range` because `Range: !Copy`.
     //
-    // `None` means uninit.
-    //
-    // FIXME(@joshlf): Optimize this representation. Some pairs of values (where
-    // `lo > hi`) are illegal, and we could use these to represent `None`.
-    range: Option<(u8, u8)>,
+    // Uninit byte value is represented by 256.
+    pub(crate) start: u16,
+    pub(crate) end: u16,
 }
 
 impl Byte {
+    const UNINIT: u16 = 256;
+
+    #[inline]
     fn new(range: RangeInclusive<u8>) -> Self {
-        Self { range: Some((*range.start(), *range.end())) }
+        let start: u16 = (*range.start()).into();
+        let end: u16 = (*range.end()).into();
+        Byte { start, end: end + 1 }
     }
 
+    #[inline]
     fn from_val(val: u8) -> Self {
-        Self { range: Some((val, val)) }
+        let val: u16 = val.into();
+        Byte { start: val, end: val + 1 }
     }
 
-    pub(crate) fn uninit() -> Byte {
-        Byte { range: None }
+    #[inline]
+    fn uninit() -> Byte {
+        Byte { start: 0, end: Self::UNINIT + 1 }
     }
 
-    /// Returns `None` if `self` is the uninit byte.
-    pub(crate) fn range(&self) -> Option<RangeInclusive<u8>> {
-        self.range.map(|(lo, hi)| lo..=hi)
+    #[inline]
+    fn is_empty(&self) -> bool {
+        self.start == self.end
     }
 
-    /// Are any of the values in `self` transmutable into `other`?
-    ///
-    /// Note two special cases: An uninit byte is only transmutable into another
-    /// uninit byte. Any byte is transmutable into an uninit byte.
-    pub(crate) fn transmutable_into(&self, other: &Byte) -> bool {
-        match (self.range, other.range) {
-            (None, None) => true,
-            (None, Some(_)) => false,
-            (Some(_), None) => true,
-            (Some((slo, shi)), Some((olo, ohi))) => slo <= ohi && olo <= shi,
-        }
+    #[inline]
+    fn contains_uninit(&self) -> bool {
+        self.start <= Self::UNINIT && Self::UNINIT < self.end
     }
 }
 
 impl fmt::Debug for Byte {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        match self.range {
-            None => write!(f, "uninit"),
-            Some((lo, hi)) => write!(f, "{lo}..={hi}"),
+        if self.start == Self::UNINIT && self.end == Self::UNINIT + 1 {
+            write!(f, "uninit")
+        } else if self.start <= Self::UNINIT && self.end == Self::UNINIT + 1 {
+            write!(f, "{}..{}|uninit", self.start, self.end - 1)
+        } else {
+            write!(f, "{}..{}", self.start, self.end)
         }
     }
 }
@@ -72,6 +72,7 @@ impl From<RangeInclusive<u8>> for Byte {
 }
 
 impl From<u8> for Byte {
+    #[inline]
     fn from(src: u8) -> Self {
         Self::from_val(src)
     }
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
index 0a19cccc2ed..f76abe50ed3 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
@@ -1,14 +1,11 @@
-use std::rc::Rc;
-use std::{cmp, iter};
-
-use itertools::Either;
+use rustc_data_structures::stack::ensure_sufficient_stack;
 use tracing::{debug, instrument, trace};
 
 pub(crate) mod query_context;
 #[cfg(test)]
 mod tests;
 
-use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, dfa};
+use crate::layout::{self, Def, Dfa, Ref, Tree, dfa, union};
 use crate::maybe_transmutable::query_context::QueryContext;
 use crate::{Answer, Condition, Map, Reason};
 
@@ -153,230 +150,135 @@ where
         if let Some(answer) = cache.get(&(src_state, dst_state)) {
             answer.clone()
         } else {
-            debug!(?src_state, ?dst_state);
-            debug!(src = ?self.src);
-            debug!(dst = ?self.dst);
-            debug!(
-                src_transitions_len = self.src.transitions.len(),
-                dst_transitions_len = self.dst.transitions.len()
-            );
-            let answer = if dst_state == self.dst.accept {
-                // truncation: `size_of(Src) >= size_of(Dst)`
-                //
-                // Why is truncation OK to do? Because even though the Src is bigger, all we care about
-                // is whether we have enough data for the Dst to be valid in accordance with what its
-                // type dictates.
-                // For example, in a u8 to `()` transmutation, we have enough data available from the u8
-                // to transmute it to a `()` (though in this case does `()` really need any data to
-                // begin with? It doesn't). Same thing with u8 to fieldless struct.
-                // Now then, why is something like u8 to bool not allowed? That is not because the bool
-                // is smaller in size, but rather because those 2 bits that we are re-interpreting from
-                // the u8 could introduce invalid states for the bool type.
-                //
-                // So, if it's possible to transmute to a smaller Dst by truncating, and we can guarantee
-                // that none of the actually-used data can introduce an invalid state for Dst's type, we
-                // are able to safely transmute, even with truncation.
-                Answer::Yes
-            } else if src_state == self.src.accept {
-                // extension: `size_of(Src) <= size_of(Dst)`
-                if let Some(dst_state_prime) = self.dst.get_uninit_edge_dst(dst_state) {
-                    self.answer_memo(cache, src_state, dst_state_prime)
-                } else {
-                    Answer::No(Reason::DstIsTooBig)
-                }
+            let answer = ensure_sufficient_stack(|| self.answer_impl(cache, src_state, dst_state));
+            if let Some(..) = cache.insert((src_state, dst_state), answer.clone()) {
+                panic!("failed to correctly cache transmutability")
+            }
+            answer
+        }
+    }
+
+    fn answer_impl(
+        &self,
+        cache: &mut Map<(dfa::State, dfa::State), Answer<<C as QueryContext>::Ref>>,
+        src_state: dfa::State,
+        dst_state: dfa::State,
+    ) -> Answer<<C as QueryContext>::Ref> {
+        debug!(?src_state, ?dst_state);
+        debug!(src = ?self.src);
+        debug!(dst = ?self.dst);
+        debug!(
+            src_transitions_len = self.src.transitions.len(),
+            dst_transitions_len = self.dst.transitions.len()
+        );
+        if dst_state == self.dst.accept {
+            // truncation: `size_of(Src) >= size_of(Dst)`
+            //
+            // Why is truncation OK to do? Because even though the Src is bigger, all we care about
+            // is whether we have enough data for the Dst to be valid in accordance with what its
+            // type dictates.
+            // For example, in a u8 to `()` transmutation, we have enough data available from the u8
+            // to transmute it to a `()` (though in this case does `()` really need any data to
+            // begin with? It doesn't). Same thing with u8 to fieldless struct.
+            // Now then, why is something like u8 to bool not allowed? That is not because the bool
+            // is smaller in size, but rather because those 2 bits that we are re-interpreting from
+            // the u8 could introduce invalid states for the bool type.
+            //
+            // So, if it's possible to transmute to a smaller Dst by truncating, and we can guarantee
+            // that none of the actually-used data can introduce an invalid state for Dst's type, we
+            // are able to safely transmute, even with truncation.
+            Answer::Yes
+        } else if src_state == self.src.accept {
+            // extension: `size_of(Src) <= size_of(Dst)`
+            if let Some(dst_state_prime) = self.dst.get_uninit_edge_dst(dst_state) {
+                self.answer_memo(cache, src_state, dst_state_prime)
+            } else {
+                Answer::No(Reason::DstIsTooBig)
+            }
+        } else {
+            let src_quantifier = if self.assume.validity {
+                // if the compiler may assume that the programmer is doing additional validity checks,
+                // (e.g.: that `src != 3u8` when the destination type is `bool`)
+                // then there must exist at least one transition out of `src_state` such that the transmute is viable...
+                Quantifier::ThereExists
             } else {
-                let src_quantifier = if self.assume.validity {
-                    // if the compiler may assume that the programmer is doing additional validity checks,
-                    // (e.g.: that `src != 3u8` when the destination type is `bool`)
-                    // then there must exist at least one transition out of `src_state` such that the transmute is viable...
-                    Quantifier::ThereExists
-                } else {
-                    // if the compiler cannot assume that the programmer is doing additional validity checks,
-                    // then for all transitions out of `src_state`, such that the transmute is viable...
-                    // then there must exist at least one transition out of `dst_state` such that the transmute is viable...
-                    Quantifier::ForAll
-                };
-
-                let c = &core::cell::RefCell::new(&mut *cache);
-                let bytes_answer = src_quantifier.apply(
-                    // for each of the byte set transitions out of the `src_state`...
-                    self.src.bytes_from(src_state).flat_map(
-                        move |(src_validity, src_state_prime)| {
-                            // ...find all matching transitions out of `dst_state`.
-
-                            let Some(src_validity) = src_validity.range() else {
-                                // NOTE: We construct an iterator here rather
-                                // than just computing the value directly (via
-                                // `self.answer_memo`) so that, if the iterator
-                                // we produce from this branch is
-                                // short-circuited, we don't waste time
-                                // computing `self.answer_memo` unnecessarily.
-                                // That will specifically happen if
-                                // `src_quantifier == Quantifier::ThereExists`,
-                                // since we emit `Answer::Yes` first (before
-                                // chaining `answer_iter`).
-                                let answer_iter = if let Some(dst_state_prime) =
-                                    self.dst.get_uninit_edge_dst(dst_state)
-                                {
-                                    Either::Left(iter::once_with(move || {
-                                        let mut c = c.borrow_mut();
-                                        self.answer_memo(&mut *c, src_state_prime, dst_state_prime)
-                                    }))
-                                } else {
-                                    Either::Right(iter::once(Answer::No(
-                                        Reason::DstIsBitIncompatible,
-                                    )))
-                                };
-
-                                // When `answer == Answer::No(...)`, there are
-                                // two cases to consider:
-                                // - If `assume.validity`, then we should
-                                //   succeed because the user is responsible for
-                                //   ensuring that the *specific* byte value
-                                //   appearing at runtime is valid for the
-                                //   destination type. When `assume.validity`,
-                                //   `src_quantifier ==
-                                //   Quantifier::ThereExists`, so adding an
-                                //   `Answer::Yes` has the effect of ensuring
-                                //   that the "there exists" is always
-                                //   satisfied.
-                                // - If `!assume.validity`, then we should fail.
-                                //   In this case, `src_quantifier ==
-                                //   Quantifier::ForAll`, so adding an
-                                //   `Answer::Yes` has no effect.
-                                return Either::Left(iter::once(Answer::Yes).chain(answer_iter));
-                            };
-
-                            #[derive(Copy, Clone, Debug)]
-                            struct Accum {
-                                // The number of matching byte edges that we
-                                // have found in the destination so far.
-                                sum: usize,
-                                found_uninit: bool,
+                // if the compiler cannot assume that the programmer is doing additional validity checks,
+                // then for all transitions out of `src_state`, such that the transmute is viable...
+                // then there must exist at least one transition out of `dst_state` such that the transmute is viable...
+                Quantifier::ForAll
+            };
+
+            let bytes_answer = src_quantifier.apply(
+                union(self.src.bytes_from(src_state), self.dst.bytes_from(dst_state)).filter_map(
+                    |(_range, (src_state_prime, dst_state_prime))| {
+                        match (src_state_prime, dst_state_prime) {
+                            // No matching transitions in `src`. Skip.
+                            (None, _) => None,
+                            // No matching transitions in `dst`. Fail.
+                            (Some(_), None) => Some(Answer::No(Reason::DstIsBitIncompatible)),
+                            // Matching transitions. Continue with successor states.
+                            (Some(src_state_prime), Some(dst_state_prime)) => {
+                                Some(self.answer_memo(cache, src_state_prime, dst_state_prime))
                             }
+                        }
+                    },
+                ),
+            );
 
-                            let accum1 = Rc::new(std::cell::Cell::new(Accum {
-                                sum: 0,
-                                found_uninit: false,
-                            }));
-                            let accum2 = Rc::clone(&accum1);
-                            let sv = src_validity.clone();
-                            let update_accum = move |mut accum: Accum, dst_validity: Byte| {
-                                if let Some(dst_validity) = dst_validity.range() {
-                                    // Only add the part of `dst_validity` that
-                                    // overlaps with `src_validity`.
-                                    let start = cmp::max(*sv.start(), *dst_validity.start());
-                                    let end = cmp::min(*sv.end(), *dst_validity.end());
-
-                                    // We add 1 here to account for the fact
-                                    // that `end` is an inclusive bound.
-                                    accum.sum += 1 + usize::from(end.saturating_sub(start));
-                                } else {
-                                    accum.found_uninit = true;
-                                }
-                                accum
-                            };
-
-                            let answers = self
-                                .dst
-                                .states_from(dst_state, src_validity.clone())
-                                .map(move |(dst_validity, dst_state_prime)| {
-                                    let mut c = c.borrow_mut();
-                                    accum1.set(update_accum(accum1.get(), dst_validity));
-                                    let answer =
-                                        self.answer_memo(&mut *c, src_state_prime, dst_state_prime);
-                                    answer
+            // The below early returns reflect how this code would behave:
+            //   if self.assume.validity {
+            //       or(bytes_answer, refs_answer)
+            //   } else {
+            //       and(bytes_answer, refs_answer)
+            //   }
+            // ...if `refs_answer` was computed lazily. The below early
+            // returns can be deleted without impacting the correctness of
+            // the algorithm; only its performance.
+            debug!(?bytes_answer);
+            match bytes_answer {
+                Answer::No(_) if !self.assume.validity => return bytes_answer,
+                Answer::Yes if self.assume.validity => return bytes_answer,
+                _ => {}
+            };
+
+            let refs_answer = src_quantifier.apply(
+                // for each reference transition out of `src_state`...
+                self.src.refs_from(src_state).map(|(src_ref, src_state_prime)| {
+                    // ...there exists a reference transition out of `dst_state`...
+                    Quantifier::ThereExists.apply(self.dst.refs_from(dst_state).map(
+                        |(dst_ref, dst_state_prime)| {
+                            if !src_ref.is_mutable() && dst_ref.is_mutable() {
+                                Answer::No(Reason::DstIsMoreUnique)
+                            } else if !self.assume.alignment
+                                && src_ref.min_align() < dst_ref.min_align()
+                            {
+                                Answer::No(Reason::DstHasStricterAlignment {
+                                    src_min_align: src_ref.min_align(),
+                                    dst_min_align: dst_ref.min_align(),
                                 })
-                                .chain(
-                                    iter::once_with(move || {
-                                        let src_validity_len = usize::from(*src_validity.end())
-                                            - usize::from(*src_validity.start())
-                                            + 1;
-                                        let accum = accum2.get();
-
-                                        // If this condition is false, then
-                                        // there are some byte values in the
-                                        // source which have no corresponding
-                                        // transition in the destination DFA. In
-                                        // that case, we add a `No` to our list
-                                        // of answers. When
-                                        // `!self.assume.validity`, this will
-                                        // cause the query to fail.
-                                        if accum.found_uninit || accum.sum == src_validity_len {
-                                            None
-                                        } else {
-                                            Some(Answer::No(Reason::DstIsBitIncompatible))
-                                        }
-                                    })
-                                    .flatten(),
-                                );
-                            Either::Right(answers)
-                        },
-                    ),
-                );
-
-                // The below early returns reflect how this code would behave:
-                //   if self.assume.validity {
-                //       or(bytes_answer, refs_answer)
-                //   } else {
-                //       and(bytes_answer, refs_answer)
-                //   }
-                // ...if `refs_answer` was computed lazily. The below early
-                // returns can be deleted without impacting the correctness of
-                // the algorithm; only its performance.
-                debug!(?bytes_answer);
-                match bytes_answer {
-                    Answer::No(_) if !self.assume.validity => return bytes_answer,
-                    Answer::Yes if self.assume.validity => return bytes_answer,
-                    _ => {}
-                };
-
-                let refs_answer = src_quantifier.apply(
-                    // for each reference transition out of `src_state`...
-                    self.src.refs_from(src_state).map(|(src_ref, src_state_prime)| {
-                        // ...there exists a reference transition out of `dst_state`...
-                        Quantifier::ThereExists.apply(self.dst.refs_from(dst_state).map(
-                            |(dst_ref, dst_state_prime)| {
-                                if !src_ref.is_mutable() && dst_ref.is_mutable() {
-                                    Answer::No(Reason::DstIsMoreUnique)
-                                } else if !self.assume.alignment
-                                    && src_ref.min_align() < dst_ref.min_align()
-                                {
-                                    Answer::No(Reason::DstHasStricterAlignment {
-                                        src_min_align: src_ref.min_align(),
-                                        dst_min_align: dst_ref.min_align(),
-                                    })
-                                } else if dst_ref.size() > src_ref.size() {
-                                    Answer::No(Reason::DstRefIsTooBig {
+                            } else if dst_ref.size() > src_ref.size() {
+                                Answer::No(Reason::DstRefIsTooBig { src: src_ref, dst: dst_ref })
+                            } else {
+                                // ...such that `src` is transmutable into `dst`, if
+                                // `src_ref` is transmutability into `dst_ref`.
+                                and(
+                                    Answer::If(Condition::IfTransmutable {
                                         src: src_ref,
                                         dst: dst_ref,
-                                    })
-                                } else {
-                                    // ...such that `src` is transmutable into `dst`, if
-                                    // `src_ref` is transmutability into `dst_ref`.
-                                    and(
-                                        Answer::If(Condition::IfTransmutable {
-                                            src: src_ref,
-                                            dst: dst_ref,
-                                        }),
-                                        self.answer_memo(cache, src_state_prime, dst_state_prime),
-                                    )
-                                }
-                            },
-                        ))
-                    }),
-                );
-
-                if self.assume.validity {
-                    or(bytes_answer, refs_answer)
-                } else {
-                    and(bytes_answer, refs_answer)
-                }
-            };
-            if let Some(..) = cache.insert((src_state, dst_state), answer.clone()) {
-                panic!("failed to correctly cache transmutability")
+                                    }),
+                                    self.answer_memo(cache, src_state_prime, dst_state_prime),
+                                )
+                            }
+                        },
+                    ))
+                }),
+            );
+
+            if self.assume.validity {
+                or(bytes_answer, refs_answer)
+            } else {
+                and(bytes_answer, refs_answer)
             }
-            answer
         }
     }
 }
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
index 992fcb7cc4c..fbb4639dbd6 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
@@ -400,16 +400,23 @@ mod r#ref {
     fn should_permit_identity_transmutation() {
         type Tree = crate::layout::Tree<Def, [(); 1]>;
 
-        let layout = Tree::Seq(vec![Tree::byte(0x00), Tree::Ref([()])]);
+        for validity in [false, true] {
+            let layout = Tree::Seq(vec![Tree::byte(0x00), Tree::Ref([()])]);
 
-        let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new(
-            layout.clone(),
-            layout,
-            Assume::default(),
-            UltraMinimal::default(),
-        )
-        .answer();
-        assert_eq!(answer, Answer::If(crate::Condition::IfTransmutable { src: [()], dst: [()] }));
+            let assume = Assume { validity, ..Assume::default() };
+
+            let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new(
+                layout.clone(),
+                layout,
+                assume,
+                UltraMinimal::default(),
+            )
+            .answer();
+            assert_eq!(
+                answer,
+                Answer::If(crate::Condition::IfTransmutable { src: [()], dst: [()] })
+            );
+        }
     }
 }