diff options
| author | Joshua Liebow-Feeser <hello@joshlf.com> | 2025-04-10 13:45:39 -0700 |
|---|---|---|
| committer | Joshua Liebow-Feeser <hello@joshlf.com> | 2025-04-23 11:45:00 -0700 |
| commit | 4326a44e6f0859077b7789d42416b9291b0ff4d1 (patch) | |
| tree | 713ed8f802314f672e8d51725b158eb1a6c1dc9e /compiler/rustc_transmute/src/maybe_transmutable | |
| parent | be181dd75c83d72fcc95538e235768bc367b76b9 (diff) | |
| download | rust-4326a44e6f0859077b7789d42416b9291b0ff4d1.tar.gz rust-4326a44e6f0859077b7789d42416b9291b0ff4d1.zip | |
transmutability: Mark edges by ranges, not values
In the `Tree` and `Dfa` representations of a type's layout, store byte ranges rather than needing to separately store each byte value. This permits us to, for example, represent a `u8` using a single 0..=255 edge in the DFA rather than using 256 separate edges. This leads to drastic performance improvements. For example, on the author's 2024 MacBook Pro, the time to convert the `Tree` representation of a `u64` to its equivalent DFA representation drops from ~8.5ms to ~1us, a reduction of ~8,500x. See `bench_dfa_from_tree`. Similarly, the time to execute a transmutability query from `u64` to `u64` drops from ~35us to ~1.7us, a reduction of ~20x. See `bench_transmute`.
Diffstat (limited to 'compiler/rustc_transmute/src/maybe_transmutable')
3 files changed, 343 insertions, 72 deletions
diff --git a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs index db0e1ab8e98..0a19cccc2ed 100644 --- a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs +++ b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs @@ -1,10 +1,14 @@ +use std::rc::Rc; +use std::{cmp, iter}; + +use itertools::Either; use tracing::{debug, instrument, trace}; pub(crate) mod query_context; #[cfg(test)] mod tests; -use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, Uninhabited, dfa}; +use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, dfa}; use crate::maybe_transmutable::query_context::QueryContext; use crate::{Answer, Condition, Map, Reason}; @@ -111,7 +115,7 @@ where // the `src` type do not exist. let src = match Dfa::from_tree(src) { Ok(src) => src, - Err(Uninhabited) => return Answer::Yes, + Err(layout::Uninhabited) => return Answer::Yes, }; // Convert `dst` from a tree-based representation to an DFA-based @@ -122,7 +126,7 @@ where // free of safety invariants. let dst = match Dfa::from_tree(dst) { Ok(dst) => dst, - Err(Uninhabited) => return Answer::No(Reason::DstMayHaveSafetyInvariants), + Err(layout::Uninhabited) => return Answer::No(Reason::DstMayHaveSafetyInvariants), }; MaybeTransmutableQuery { src, dst, assume, context }.answer() @@ -174,8 +178,8 @@ where // are able to safely transmute, even with truncation. Answer::Yes } else if src_state == self.src.accept { - // extension: `size_of(Src) >= size_of(Dst)` - if let Some(dst_state_prime) = self.dst.byte_from(dst_state, Byte::Uninit) { + // extension: `size_of(Src) <= size_of(Dst)` + if let Some(dst_state_prime) = self.dst.get_uninit_edge_dst(dst_state) { self.answer_memo(cache, src_state, dst_state_prime) } else { Answer::No(Reason::DstIsTooBig) @@ -193,26 +197,120 @@ where Quantifier::ForAll }; + let c = &core::cell::RefCell::new(&mut *cache); let bytes_answer = src_quantifier.apply( - // for each of the byte transitions out of the `src_state`... - self.src.bytes_from(src_state).unwrap_or(&Map::default()).into_iter().map( - |(&src_validity, &src_state_prime)| { - // ...try to find a matching transition out of `dst_state`. - if let Some(dst_state_prime) = - self.dst.byte_from(dst_state, src_validity) - { - self.answer_memo(cache, src_state_prime, dst_state_prime) - } else if let Some(dst_state_prime) = - // otherwise, see if `dst_state` has any outgoing `Uninit` transitions - // (any init byte is a valid uninit byte) - self.dst.byte_from(dst_state, Byte::Uninit) - { - self.answer_memo(cache, src_state_prime, dst_state_prime) - } else { - // otherwise, we've exhausted our options. - // the DFAs, from this point onwards, are bit-incompatible. - Answer::No(Reason::DstIsBitIncompatible) + // for each of the byte set transitions out of the `src_state`... + self.src.bytes_from(src_state).flat_map( + move |(src_validity, src_state_prime)| { + // ...find all matching transitions out of `dst_state`. + + let Some(src_validity) = src_validity.range() else { + // NOTE: We construct an iterator here rather + // than just computing the value directly (via + // `self.answer_memo`) so that, if the iterator + // we produce from this branch is + // short-circuited, we don't waste time + // computing `self.answer_memo` unnecessarily. + // That will specifically happen if + // `src_quantifier == Quantifier::ThereExists`, + // since we emit `Answer::Yes` first (before + // chaining `answer_iter`). + let answer_iter = if let Some(dst_state_prime) = + self.dst.get_uninit_edge_dst(dst_state) + { + Either::Left(iter::once_with(move || { + let mut c = c.borrow_mut(); + self.answer_memo(&mut *c, src_state_prime, dst_state_prime) + })) + } else { + Either::Right(iter::once(Answer::No( + Reason::DstIsBitIncompatible, + ))) + }; + + // When `answer == Answer::No(...)`, there are + // two cases to consider: + // - If `assume.validity`, then we should + // succeed because the user is responsible for + // ensuring that the *specific* byte value + // appearing at runtime is valid for the + // destination type. When `assume.validity`, + // `src_quantifier == + // Quantifier::ThereExists`, so adding an + // `Answer::Yes` has the effect of ensuring + // that the "there exists" is always + // satisfied. + // - If `!assume.validity`, then we should fail. + // In this case, `src_quantifier == + // Quantifier::ForAll`, so adding an + // `Answer::Yes` has no effect. + return Either::Left(iter::once(Answer::Yes).chain(answer_iter)); + }; + + #[derive(Copy, Clone, Debug)] + struct Accum { + // The number of matching byte edges that we + // have found in the destination so far. + sum: usize, + found_uninit: bool, } + + let accum1 = Rc::new(std::cell::Cell::new(Accum { + sum: 0, + found_uninit: false, + })); + let accum2 = Rc::clone(&accum1); + let sv = src_validity.clone(); + let update_accum = move |mut accum: Accum, dst_validity: Byte| { + if let Some(dst_validity) = dst_validity.range() { + // Only add the part of `dst_validity` that + // overlaps with `src_validity`. + let start = cmp::max(*sv.start(), *dst_validity.start()); + let end = cmp::min(*sv.end(), *dst_validity.end()); + + // We add 1 here to account for the fact + // that `end` is an inclusive bound. + accum.sum += 1 + usize::from(end.saturating_sub(start)); + } else { + accum.found_uninit = true; + } + accum + }; + + let answers = self + .dst + .states_from(dst_state, src_validity.clone()) + .map(move |(dst_validity, dst_state_prime)| { + let mut c = c.borrow_mut(); + accum1.set(update_accum(accum1.get(), dst_validity)); + let answer = + self.answer_memo(&mut *c, src_state_prime, dst_state_prime); + answer + }) + .chain( + iter::once_with(move || { + let src_validity_len = usize::from(*src_validity.end()) + - usize::from(*src_validity.start()) + + 1; + let accum = accum2.get(); + + // If this condition is false, then + // there are some byte values in the + // source which have no corresponding + // transition in the destination DFA. In + // that case, we add a `No` to our list + // of answers. When + // `!self.assume.validity`, this will + // cause the query to fail. + if accum.found_uninit || accum.sum == src_validity_len { + None + } else { + Some(Answer::No(Reason::DstIsBitIncompatible)) + } + }) + .flatten(), + ); + Either::Right(answers) }, ), ); @@ -235,48 +333,38 @@ where let refs_answer = src_quantifier.apply( // for each reference transition out of `src_state`... - self.src.refs_from(src_state).unwrap_or(&Map::default()).into_iter().map( - |(&src_ref, &src_state_prime)| { - // ...there exists a reference transition out of `dst_state`... - Quantifier::ThereExists.apply( - self.dst - .refs_from(dst_state) - .unwrap_or(&Map::default()) - .into_iter() - .map(|(&dst_ref, &dst_state_prime)| { - if !src_ref.is_mutable() && dst_ref.is_mutable() { - Answer::No(Reason::DstIsMoreUnique) - } else if !self.assume.alignment - && src_ref.min_align() < dst_ref.min_align() - { - Answer::No(Reason::DstHasStricterAlignment { - src_min_align: src_ref.min_align(), - dst_min_align: dst_ref.min_align(), - }) - } else if dst_ref.size() > src_ref.size() { - Answer::No(Reason::DstRefIsTooBig { - src: src_ref, - dst: dst_ref, - }) - } else { - // ...such that `src` is transmutable into `dst`, if - // `src_ref` is transmutability into `dst_ref`. - and( - Answer::If(Condition::IfTransmutable { - src: src_ref, - dst: dst_ref, - }), - self.answer_memo( - cache, - src_state_prime, - dst_state_prime, - ), - ) - } - }), - ) - }, - ), + self.src.refs_from(src_state).map(|(src_ref, src_state_prime)| { + // ...there exists a reference transition out of `dst_state`... + Quantifier::ThereExists.apply(self.dst.refs_from(dst_state).map( + |(dst_ref, dst_state_prime)| { + if !src_ref.is_mutable() && dst_ref.is_mutable() { + Answer::No(Reason::DstIsMoreUnique) + } else if !self.assume.alignment + && src_ref.min_align() < dst_ref.min_align() + { + Answer::No(Reason::DstHasStricterAlignment { + src_min_align: src_ref.min_align(), + dst_min_align: dst_ref.min_align(), + }) + } else if dst_ref.size() > src_ref.size() { + Answer::No(Reason::DstRefIsTooBig { + src: src_ref, + dst: dst_ref, + }) + } else { + // ...such that `src` is transmutable into `dst`, if + // `src_ref` is transmutability into `dst_ref`. + and( + Answer::If(Condition::IfTransmutable { + src: src_ref, + dst: dst_ref, + }), + self.answer_memo(cache, src_state_prime, dst_state_prime), + ) + } + }, + )) + }), ); if self.assume.validity { diff --git a/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs b/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs index f8b59bdf326..214da101be3 100644 --- a/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs +++ b/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs @@ -8,9 +8,17 @@ pub(crate) trait QueryContext { #[cfg(test)] pub(crate) mod test { + use std::marker::PhantomData; + use super::QueryContext; - pub(crate) struct UltraMinimal; + pub(crate) struct UltraMinimal<R = !>(PhantomData<R>); + + impl<R> Default for UltraMinimal<R> { + fn default() -> Self { + Self(PhantomData) + } + } #[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)] pub(crate) enum Def { @@ -24,9 +32,9 @@ pub(crate) mod test { } } - impl QueryContext for UltraMinimal { + impl<R: crate::layout::Ref> QueryContext for UltraMinimal<R> { type Def = Def; - type Ref = !; + type Ref = R; } } diff --git a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs index cc6a4dce17b..24e2a1acadd 100644 --- a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs +++ b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs @@ -1,3 +1,5 @@ +extern crate test; + use itertools::Itertools; use super::query_context::test::{Def, UltraMinimal}; @@ -12,15 +14,25 @@ trait Representation { impl Representation for Tree { fn is_transmutable(src: Self, dst: Self, assume: Assume) -> Answer<!> { - crate::maybe_transmutable::MaybeTransmutableQuery::new(src, dst, assume, UltraMinimal) - .answer() + crate::maybe_transmutable::MaybeTransmutableQuery::new( + src, + dst, + assume, + UltraMinimal::default(), + ) + .answer() } } impl Representation for Dfa { fn is_transmutable(src: Self, dst: Self, assume: Assume) -> Answer<!> { - crate::maybe_transmutable::MaybeTransmutableQuery::new(src, dst, assume, UltraMinimal) - .answer() + crate::maybe_transmutable::MaybeTransmutableQuery::new( + src, + dst, + assume, + UltraMinimal::default(), + ) + .answer() } } @@ -89,6 +101,36 @@ mod safety { } } +mod size { + use super::*; + + #[test] + fn size() { + let small = Tree::number(1); + let large = Tree::number(2); + + for alignment in [false, true] { + for lifetimes in [false, true] { + for safety in [false, true] { + for validity in [false, true] { + let assume = Assume { alignment, lifetimes, safety, validity }; + assert_eq!( + is_transmutable(&small, &large, assume), + Answer::No(Reason::DstIsTooBig), + "assume: {assume:?}" + ); + assert_eq!( + is_transmutable(&large, &small, assume), + Answer::Yes, + "assume: {assume:?}" + ); + } + } + } + } + } +} + mod bool { use super::*; @@ -113,6 +155,27 @@ mod bool { } #[test] + fn transmute_u8() { + let bool = &Tree::bool(); + let u8 = &Tree::u8(); + for (src, dst, assume_validity, answer) in [ + (bool, u8, false, Answer::Yes), + (bool, u8, true, Answer::Yes), + (u8, bool, false, Answer::No(Reason::DstIsBitIncompatible)), + (u8, bool, true, Answer::Yes), + ] { + assert_eq!( + is_transmutable( + src, + dst, + Assume { validity: assume_validity, ..Assume::default() } + ), + answer + ); + } + } + + #[test] fn should_permit_validity_expansion_and_reject_contraction() { let b0 = layout::Tree::<Def, !>::from_bits(0); let b1 = layout::Tree::<Def, !>::from_bits(1); @@ -175,6 +238,62 @@ mod bool { } } +mod uninit { + use super::*; + + #[test] + fn size() { + let mu = Tree::uninit(); + let u8 = Tree::u8(); + + for alignment in [false, true] { + for lifetimes in [false, true] { + for safety in [false, true] { + for validity in [false, true] { + let assume = Assume { alignment, lifetimes, safety, validity }; + + let want = if validity { + Answer::Yes + } else { + Answer::No(Reason::DstIsBitIncompatible) + }; + + assert_eq!(is_transmutable(&mu, &u8, assume), want, "assume: {assume:?}"); + assert_eq!( + is_transmutable(&u8, &mu, assume), + Answer::Yes, + "assume: {assume:?}" + ); + } + } + } + } + } +} + +mod alt { + use super::*; + use crate::Answer; + + #[test] + fn should_permit_identity_transmutation() { + type Tree = layout::Tree<Def, !>; + + let x = Tree::Seq(vec![Tree::from_bits(0), Tree::from_bits(0)]); + let y = Tree::Seq(vec![Tree::bool(), Tree::from_bits(1)]); + let layout = Tree::Alt(vec![x, y]); + + let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new( + layout.clone(), + layout.clone(), + crate::Assume::default(), + UltraMinimal::default(), + ) + .answer(); + assert_eq!(answer, Answer::Yes, "layout:{:#?}", layout); + } +} + mod union { use super::*; @@ -203,3 +322,59 @@ mod union { assert_eq!(is_transmutable(&t, &u, Assume::default()), Answer::Yes); } } + +mod r#ref { + use super::*; + + #[test] + fn should_permit_identity_transmutation() { + type Tree = crate::layout::Tree<Def, [(); 1]>; + + let layout = Tree::Seq(vec![Tree::from_bits(0), Tree::Ref([()])]); + + let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new( + layout.clone(), + layout, + Assume::default(), + UltraMinimal::default(), + ) + .answer(); + assert_eq!(answer, Answer::If(crate::Condition::IfTransmutable { src: [()], dst: [()] })); + } +} + +mod benches { + use std::hint::black_box; + + use test::Bencher; + + use super::*; + + #[bench] + fn bench_dfa_from_tree(b: &mut Bencher) { + let num = Tree::number(8).prune(&|_| false); + let num = black_box(num); + + b.iter(|| { + let _ = black_box(Dfa::from_tree(num.clone())); + }) + } + + #[bench] + fn bench_transmute(b: &mut Bencher) { + let num = Tree::number(8).prune(&|_| false); + let dfa = black_box(Dfa::from_tree(num).unwrap()); + + b.iter(|| { + let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new( + dfa.clone(), + dfa.clone(), + Assume::default(), + UltraMinimal::default(), + ) + .answer(); + let answer = std::hint::black_box(answer); + assert_eq!(answer, Answer::Yes); + }) + } +} |
