diff options
| author | Bryan Garza <1396101+bryangarza@users.noreply.github.com> | 2023-04-21 16:49:36 -0700 | 
|---|---|---|
| committer | Bryan Garza <1396101+bryangarza@users.noreply.github.com> | 2023-05-24 14:52:18 -0700 | 
| commit | 8f1cec8d8472c3ffacedd4783c64182a407c72df (patch) | |
| tree | 79f6bc12055896112e0b0898c13042411638847d /compiler/rustc_transmute/src | |
| parent | 97d328012b9ed9b7d481c40e84aa1f2c65b33ec8 (diff) | |
| download | rust-8f1cec8d8472c3ffacedd4783c64182a407c72df.tar.gz rust-8f1cec8d8472c3ffacedd4783c64182a407c72df.zip | |
Safe Transmute: Enable handling references, including recursive types
This patch enables support for references in Safe Transmute, by generating nested obligations during trait selection. Specifically, when we call `confirm_transmutability_candidate(...)`, we now recursively traverse the `rustc_transmute::Answer` tree and create obligations for all the `Answer` variants, some of which include multiple nested `Answer`s. Also, to handle recursive types, enable support for coinduction for the Safe Transmute trait (`BikeshedIntrinsicFrom`) by adding the `#[rustc_coinduction]` annotation. Also fix some small logic issues when reducing the `or` and `and` combinations in `rustc_transmute`, so that we don't end up with additional redundant `Answer`s in the tree. Co-authored-by: Jack Wrenn <jack@wrenn.fyi>
Diffstat (limited to 'compiler/rustc_transmute/src')
| -rw-r--r-- | compiler/rustc_transmute/src/layout/mod.rs | 37 | ||||
| -rw-r--r-- | compiler/rustc_transmute/src/layout/tree.rs | 11 | ||||
| -rw-r--r-- | compiler/rustc_transmute/src/lib.rs | 11 | ||||
| -rw-r--r-- | compiler/rustc_transmute/src/maybe_transmutable/mod.rs | 170 | 
4 files changed, 178 insertions, 51 deletions
| diff --git a/compiler/rustc_transmute/src/layout/mod.rs b/compiler/rustc_transmute/src/layout/mod.rs index f8d05bc89d2..b318447e581 100644 --- a/compiler/rustc_transmute/src/layout/mod.rs +++ b/compiler/rustc_transmute/src/layout/mod.rs @@ -30,33 +30,46 @@ impl fmt::Debug for Byte { } pub(crate) trait Def: Debug + Hash + Eq + PartialEq + Copy + Clone {} -pub trait Ref: Debug + Hash + Eq + PartialEq + Copy + Clone {} +pub trait Ref: Debug + Hash + Eq + PartialEq + Copy + Clone { + fn min_align(&self) -> usize { + 1 + } + + fn is_mutable(&self) -> bool { + false + } +} impl Def for ! {} impl Ref for ! {} #[cfg(feature = "rustc")] -pub(crate) mod rustc { +pub mod rustc { use rustc_middle::mir::Mutability; - use rustc_middle::ty; - use rustc_middle::ty::Region; - use rustc_middle::ty::Ty; + use rustc_middle::ty::{self, Ty}; /// A reference in the layout. #[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord, Clone, Copy)] pub struct Ref<'tcx> { - lifetime: Region<'tcx>, - ty: Ty<'tcx>, - mutability: Mutability, + pub lifetime: ty::Region<'tcx>, + pub ty: Ty<'tcx>, + pub mutability: Mutability, + pub align: usize, } - impl<'tcx> super::Ref for Ref<'tcx> {} + impl<'tcx> super::Ref for Ref<'tcx> { + fn min_align(&self) -> usize { + self.align + } - impl<'tcx> Ref<'tcx> { - pub fn min_align(&self) -> usize { - todo!() + fn is_mutable(&self) -> bool { + match self.mutability { + Mutability::Mut => true, + Mutability::Not => false, + } } } + impl<'tcx> Ref<'tcx> {} /// A visibility node in the layout. #[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)] diff --git a/compiler/rustc_transmute/src/layout/tree.rs b/compiler/rustc_transmute/src/layout/tree.rs index a6d88b1342a..ed9309b015d 100644 --- a/compiler/rustc_transmute/src/layout/tree.rs +++ b/compiler/rustc_transmute/src/layout/tree.rs @@ -365,6 +365,17 @@ pub(crate) mod rustc { } })) } + + ty::Ref(lifetime, ty, mutability) => { + let align = layout_of(tcx, *ty)?.align(); + Ok(Tree::Ref(Ref { + lifetime: *lifetime, + ty: *ty, + mutability: *mutability, + align, + })) + } + _ => Err(Err::Unspecified), } } diff --git a/compiler/rustc_transmute/src/lib.rs b/compiler/rustc_transmute/src/lib.rs index 77c0526e3aa..c4a99d9eb89 100644 --- a/compiler/rustc_transmute/src/lib.rs +++ b/compiler/rustc_transmute/src/lib.rs @@ -8,7 +8,7 @@ extern crate tracing; pub(crate) use rustc_data_structures::fx::{FxIndexMap as Map, FxIndexSet as Set}; -pub(crate) mod layout; +pub mod layout; pub(crate) mod maybe_transmutable; #[derive(Default)] @@ -21,10 +21,7 @@ pub struct Assume { /// The type encodes answers to the question: "Are these types transmutable?" #[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord, Clone)] -pub enum Answer<R> -where - R: layout::Ref, -{ +pub enum Answer<R> { /// `Src` is transmutable into `Dst`. Yes, @@ -54,6 +51,10 @@ pub enum Reason { DstIsPrivate, /// `Dst` is larger than `Src`, and the excess bytes were not exclusively uninitialized. DstIsTooBig, + /// Src should have a stricter alignment than Dst, but it does not. + DstHasStricterAlignment, + /// Can't go from shared pointer to unique pointer + DstIsMoreUnique, } #[cfg(feature = "rustc")] diff --git a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs index 2e2fb90e71c..d1077488c79 100644 --- a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs +++ b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs @@ -1,13 +1,13 @@ -use crate::Map; -use crate::{Answer, Reason}; - +pub(crate) mod query_context; #[cfg(test)] mod tests; -mod query_context; -use query_context::QueryContext; +use crate::{ + layout::{self, dfa, Byte, Dfa, Nfa, Ref, Tree, Uninhabited}, + maybe_transmutable::query_context::QueryContext, + Answer, Map, Reason, +}; -use crate::layout::{self, dfa, Byte, Dfa, Nfa, Tree, Uninhabited}; pub(crate) struct MaybeTransmutableQuery<L, C> where C: QueryContext, @@ -53,6 +53,7 @@ where } } +// FIXME: Nix this cfg, so we can write unit tests independently of rustc #[cfg(feature = "rustc")] mod rustc { use super::*; @@ -77,12 +78,11 @@ mod rustc { match (src, dst) { // Answer `Yes` here, because 'unknown layout' and type errors will already // be reported by rustc. No need to spam the user with more errors. - (Err(Err::TypeError(_)), _) => Err(Answer::Yes), - (_, Err(Err::TypeError(_))) => Err(Answer::Yes), - (Err(Err::Unknown), _) => Err(Answer::Yes), - (_, Err(Err::Unknown)) => Err(Answer::Yes), - (Err(Err::Unspecified), _) => Err(Answer::No(Reason::SrcIsUnspecified)), - (_, Err(Err::Unspecified)) => Err(Answer::No(Reason::DstIsUnspecified)), + (Err(Err::TypeError(_)), _) | (_, Err(Err::TypeError(_))) => Err(Answer::Yes), + (Err(Err::Unknown), _) | (_, Err(Err::Unknown)) => Err(Answer::Yes), + (Err(Err::Unspecified), _) | (_, Err(Err::Unspecified)) => { + Err(Answer::No(Reason::SrcIsUnspecified)) + } (Ok(src), Ok(dst)) => Ok((src, dst)), } }); @@ -214,34 +214,99 @@ where Answer::No(Reason::DstIsTooBig) } } else { - let src_quantification = if self.assume.validity { + let src_quantifier = if self.assume.validity { // if the compiler may assume that the programmer is doing additional validity checks, // (e.g.: that `src != 3u8` when the destination type is `bool`) // then there must exist at least one transition out of `src_state` such that the transmute is viable... - there_exists + Quantifier::ThereExists } else { // if the compiler cannot assume that the programmer is doing additional validity checks, // then for all transitions out of `src_state`, such that the transmute is viable... - // then there must exist at least one transition out of `src_state` such that the transmute is viable... - for_all + // then there must exist at least one transition out of `dst_state` such that the transmute is viable... + Quantifier::ForAll }; - src_quantification( - self.src.bytes_from(src_state).unwrap_or(&Map::default()), - |(&src_validity, &src_state_prime)| { - if let Some(dst_state_prime) = self.dst.byte_from(dst_state, src_validity) { - self.answer_memo(cache, src_state_prime, dst_state_prime) - } else if let Some(dst_state_prime) = - self.dst.byte_from(dst_state, Byte::Uninit) - { - self.answer_memo(cache, src_state_prime, dst_state_prime) - } else { - Answer::No(Reason::DstIsBitIncompatible) - } - }, - ) + let bytes_answer = src_quantifier.apply( + // for each of the byte transitions out of the `src_state`... + self.src.bytes_from(src_state).unwrap_or(&Map::default()).into_iter().map( + |(&src_validity, &src_state_prime)| { + // ...try to find a matching transition out of `dst_state`. + if let Some(dst_state_prime) = + self.dst.byte_from(dst_state, src_validity) + { + self.answer_memo(cache, src_state_prime, dst_state_prime) + } else if let Some(dst_state_prime) = + // otherwise, see if `dst_state` has any outgoing `Uninit` transitions + // (any init byte is a valid uninit byte) + self.dst.byte_from(dst_state, Byte::Uninit) + { + self.answer_memo(cache, src_state_prime, dst_state_prime) + } else { + // otherwise, we've exhausted our options. + // the DFAs, from this point onwards, are bit-incompatible. + Answer::No(Reason::DstIsBitIncompatible) + } + }, + ), + ); + + // The below early returns reflect how this code would behave: + // if self.assume.validity { + // bytes_answer.or(refs_answer) + // } else { + // bytes_answer.and(refs_answer) + // } + // ...if `refs_answer` was computed lazily. The below early + // returns can be deleted without impacting the correctness of + // the algoritm; only its performance. + match bytes_answer { + Answer::No(..) if !self.assume.validity => return bytes_answer, + Answer::Yes if self.assume.validity => return bytes_answer, + _ => {} + }; + + let refs_answer = src_quantifier.apply( + // for each reference transition out of `src_state`... + self.src.refs_from(src_state).unwrap_or(&Map::default()).into_iter().map( + |(&src_ref, &src_state_prime)| { + // ...there exists a reference transition out of `dst_state`... + Quantifier::ThereExists.apply( + self.dst + .refs_from(dst_state) + .unwrap_or(&Map::default()) + .into_iter() + .map(|(&dst_ref, &dst_state_prime)| { + if !src_ref.is_mutable() && dst_ref.is_mutable() { + Answer::No(Reason::DstIsMoreUnique) + } else if !self.assume.alignment + && src_ref.min_align() < dst_ref.min_align() + { + Answer::No(Reason::DstHasStricterAlignment) + } else { + // ...such that `src` is transmutable into `dst`, if + // `src_ref` is transmutability into `dst_ref`. + Answer::IfTransmutable { src: src_ref, dst: dst_ref } + .and(self.answer_memo( + cache, + src_state_prime, + dst_state_prime, + )) + } + }), + ) + }, + ), + ); + + if self.assume.validity { + bytes_answer.or(refs_answer) + } else { + bytes_answer.and(refs_answer) + } }; - cache.insert((src_state, dst_state), answer.clone()); + if let Some(..) = cache.insert((src_state, dst_state), answer.clone()) { + panic!("failed to correctly cache transmutability") + } answer } } @@ -253,17 +318,21 @@ where { pub(crate) fn and(self, rhs: Self) -> Self { match (self, rhs) { - (Self::No(reason), _) | (_, Self::No(reason)) => Self::No(reason), - (Self::Yes, Self::Yes) => Self::Yes, + (_, Self::No(reason)) | (Self::No(reason), _) => Self::No(reason), + + (Self::Yes, other) | (other, Self::Yes) => other, + (Self::IfAll(mut lhs), Self::IfAll(ref mut rhs)) => { lhs.append(rhs); Self::IfAll(lhs) } + (constraint, Self::IfAll(mut constraints)) | (Self::IfAll(mut constraints), constraint) => { constraints.push(constraint); Self::IfAll(constraints) } + (lhs, rhs) => Self::IfAll(vec![lhs, rhs]), } } @@ -271,7 +340,7 @@ where pub(crate) fn or(self, rhs: Self) -> Self { match (self, rhs) { (Self::Yes, _) | (_, Self::Yes) => Self::Yes, - (Self::No(lhr), Self::No(rhr)) => Self::No(lhr), + (other, Self::No(reason)) | (Self::No(reason), other) => other, (Self::IfAny(mut lhs), Self::IfAny(ref mut rhs)) => { lhs.append(rhs); Self::IfAny(lhs) @@ -319,3 +388,36 @@ where ); result } + +pub enum Quantifier { + ThereExists, + ForAll, +} + +impl Quantifier { + pub fn apply<R, I>(&self, iter: I) -> Answer<R> + where + R: layout::Ref, + I: IntoIterator<Item = Answer<R>>, + { + use std::ops::ControlFlow::{Break, Continue}; + + let (init, try_fold_f): (_, fn(_, _) -> _) = match self { + Self::ThereExists => { + (Answer::No(Reason::DstIsBitIncompatible), |accum: Answer<R>, next| { + match accum.or(next) { + Answer::Yes => Break(Answer::Yes), + maybe => Continue(maybe), + } + }) + } + Self::ForAll => (Answer::Yes, |accum: Answer<R>, next| match accum.and(next) { + Answer::No(reason) => Break(Answer::No(reason)), + maybe => Continue(maybe), + }), + }; + + let (Continue(result) | Break(result)) = iter.into_iter().try_fold(init, try_fold_f); + result + } +} | 
