about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJack Wrenn <jack@wrenn.fyi>2025-04-17 17:58:34 +0000
committerJack Wrenn <jack@wrenn.fyi>2025-04-20 03:06:59 +0000
commit957b5488a5fb3875006c06577d9049177ed971bc (patch)
tree71f4bde6f53fcf4bdf406073cc4247a7494fbb2a
parent883f9f72e87ccb6838d528d8158ea6323baacc65 (diff)
downloadrust-957b5488a5fb3875006c06577d9049177ed971bc.tar.gz
rust-957b5488a5fb3875006c06577d9049177ed971bc.zip
transmutability: remove NFA intermediate representation
Prior to this commit, the transmutability analysis used an intermediate
NFA representation of type layout. We then determinized this
representation into a DFA, upon which we ran the core transmutability
analysis. Unfortunately, determinizing NFAs is expensive. In this
commit, we avoid NFAs entirely by observing that Rust `union`s are the
only source of nondeterminism and that it is comparatively cheap to
compute the DFA union of DFAs.

We also implement Graphviz DOT debug formatting of DFAs.

Fixes rust-lang/project-safe-transmute#23
Fixes rust-lang/project-safe-transmute#24
-rw-r--r--compiler/rustc_transmute/src/layout/dfa.rs289
-rw-r--r--compiler/rustc_transmute/src/layout/mod.rs10
-rw-r--r--compiler/rustc_transmute/src/layout/nfa.rs169
-rw-r--r--compiler/rustc_transmute/src/lib.rs2
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/mod.rs33
-rw-r--r--compiler/rustc_transmute/src/maybe_transmutable/tests.rs31
6 files changed, 239 insertions, 295 deletions
diff --git a/compiler/rustc_transmute/src/layout/dfa.rs b/compiler/rustc_transmute/src/layout/dfa.rs
index af568171f91..bb909c54d2b 100644
--- a/compiler/rustc_transmute/src/layout/dfa.rs
+++ b/compiler/rustc_transmute/src/layout/dfa.rs
@@ -1,19 +1,18 @@
 use std::fmt;
 use std::sync::atomic::{AtomicU32, Ordering};
 
-use tracing::instrument;
-
-use super::{Byte, Nfa, Ref, nfa};
+use super::{Byte, Ref, Tree, Uninhabited};
 use crate::Map;
 
-#[derive(PartialEq, Clone, Debug)]
+#[derive(PartialEq)]
+#[cfg_attr(test, derive(Clone))]
 pub(crate) struct Dfa<R>
 where
     R: Ref,
 {
     pub(crate) transitions: Map<State, Transitions<R>>,
     pub(crate) start: State,
-    pub(crate) accepting: State,
+    pub(crate) accept: State,
 }
 
 #[derive(PartialEq, Clone, Debug)]
@@ -34,35 +33,15 @@ where
     }
 }
 
-impl<R> Transitions<R>
-where
-    R: Ref,
-{
-    #[cfg(test)]
-    fn insert(&mut self, transition: Transition<R>, state: State) {
-        match transition {
-            Transition::Byte(b) => {
-                self.byte_transitions.insert(b, state);
-            }
-            Transition::Ref(r) => {
-                self.ref_transitions.insert(r, state);
-            }
-        }
-    }
-}
-
-/// The states in a `Nfa` represent byte offsets.
+/// The states in a [`Dfa`] represent byte offsets.
 #[derive(Hash, Eq, PartialEq, PartialOrd, Ord, Copy, Clone)]
-pub(crate) struct State(u32);
+pub(crate) struct State(pub(crate) u32);
 
-#[cfg(test)]
-#[derive(Hash, Eq, PartialEq, Clone, Copy)]
-pub(crate) enum Transition<R>
-where
-    R: Ref,
-{
-    Byte(Byte),
-    Ref(R),
+impl State {
+    pub(crate) fn new() -> Self {
+        static COUNTER: AtomicU32 = AtomicU32::new(0);
+        Self(COUNTER.fetch_add(1, Ordering::SeqCst))
+    }
 }
 
 impl fmt::Debug for State {
@@ -71,19 +50,6 @@ impl fmt::Debug for State {
     }
 }
 
-#[cfg(test)]
-impl<R> fmt::Debug for Transition<R>
-where
-    R: Ref,
-{
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        match &self {
-            Self::Byte(b) => b.fmt(f),
-            Self::Ref(r) => r.fmt(f),
-        }
-    }
-}
-
 impl<R> Dfa<R>
 where
     R: Ref,
@@ -92,60 +58,167 @@ where
     pub(crate) fn bool() -> Self {
         let mut transitions: Map<State, Transitions<R>> = Map::default();
         let start = State::new();
-        let accepting = State::new();
+        let accept = State::new();
 
-        transitions.entry(start).or_default().insert(Transition::Byte(Byte::Init(0x00)), accepting);
+        transitions.entry(start).or_default().byte_transitions.insert(Byte::Init(0x00), accept);
 
-        transitions.entry(start).or_default().insert(Transition::Byte(Byte::Init(0x01)), accepting);
+        transitions.entry(start).or_default().byte_transitions.insert(Byte::Init(0x01), accept);
 
-        Self { transitions, start, accepting }
+        Self { transitions, start, accept }
     }
 
-    #[instrument(level = "debug")]
-    pub(crate) fn from_nfa(nfa: Nfa<R>) -> Self {
-        let Nfa { transitions: nfa_transitions, start: nfa_start, accepting: nfa_accepting } = nfa;
+    pub(crate) fn unit() -> Self {
+        let transitions: Map<State, Transitions<R>> = Map::default();
+        let start = State::new();
+        let accept = start;
+
+        Self { transitions, start, accept }
+    }
 
-        let mut dfa_transitions: Map<State, Transitions<R>> = Map::default();
-        let mut nfa_to_dfa: Map<nfa::State, State> = Map::default();
-        let dfa_start = State::new();
-        nfa_to_dfa.insert(nfa_start, dfa_start);
+    pub(crate) fn from_byte(byte: Byte) -> Self {
+        let mut transitions: Map<State, Transitions<R>> = Map::default();
+        let start = State::new();
+        let accept = State::new();
 
-        let mut queue = vec![(nfa_start, dfa_start)];
+        transitions.entry(start).or_default().byte_transitions.insert(byte, accept);
 
-        while let Some((nfa_state, dfa_state)) = queue.pop() {
-            if nfa_state == nfa_accepting {
-                continue;
-            }
+        Self { transitions, start, accept }
+    }
 
-            for (nfa_transition, next_nfa_states) in nfa_transitions[&nfa_state].iter() {
-                let dfa_transitions =
-                    dfa_transitions.entry(dfa_state).or_insert_with(Default::default);
-
-                let mapped_state = next_nfa_states.iter().find_map(|x| nfa_to_dfa.get(x).copied());
-
-                let next_dfa_state = match nfa_transition {
-                    &nfa::Transition::Byte(b) => *dfa_transitions
-                        .byte_transitions
-                        .entry(b)
-                        .or_insert_with(|| mapped_state.unwrap_or_else(State::new)),
-                    &nfa::Transition::Ref(r) => *dfa_transitions
-                        .ref_transitions
-                        .entry(r)
-                        .or_insert_with(|| mapped_state.unwrap_or_else(State::new)),
-                };
-
-                for &next_nfa_state in next_nfa_states {
-                    nfa_to_dfa.entry(next_nfa_state).or_insert_with(|| {
-                        queue.push((next_nfa_state, next_dfa_state));
-                        next_dfa_state
-                    });
+    pub(crate) fn from_ref(r: R) -> Self {
+        let mut transitions: Map<State, Transitions<R>> = Map::default();
+        let start = State::new();
+        let accept = State::new();
+
+        transitions.entry(start).or_default().ref_transitions.insert(r, accept);
+
+        Self { transitions, start, accept }
+    }
+
+    pub(crate) fn from_tree(tree: Tree<!, R>) -> Result<Self, Uninhabited> {
+        Ok(match tree {
+            Tree::Byte(b) => Self::from_byte(b),
+            Tree::Ref(r) => Self::from_ref(r),
+            Tree::Alt(alts) => {
+                // Convert and filter the inhabited alternatives.
+                let mut alts = alts.into_iter().map(Self::from_tree).filter_map(Result::ok);
+                // If there are no alternatives, return `Uninhabited`.
+                let dfa = alts.next().ok_or(Uninhabited)?;
+                // Combine the remaining alternatives with `dfa`.
+                alts.fold(dfa, |dfa, alt| dfa.union(alt, State::new))
+            }
+            Tree::Seq(elts) => {
+                let mut dfa = Self::unit();
+                for elt in elts.into_iter().map(Self::from_tree) {
+                    dfa = dfa.concat(elt?);
                 }
+                dfa
             }
+        })
+    }
+
+    /// Concatenate two `Dfa`s.
+    pub(crate) fn concat(self, other: Self) -> Self {
+        if self.start == self.accept {
+            return other;
+        } else if other.start == other.accept {
+            return self;
         }
 
-        let dfa_accepting = nfa_to_dfa[&nfa_accepting];
+        let start = self.start;
+        let accept = other.accept;
+
+        let mut transitions: Map<State, Transitions<R>> = self.transitions;
 
-        Self { transitions: dfa_transitions, start: dfa_start, accepting: dfa_accepting }
+        for (source, transition) in other.transitions {
+            let fix_state = |state| if state == other.start { self.accept } else { state };
+            let entry = transitions.entry(fix_state(source)).or_default();
+            for (edge, destination) in transition.byte_transitions {
+                entry.byte_transitions.insert(edge, fix_state(destination));
+            }
+            for (edge, destination) in transition.ref_transitions {
+                entry.ref_transitions.insert(edge, fix_state(destination));
+            }
+        }
+
+        Self { transitions, start, accept }
+    }
+
+    /// Compute the union of two `Dfa`s.
+    pub(crate) fn union(self, other: Self, mut new_state: impl FnMut() -> State) -> Self {
+        // We implement `union` by lazily initializing a set of states
+        // corresponding to the product of states in `self` and `other`, and
+        // then add transitions between these states that correspond to where
+        // they exist between `self` and `other`.
+
+        let a = self;
+        let b = other;
+
+        let accept = new_state();
+
+        let mut mapping: Map<(Option<State>, Option<State>), State> = Map::default();
+
+        let mut mapped = |(a_state, b_state)| {
+            if Some(a.accept) == a_state || Some(b.accept) == b_state {
+                // If either `a_state` or `b_state` are accepting, map to a
+                // common `accept` state.
+                accept
+            } else {
+                *mapping.entry((a_state, b_state)).or_insert_with(&mut new_state)
+            }
+        };
+
+        let start = mapped((Some(a.start), Some(b.start)));
+        let mut transitions: Map<State, Transitions<R>> = Map::default();
+        let mut queue = vec![(Some(a.start), Some(b.start))];
+        let empty_transitions = Transitions::default();
+
+        while let Some((a_src, b_src)) = queue.pop() {
+            let a_transitions =
+                a_src.and_then(|a_src| a.transitions.get(&a_src)).unwrap_or(&empty_transitions);
+            let b_transitions =
+                b_src.and_then(|b_src| b.transitions.get(&b_src)).unwrap_or(&empty_transitions);
+
+            let byte_transitions =
+                a_transitions.byte_transitions.keys().chain(b_transitions.byte_transitions.keys());
+
+            for byte_transition in byte_transitions {
+                let a_dst = a_transitions.byte_transitions.get(byte_transition).copied();
+                let b_dst = b_transitions.byte_transitions.get(byte_transition).copied();
+
+                assert!(a_dst.is_some() || b_dst.is_some());
+
+                let src = mapped((a_src, b_src));
+                let dst = mapped((a_dst, b_dst));
+
+                transitions.entry(src).or_default().byte_transitions.insert(*byte_transition, dst);
+
+                if !transitions.contains_key(&dst) {
+                    queue.push((a_dst, b_dst))
+                }
+            }
+
+            let ref_transitions =
+                a_transitions.ref_transitions.keys().chain(b_transitions.ref_transitions.keys());
+
+            for ref_transition in ref_transitions {
+                let a_dst = a_transitions.ref_transitions.get(ref_transition).copied();
+                let b_dst = b_transitions.ref_transitions.get(ref_transition).copied();
+
+                assert!(a_dst.is_some() || b_dst.is_some());
+
+                let src = mapped((a_src, b_src));
+                let dst = mapped((a_dst, b_dst));
+
+                transitions.entry(src).or_default().ref_transitions.insert(*ref_transition, dst);
+
+                if !transitions.contains_key(&dst) {
+                    queue.push((a_dst, b_dst))
+                }
+            }
+        }
+
+        Self { transitions, start, accept }
     }
 
     pub(crate) fn bytes_from(&self, start: State) -> Option<&Map<Byte, State>> {
@@ -159,24 +232,48 @@ where
     pub(crate) fn refs_from(&self, start: State) -> Option<&Map<R, State>> {
         Some(&self.transitions.get(&start)?.ref_transitions)
     }
-}
 
-impl State {
-    pub(crate) fn new() -> Self {
-        static COUNTER: AtomicU32 = AtomicU32::new(0);
-        Self(COUNTER.fetch_add(1, Ordering::SeqCst))
+    #[cfg(test)]
+    pub(crate) fn from_edges<B: Copy + Into<Byte>>(
+        start: u32,
+        accept: u32,
+        edges: &[(u32, B, u32)],
+    ) -> Self {
+        let start = State(start);
+        let accept = State(accept);
+        let mut transitions: Map<State, Transitions<R>> = Map::default();
+
+        for &(src, edge, dst) in edges {
+            let src = State(src);
+            let dst = State(dst);
+            let old = transitions.entry(src).or_default().byte_transitions.insert(edge.into(), dst);
+            assert!(old.is_none());
+        }
+
+        Self { start, accept, transitions }
     }
 }
 
-#[cfg(test)]
-impl<R> From<nfa::Transition<R>> for Transition<R>
+/// Serialize the DFA using the Graphviz DOT format.
+impl<R> fmt::Debug for Dfa<R>
 where
     R: Ref,
 {
-    fn from(nfa_transition: nfa::Transition<R>) -> Self {
-        match nfa_transition {
-            nfa::Transition::Byte(byte) => Transition::Byte(byte),
-            nfa::Transition::Ref(r) => Transition::Ref(r),
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        writeln!(f, "digraph {{")?;
+        writeln!(f, "    {:?} [shape = doublecircle]", self.start)?;
+        writeln!(f, "    {:?} [shape = doublecircle]", self.accept)?;
+
+        for (src, transitions) in self.transitions.iter() {
+            for (t, dst) in transitions.byte_transitions.iter() {
+                writeln!(f, "    {src:?} -> {dst:?} [label=\"{t:?}\"]")?;
+            }
+
+            for (t, dst) in transitions.ref_transitions.iter() {
+                writeln!(f, "    {src:?} -> {dst:?} [label=\"{t:?}\"]")?;
+            }
         }
+
+        writeln!(f, "}}")
     }
 }
diff --git a/compiler/rustc_transmute/src/layout/mod.rs b/compiler/rustc_transmute/src/layout/mod.rs
index c4c01a8fac3..c940f7c42a8 100644
--- a/compiler/rustc_transmute/src/layout/mod.rs
+++ b/compiler/rustc_transmute/src/layout/mod.rs
@@ -4,9 +4,6 @@ use std::hash::Hash;
 pub(crate) mod tree;
 pub(crate) use tree::Tree;
 
-pub(crate) mod nfa;
-pub(crate) use nfa::Nfa;
-
 pub(crate) mod dfa;
 pub(crate) use dfa::Dfa;
 
@@ -29,6 +26,13 @@ impl fmt::Debug for Byte {
     }
 }
 
+#[cfg(test)]
+impl From<u8> for Byte {
+    fn from(src: u8) -> Self {
+        Self::Init(src)
+    }
+}
+
 pub(crate) trait Def: Debug + Hash + Eq + PartialEq + Copy + Clone {
     fn has_safety_invariants(&self) -> bool;
 }
diff --git a/compiler/rustc_transmute/src/layout/nfa.rs b/compiler/rustc_transmute/src/layout/nfa.rs
deleted file mode 100644
index 9c21fd94f03..00000000000
--- a/compiler/rustc_transmute/src/layout/nfa.rs
+++ /dev/null
@@ -1,169 +0,0 @@
-use std::fmt;
-use std::sync::atomic::{AtomicU32, Ordering};
-
-use super::{Byte, Ref, Tree, Uninhabited};
-use crate::{Map, Set};
-
-/// A non-deterministic finite automaton (NFA) that represents the layout of a type.
-/// The transmutability of two given types is computed by comparing their `Nfa`s.
-#[derive(PartialEq, Debug)]
-pub(crate) struct Nfa<R>
-where
-    R: Ref,
-{
-    pub(crate) transitions: Map<State, Map<Transition<R>, Set<State>>>,
-    pub(crate) start: State,
-    pub(crate) accepting: State,
-}
-
-/// The states in a `Nfa` represent byte offsets.
-#[derive(Hash, Eq, PartialEq, PartialOrd, Ord, Copy, Clone)]
-pub(crate) struct State(u32);
-
-/// The transitions between states in a `Nfa` reflect bit validity.
-#[derive(Hash, Eq, PartialEq, Clone, Copy)]
-pub(crate) enum Transition<R>
-where
-    R: Ref,
-{
-    Byte(Byte),
-    Ref(R),
-}
-
-impl fmt::Debug for State {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write!(f, "S_{}", self.0)
-    }
-}
-
-impl<R> fmt::Debug for Transition<R>
-where
-    R: Ref,
-{
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        match &self {
-            Self::Byte(b) => b.fmt(f),
-            Self::Ref(r) => r.fmt(f),
-        }
-    }
-}
-
-impl<R> Nfa<R>
-where
-    R: Ref,
-{
-    pub(crate) fn unit() -> Self {
-        let transitions: Map<State, Map<Transition<R>, Set<State>>> = Map::default();
-        let start = State::new();
-        let accepting = start;
-
-        Nfa { transitions, start, accepting }
-    }
-
-    pub(crate) fn from_byte(byte: Byte) -> Self {
-        let mut transitions: Map<State, Map<Transition<R>, Set<State>>> = Map::default();
-        let start = State::new();
-        let accepting = State::new();
-
-        let source = transitions.entry(start).or_default();
-        let edge = source.entry(Transition::Byte(byte)).or_default();
-        edge.insert(accepting);
-
-        Nfa { transitions, start, accepting }
-    }
-
-    pub(crate) fn from_ref(r: R) -> Self {
-        let mut transitions: Map<State, Map<Transition<R>, Set<State>>> = Map::default();
-        let start = State::new();
-        let accepting = State::new();
-
-        let source = transitions.entry(start).or_default();
-        let edge = source.entry(Transition::Ref(r)).or_default();
-        edge.insert(accepting);
-
-        Nfa { transitions, start, accepting }
-    }
-
-    pub(crate) fn from_tree(tree: Tree<!, R>) -> Result<Self, Uninhabited> {
-        Ok(match tree {
-            Tree::Byte(b) => Self::from_byte(b),
-            Tree::Ref(r) => Self::from_ref(r),
-            Tree::Alt(alts) => {
-                let mut alts = alts.into_iter().map(Self::from_tree);
-                let mut nfa = alts.next().ok_or(Uninhabited)??;
-                for alt in alts {
-                    nfa = nfa.union(alt?);
-                }
-                nfa
-            }
-            Tree::Seq(elts) => {
-                let mut nfa = Self::unit();
-                for elt in elts.into_iter().map(Self::from_tree) {
-                    nfa = nfa.concat(elt?);
-                }
-                nfa
-            }
-        })
-    }
-
-    /// Concatenate two `Nfa`s.
-    pub(crate) fn concat(self, other: Self) -> Self {
-        if self.start == self.accepting {
-            return other;
-        } else if other.start == other.accepting {
-            return self;
-        }
-
-        let start = self.start;
-        let accepting = other.accepting;
-
-        let mut transitions: Map<State, Map<Transition<R>, Set<State>>> = self.transitions;
-
-        for (source, transition) in other.transitions {
-            let fix_state = |state| if state == other.start { self.accepting } else { state };
-            let entry = transitions.entry(fix_state(source)).or_default();
-            for (edge, destinations) in transition {
-                let entry = entry.entry(edge).or_default();
-                for destination in destinations {
-                    entry.insert(fix_state(destination));
-                }
-            }
-        }
-
-        Self { transitions, start, accepting }
-    }
-
-    /// Compute the union of two `Nfa`s.
-    pub(crate) fn union(self, other: Self) -> Self {
-        let start = self.start;
-        let accepting = self.accepting;
-
-        let mut transitions: Map<State, Map<Transition<R>, Set<State>>> = self.transitions.clone();
-
-        for (&(mut source), transition) in other.transitions.iter() {
-            // if source is starting state of `other`, replace with starting state of `self`
-            if source == other.start {
-                source = self.start;
-            }
-            let entry = transitions.entry(source).or_default();
-            for (edge, destinations) in transition {
-                let entry = entry.entry(*edge).or_default();
-                for &(mut destination) in destinations {
-                    // if dest is accepting state of `other`, replace with accepting state of `self`
-                    if destination == other.accepting {
-                        destination = self.accepting;
-                    }
-                    entry.insert(destination);
-                }
-            }
-        }
-        Self { transitions, start, accepting }
-    }
-}
-
-impl State {
-    pub(crate) fn new() -> Self {
-        static COUNTER: AtomicU32 = AtomicU32::new(0);
-        Self(COUNTER.fetch_add(1, Ordering::SeqCst))
-    }
-}
diff --git a/compiler/rustc_transmute/src/lib.rs b/compiler/rustc_transmute/src/lib.rs
index 00928137d29..76fa6ceabe7 100644
--- a/compiler/rustc_transmute/src/lib.rs
+++ b/compiler/rustc_transmute/src/lib.rs
@@ -2,7 +2,7 @@
 #![feature(never_type)]
 // tidy-alphabetical-end
 
-pub(crate) use rustc_data_structures::fx::{FxIndexMap as Map, FxIndexSet as Set};
+pub(crate) use rustc_data_structures::fx::FxIndexMap as Map;
 
 pub mod layout;
 mod maybe_transmutable;
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
index 63fabc9c83d..db0e1ab8e98 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs
@@ -4,7 +4,7 @@ pub(crate) mod query_context;
 #[cfg(test)]
 mod tests;
 
-use crate::layout::{self, Byte, Def, Dfa, Nfa, Ref, Tree, Uninhabited, dfa};
+use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, Uninhabited, dfa};
 use crate::maybe_transmutable::query_context::QueryContext;
 use crate::{Answer, Condition, Map, Reason};
 
@@ -73,7 +73,7 @@ where
     /// Answers whether a `Tree` is transmutable into another `Tree`.
     ///
     /// This method begins by de-def'ing `src` and `dst`, and prunes private paths from `dst`,
-    /// then converts `src` and `dst` to `Nfa`s, and computes an answer using those NFAs.
+    /// then converts `src` and `dst` to `Dfa`s, and computes an answer using those DFAs.
     #[inline(always)]
     #[instrument(level = "debug", skip(self), fields(src = ?self.src, dst = ?self.dst))]
     pub(crate) fn answer(self) -> Answer<<C as QueryContext>::Ref> {
@@ -105,22 +105,22 @@ where
 
         trace!(?dst, "pruned dst");
 
-        // Convert `src` from a tree-based representation to an NFA-based
+        // Convert `src` from a tree-based representation to an DFA-based
         // representation. If the conversion fails because `src` is uninhabited,
         // conclude that the transmutation is acceptable, because instances of
         // the `src` type do not exist.
-        let src = match Nfa::from_tree(src) {
+        let src = match Dfa::from_tree(src) {
             Ok(src) => src,
             Err(Uninhabited) => return Answer::Yes,
         };
 
-        // Convert `dst` from a tree-based representation to an NFA-based
+        // Convert `dst` from a tree-based representation to an DFA-based
         // representation. If the conversion fails because `src` is uninhabited,
         // conclude that the transmutation is unacceptable. Valid instances of
         // the `dst` type do not exist, either because it's genuinely
         // uninhabited, or because there are no branches of the tree that are
         // free of safety invariants.
-        let dst = match Nfa::from_tree(dst) {
+        let dst = match Dfa::from_tree(dst) {
             Ok(dst) => dst,
             Err(Uninhabited) => return Answer::No(Reason::DstMayHaveSafetyInvariants),
         };
@@ -129,23 +129,6 @@ where
     }
 }
 
-impl<C> MaybeTransmutableQuery<Nfa<<C as QueryContext>::Ref>, C>
-where
-    C: QueryContext,
-{
-    /// Answers whether a `Nfa` is transmutable into another `Nfa`.
-    ///
-    /// This method converts `src` and `dst` to DFAs, then computes an answer using those DFAs.
-    #[inline(always)]
-    #[instrument(level = "debug", skip(self), fields(src = ?self.src, dst = ?self.dst))]
-    pub(crate) fn answer(self) -> Answer<<C as QueryContext>::Ref> {
-        let Self { src, dst, assume, context } = self;
-        let src = Dfa::from_nfa(src);
-        let dst = Dfa::from_nfa(dst);
-        MaybeTransmutableQuery { src, dst, assume, context }.answer()
-    }
-}
-
 impl<C> MaybeTransmutableQuery<Dfa<<C as QueryContext>::Ref>, C>
 where
     C: QueryContext,
@@ -173,7 +156,7 @@ where
                 src_transitions_len = self.src.transitions.len(),
                 dst_transitions_len = self.dst.transitions.len()
             );
-            let answer = if dst_state == self.dst.accepting {
+            let answer = if dst_state == self.dst.accept {
                 // truncation: `size_of(Src) >= size_of(Dst)`
                 //
                 // Why is truncation OK to do? Because even though the Src is bigger, all we care about
@@ -190,7 +173,7 @@ where
                 // that none of the actually-used data can introduce an invalid state for Dst's type, we
                 // are able to safely transmute, even with truncation.
                 Answer::Yes
-            } else if src_state == self.src.accepting {
+            } else if src_state == self.src.accept {
                 // extension: `size_of(Src) >= size_of(Dst)`
                 if let Some(dst_state_prime) = self.dst.byte_from(dst_state, Byte::Uninit) {
                     self.answer_memo(cache, src_state, dst_state_prime)
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
index 69a6b1b77f4..cc6a4dce17b 100644
--- a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
+++ b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs
@@ -126,7 +126,7 @@ mod bool {
 
         let into_set = |alts: Vec<_>| {
             #[cfg(feature = "rustc")]
-            let mut set = crate::Set::default();
+            let mut set = rustc_data_structures::fx::FxIndexSet::default();
             #[cfg(not(feature = "rustc"))]
             let mut set = std::collections::HashSet::new();
             set.extend(alts);
@@ -174,3 +174,32 @@ mod bool {
         }
     }
 }
+
+mod union {
+    use super::*;
+
+    #[test]
+    fn union() {
+        let [a, b, c, d] = [0, 1, 2, 3];
+        let s = Dfa::from_edges(a, d, &[(a, 0, b), (b, 0, d), (a, 1, c), (c, 1, d)]);
+
+        let t = Dfa::from_edges(a, c, &[(a, 1, b), (b, 0, c)]);
+
+        let mut ctr = 0;
+        let new_state = || {
+            let state = crate::layout::dfa::State(ctr);
+            ctr += 1;
+            state
+        };
+
+        let u = s.clone().union(t.clone(), new_state);
+
+        let expected_u =
+            Dfa::from_edges(b, a, &[(b, 0, c), (b, 1, d), (d, 1, a), (d, 0, a), (c, 0, a)]);
+
+        assert_eq!(u, expected_u);
+
+        assert_eq!(is_transmutable(&s, &u, Assume::default()), Answer::Yes);
+        assert_eq!(is_transmutable(&t, &u, Assume::default()), Answer::Yes);
+    }
+}