diff options
| author | Bastian Kersting <bkersting@google.com> | 2024-12-16 14:42:49 +0000 |
|---|---|---|
| committer | Bastian Kersting <bkersting@google.com> | 2025-01-30 10:08:37 +0000 |
| commit | 851322b74db9ac91a1b9d206c5f80fc51a97f7c1 (patch) | |
| tree | aabf2e13e79c3531464352c606558c30d44ee708 | |
| parent | e6f12c8b7d8d5c821c32fb2739ea01d1d874c58a (diff) | |
| download | rust-851322b74db9ac91a1b9d206c5f80fc51a97f7c1.tar.gz rust-851322b74db9ac91a1b9d206c5f80fc51a97f7c1.zip | |
Refactor `PointerFinder` into a separate module
This also parameterize the "excluded pointee types" and exposes a general method for inserting checks on pointers. This is a preparation for adding a NullCheck that makes use of the same code.
| -rw-r--r-- | compiler/rustc_mir_transform/src/check_alignment.rs | 198 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/check_pointers.rs | 233 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/lib.rs | 1 |
3 files changed, 272 insertions, 160 deletions
diff --git a/compiler/rustc_mir_transform/src/check_alignment.rs b/compiler/rustc_mir_transform/src/check_alignment.rs index d7e22c12394..ca5564e447a 100644 --- a/compiler/rustc_mir_transform/src/check_alignment.rs +++ b/compiler/rustc_mir_transform/src/check_alignment.rs @@ -1,11 +1,10 @@ -use rustc_hir::lang_items::LangItem; use rustc_index::IndexVec; use rustc_middle::mir::interpret::Scalar; -use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; -use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_middle::ty::{Ty, TyCtxt}; use rustc_session::Session; -use tracing::{debug, trace}; + +use crate::check_pointers::{BorrowCheckMode, PointerCheck, check_pointers}; pub(super) struct CheckAlignment; @@ -19,46 +18,19 @@ impl<'tcx> crate::MirPass<'tcx> for CheckAlignment { } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // This pass emits new panics. If for whatever reason we do not have a panic - // implementation, running this pass may cause otherwise-valid code to not compile. - if tcx.lang_items().get(LangItem::PanicImpl).is_none() { - return; - } - - let typing_env = body.typing_env(tcx); - let basic_blocks = body.basic_blocks.as_mut(); - let local_decls = &mut body.local_decls; - - // This pass inserts new blocks. Each insertion changes the Location for all - // statements/blocks after. Iterating or visiting the MIR in order would require updating - // our current location after every insertion. By iterating backwards, we dodge this issue: - // The only Locations that an insertion changes have already been handled. - for block in (0..basic_blocks.len()).rev() { - let block = block.into(); - for statement_index in (0..basic_blocks[block].statements.len()).rev() { - let location = Location { block, statement_index }; - let statement = &basic_blocks[block].statements[statement_index]; - let source_info = statement.source_info; - - let mut finder = - PointerFinder { tcx, local_decls, typing_env, pointers: Vec::new() }; - finder.visit_statement(statement, location); - - for (local, ty) in finder.pointers { - debug!("Inserting alignment check for {:?}", ty); - let new_block = split_block(basic_blocks, location); - insert_alignment_check( - tcx, - local_decls, - &mut basic_blocks[block], - local, - ty, - source_info, - new_block, - ); - } - } - } + // Skip trivially aligned place types. + let excluded_pointees = [tcx.types.bool, tcx.types.i8, tcx.types.u8]; + + // We have to exclude borrows here: in `&x.field`, the exact + // requirement is that the final reference must be aligned, but + // `check_pointers` would check that `x` is aligned, which would be wrong. + check_pointers( + tcx, + body, + &excluded_pointees, + insert_alignment_check, + BorrowCheckMode::ExcludeBorrows, + ); } fn is_required(&self) -> bool { @@ -66,119 +38,33 @@ impl<'tcx> crate::MirPass<'tcx> for CheckAlignment { } } -struct PointerFinder<'a, 'tcx> { - tcx: TyCtxt<'tcx>, - local_decls: &'a mut LocalDecls<'tcx>, - typing_env: ty::TypingEnv<'tcx>, - pointers: Vec<(Place<'tcx>, Ty<'tcx>)>, -} - -impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> { - fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { - // We want to only check reads and writes to Places, so we specifically exclude - // Borrow and RawBorrow. - match context { - PlaceContext::MutatingUse( - MutatingUseContext::Store - | MutatingUseContext::AsmOutput - | MutatingUseContext::Call - | MutatingUseContext::Yield - | MutatingUseContext::Drop, - ) => {} - PlaceContext::NonMutatingUse( - NonMutatingUseContext::Copy | NonMutatingUseContext::Move, - ) => {} - _ => { - return; - } - } - - if !place.is_indirect() { - return; - } - - // Since Deref projections must come first and only once, the pointer for an indirect place - // is the Local that the Place is based on. - let pointer = Place::from(place.local); - let pointer_ty = self.local_decls[place.local].ty; - - // We only want to check places based on unsafe pointers - if !pointer_ty.is_unsafe_ptr() { - trace!("Indirect, but not based on an unsafe ptr, not checking {:?}", place); - return; - } - - let pointee_ty = - pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer"); - // Ideally we'd support this in the future, but for now we are limited to sized types. - if !pointee_ty.is_sized(self.tcx, self.typing_env) { - debug!("Unsafe pointer, but pointee is not known to be sized: {:?}", pointer_ty); - return; - } - - // Try to detect types we are sure have an alignment of 1 and skip the check - // We don't need to look for str and slices, we already rejected unsized types above - let element_ty = match pointee_ty.kind() { - ty::Array(ty, _) => *ty, - _ => pointee_ty, - }; - if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8].contains(&element_ty) { - debug!("Trivially aligned place type: {:?}", pointee_ty); - return; - } - - // Ensure that this place is based on an aligned pointer. - self.pointers.push((pointer, pointee_ty)); - - self.super_place(place, context, location); - } -} - -fn split_block( - basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>, - location: Location, -) -> BasicBlock { - let block_data = &mut basic_blocks[location.block]; - - // Drain every statement after this one and move the current terminator to a new basic block - let new_block = BasicBlockData { - statements: block_data.statements.split_off(location.statement_index), - terminator: block_data.terminator.take(), - is_cleanup: block_data.is_cleanup, - }; - - basic_blocks.push(new_block) -} - +/// Inserts the actual alignment check's logic. Returns a +/// [AssertKind::MisalignedPointerDereference] on failure. fn insert_alignment_check<'tcx>( tcx: TyCtxt<'tcx>, - local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, - block_data: &mut BasicBlockData<'tcx>, pointer: Place<'tcx>, pointee_ty: Ty<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + stmts: &mut Vec<Statement<'tcx>>, source_info: SourceInfo, - new_block: BasicBlock, -) { - // Cast the pointer to a *const () +) -> PointerCheck<'tcx> { + // Cast the pointer to a *const (). let const_raw_ptr = Ty::new_imm_ptr(tcx, tcx.types.unit); let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr); let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into(); - block_data - .statements + stmts .push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) }); - // Transmute the pointer to a usize (equivalent to `ptr.addr()`) + // Transmute the pointer to a usize (equivalent to `ptr.addr()`). let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize); let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); - block_data - .statements - .push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) }); + stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) }); // Get the alignment of the pointee let alignment = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty); - block_data.statements.push(Statement { + stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((alignment, rvalue))), }); @@ -191,7 +77,7 @@ fn insert_alignment_check<'tcx>( user_ty: None, const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)), tcx.types.usize), })); - block_data.statements.push(Statement { + stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new(( alignment_mask, @@ -202,7 +88,7 @@ fn insert_alignment_check<'tcx>( // BitAnd the alignment mask with the pointer let alignment_bits = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); - block_data.statements.push(Statement { + stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new(( alignment_bits, @@ -220,7 +106,7 @@ fn insert_alignment_check<'tcx>( user_ty: None, const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), tcx.types.usize), })); - block_data.statements.push(Statement { + stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new(( is_ok, @@ -228,21 +114,13 @@ fn insert_alignment_check<'tcx>( ))), }); - // Set this block's terminator to our assert, continuing to new_block if we pass - block_data.terminator = Some(Terminator { - source_info, - kind: TerminatorKind::Assert { - cond: Operand::Copy(is_ok), - expected: true, - target: new_block, - msg: Box::new(AssertKind::MisalignedPointerDereference { - required: Operand::Copy(alignment), - found: Operand::Copy(addr), - }), - // This calls panic_misaligned_pointer_dereference, which is #[rustc_nounwind]. - // We never want to insert an unwind into unsafe code, because unwinding could - // make a failing UB check turn into much worse UB when we start unwinding. - unwind: UnwindAction::Unreachable, - }, - }); + // Emit a check that asserts on the alignment and otherwise triggers a + // AssertKind::MisalignedPointerDereference. + PointerCheck { + cond: Operand::Copy(is_ok), + assert_kind: Box::new(AssertKind::MisalignedPointerDereference { + required: Operand::Copy(alignment), + found: Operand::Copy(addr), + }), + } } diff --git a/compiler/rustc_mir_transform/src/check_pointers.rs b/compiler/rustc_mir_transform/src/check_pointers.rs new file mode 100644 index 00000000000..95d3ce05000 --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_pointers.rs @@ -0,0 +1,233 @@ +use rustc_hir::lang_items::LangItem; +use rustc_index::IndexVec; +use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use tracing::{debug, trace}; + +/// Details of a pointer check, the condition on which we decide whether to +/// fail the assert and an [AssertKind] that defines the behavior on failure. +pub(crate) struct PointerCheck<'tcx> { + pub(crate) cond: Operand<'tcx>, + pub(crate) assert_kind: Box<AssertKind<Operand<'tcx>>>, +} + +/// Indicates whether we insert the checks for borrow places of a raw pointer. +/// Concretely places with [MutatingUseContext::Borrow] or +/// [NonMutatingUseContext::SharedBorrow]. +#[derive(Copy, Clone)] +pub(crate) enum BorrowCheckMode { + ExcludeBorrows, +} + +/// Utility for adding a check for read/write on every sized, raw pointer. +/// +/// Visits every read/write access to a [Sized], raw pointer and inserts a +/// new basic block directly before the pointer access. (Read/write accesses +/// are determined by the `PlaceContext` of the MIR visitor.) Then calls +/// `on_finding` to insert the actual logic for a pointer check (e.g. check for +/// alignment). A check can choose to be inserted for (mutable) borrows of +/// raw pointers via the `borrow_check_mode` parameter. +/// +/// This utility takes care of the right order of blocks, the only thing a +/// caller must do in `on_finding` is: +/// - Append [Statement]s to `stmts`. +/// - Append [LocalDecl]s to `local_decls`. +/// - Return a [PointerCheck] that contains the condition and an [AssertKind]. +/// The AssertKind must be a panic with `#[rustc_nounwind]`. The condition +/// should always return the boolean `is_ok`, so evaluate to true in case of +/// success and fail the check otherwise. +/// This utility will insert a terminator block that asserts on the condition +/// and panics on failure. +pub(crate) fn check_pointers<'a, 'tcx, F>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + excluded_pointees: &'a [Ty<'tcx>], + on_finding: F, + borrow_check_mode: BorrowCheckMode, +) where + F: Fn( + /* tcx: */ TyCtxt<'tcx>, + /* pointer: */ Place<'tcx>, + /* pointee_ty: */ Ty<'tcx>, + /* local_decls: */ &mut IndexVec<Local, LocalDecl<'tcx>>, + /* stmts: */ &mut Vec<Statement<'tcx>>, + /* source_info: */ SourceInfo, + ) -> PointerCheck<'tcx>, +{ + // This pass emits new panics. If for whatever reason we do not have a panic + // implementation, running this pass may cause otherwise-valid code to not compile. + if tcx.lang_items().get(LangItem::PanicImpl).is_none() { + return; + } + + let typing_env = body.typing_env(tcx); + let basic_blocks = body.basic_blocks.as_mut(); + let local_decls = &mut body.local_decls; + + // This operation inserts new blocks. Each insertion changes the Location for all + // statements/blocks after. Iterating or visiting the MIR in order would require updating + // our current location after every insertion. By iterating backwards, we dodge this issue: + // The only Locations that an insertion changes have already been handled. + for block in (0..basic_blocks.len()).rev() { + let block = block.into(); + for statement_index in (0..basic_blocks[block].statements.len()).rev() { + let location = Location { block, statement_index }; + let statement = &basic_blocks[block].statements[statement_index]; + let source_info = statement.source_info; + + let mut finder = PointerFinder::new( + tcx, + local_decls, + typing_env, + excluded_pointees, + borrow_check_mode, + ); + finder.visit_statement(statement, location); + + for (local, ty) in finder.into_found_pointers() { + debug!("Inserting check for {:?}", ty); + let new_block = split_block(basic_blocks, location); + + // Invoke `on_finding` which appends to `local_decls` and the + // blocks statements. It returns information about the assert + // we're performing in the Terminator. + let block_data = &mut basic_blocks[block]; + let pointer_check = on_finding( + tcx, + local, + ty, + local_decls, + &mut block_data.statements, + source_info, + ); + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: pointer_check.cond, + expected: true, + target: new_block, + msg: pointer_check.assert_kind, + // This calls a panic function associated with the pointer check, which + // is #[rustc_nounwind]. We never want to insert an unwind into unsafe + // code, because unwinding could make a failing UB check turn into much + // worse UB when we start unwinding. + unwind: UnwindAction::Unreachable, + }, + }); + } + } + } +} + +struct PointerFinder<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + local_decls: &'a mut LocalDecls<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + pointers: Vec<(Place<'tcx>, Ty<'tcx>)>, + excluded_pointees: &'a [Ty<'tcx>], + borrow_check_mode: BorrowCheckMode, +} + +impl<'a, 'tcx> PointerFinder<'a, 'tcx> { + fn new( + tcx: TyCtxt<'tcx>, + local_decls: &'a mut LocalDecls<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + excluded_pointees: &'a [Ty<'tcx>], + borrow_check_mode: BorrowCheckMode, + ) -> Self { + PointerFinder { + tcx, + local_decls, + typing_env, + excluded_pointees, + pointers: Vec::new(), + borrow_check_mode, + } + } + + fn into_found_pointers(self) -> Vec<(Place<'tcx>, Ty<'tcx>)> { + self.pointers + } + + /// Whether or not we should visit a [Place] with [PlaceContext]. + /// + /// We generally only visit Reads/Writes to a place and only Borrows if + /// requested. + fn should_visit_place(&self, context: PlaceContext) -> bool { + match context { + PlaceContext::MutatingUse( + MutatingUseContext::Store + | MutatingUseContext::Call + | MutatingUseContext::Yield + | MutatingUseContext::Drop, + ) => true, + PlaceContext::NonMutatingUse( + NonMutatingUseContext::Copy | NonMutatingUseContext::Move, + ) => true, + PlaceContext::MutatingUse(MutatingUseContext::Borrow) + | PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow) => { + !matches!(self.borrow_check_mode, BorrowCheckMode::ExcludeBorrows) + } + _ => false, + } + } +} + +impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> { + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { + if !self.should_visit_place(context) || !place.is_indirect() { + return; + } + + // Since Deref projections must come first and only once, the pointer for an indirect place + // is the Local that the Place is based on. + let pointer = Place::from(place.local); + let pointer_ty = self.local_decls[place.local].ty; + + // We only want to check places based on raw pointers + if !pointer_ty.is_unsafe_ptr() { + trace!("Indirect, but not based on an raw ptr, not checking {:?}", place); + return; + } + + let pointee_ty = + pointer_ty.builtin_deref(true).expect("no builtin_deref for an raw pointer"); + // Ideally we'd support this in the future, but for now we are limited to sized types. + if !pointee_ty.is_sized(self.tcx, self.typing_env) { + trace!("Raw pointer, but pointee is not known to be sized: {:?}", pointer_ty); + return; + } + + // We don't need to look for slices, we already rejected unsized types above. + let element_ty = match pointee_ty.kind() { + ty::Array(ty, _) => *ty, + _ => pointee_ty, + }; + if self.excluded_pointees.contains(&element_ty) { + trace!("Skipping pointer for type: {:?}", pointee_ty); + return; + } + + self.pointers.push((pointer, pointee_ty)); + + self.super_place(place, context, location); + } +} + +fn split_block( + basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>, + location: Location, +) -> BasicBlock { + let block_data = &mut basic_blocks[location.block]; + + // Drain every statement after this one and move the current terminator to a new basic block. + let new_block = BasicBlockData { + statements: block_data.statements.split_off(location.statement_index), + terminator: block_data.terminator.take(), + is_cleanup: block_data.is_cleanup, + }; + + basic_blocks.push(new_block) +} diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 2dc55e3614e..0df9ce0c3d3 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -44,6 +44,7 @@ use std::sync::LazyLock; use pass_manager::{self as pm, Lint, MirLint, MirPass, WithMinOptLevel}; +mod check_pointers; mod cost_checker; mod cross_crate_inline; mod deduce_param_attrs; |
