about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src/check_alignment.rs
diff options
context:
space:
mode:
authorBen Kimock <kimockb@gmail.com>2022-11-10 11:37:28 -0500
committerBen Kimock <kimockb@gmail.com>2023-03-23 18:23:06 -0400
commit8ccf53332e2ab70fa4efed5716ddcbb61e98dac2 (patch)
tree632b4dfa8f9fcdc21d09b8fa93312d0f6a306b00 /compiler/rustc_mir_transform/src/check_alignment.rs
parente2163008763c326ec4003e07b8e6eef0c98f6204 (diff)
downloadrust-8ccf53332e2ab70fa4efed5716ddcbb61e98dac2.tar.gz
rust-8ccf53332e2ab70fa4efed5716ddcbb61e98dac2.zip
A MIR transform that checks pointers are aligned
Diffstat (limited to 'compiler/rustc_mir_transform/src/check_alignment.rs')
-rw-r--r--compiler/rustc_mir_transform/src/check_alignment.rs220
1 files changed, 220 insertions, 0 deletions
diff --git a/compiler/rustc_mir_transform/src/check_alignment.rs b/compiler/rustc_mir_transform/src/check_alignment.rs
new file mode 100644
index 00000000000..d654c973e02
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/check_alignment.rs
@@ -0,0 +1,220 @@
+use crate::MirPass;
+use rustc_hir::def_id::DefId;
+use rustc_index::vec::IndexVec;
+use rustc_middle::mir::*;
+use rustc_middle::mir::{
+    interpret::{ConstValue, Scalar},
+    visit::{PlaceContext, Visitor},
+};
+use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut};
+use rustc_session::Session;
+
+pub struct CheckAlignment;
+
+impl<'tcx> MirPass<'tcx> for CheckAlignment {
+    fn is_enabled(&self, sess: &Session) -> bool {
+        sess.opts.debug_assertions
+    }
+
+    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+        let basic_blocks = body.basic_blocks.as_mut();
+        let local_decls = &mut body.local_decls;
+
+        for block in (0..basic_blocks.len()).rev() {
+            let block = block.into();
+            for statement_index in (0..basic_blocks[block].statements.len()).rev() {
+                let location = Location { block, statement_index };
+                let statement = &basic_blocks[block].statements[statement_index];
+                let source_info = statement.source_info;
+
+                let mut finder = PointerFinder {
+                    local_decls,
+                    tcx,
+                    pointers: Vec::new(),
+                    def_id: body.source.def_id(),
+                };
+                for (pointer, pointee_ty) in finder.find_pointers(statement) {
+                    debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty);
+
+                    let new_block = split_block(basic_blocks, location);
+                    insert_alignment_check(
+                        tcx,
+                        local_decls,
+                        &mut basic_blocks[block],
+                        pointer,
+                        pointee_ty,
+                        source_info,
+                        new_block,
+                    );
+                }
+            }
+        }
+    }
+}
+
+impl<'tcx, 'a> PointerFinder<'tcx, 'a> {
+    fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> {
+        self.pointers.clear();
+        self.visit_statement(statement, Location::START);
+        core::mem::take(&mut self.pointers)
+    }
+}
+
+struct PointerFinder<'tcx, 'a> {
+    local_decls: &'a mut LocalDecls<'tcx>,
+    tcx: TyCtxt<'tcx>,
+    def_id: DefId,
+    pointers: Vec<(Place<'tcx>, Ty<'tcx>)>,
+}
+
+impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> {
+    fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
+        if let PlaceContext::NonUse(_) = context {
+            return;
+        }
+        if !place.is_indirect() {
+            return;
+        }
+
+        let pointer = Place::from(place.local);
+        let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty;
+
+        // We only want to check unsafe pointers
+        if !pointer_ty.is_unsafe_ptr() {
+            trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty);
+            return;
+        }
+
+        let Some(pointee) = pointer_ty.builtin_deref(true) else {
+            debug!("Indirect but no builtin deref: {:?}", pointer_ty);
+            return;
+        };
+        let mut pointee_ty = pointee.ty;
+        if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() {
+            pointee_ty = pointee_ty.sequence_element_type(self.tcx);
+        }
+
+        if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) {
+            debug!("Unsafe pointer, but unsized: {:?}", pointer_ty);
+            return;
+        }
+
+        self.pointers.push((pointer, pointee_ty))
+    }
+}
+
+fn split_block(
+    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
+    location: Location,
+) -> BasicBlock {
+    let block_data = &mut basic_blocks[location.block];
+
+    // Drain every statement after this one and move the current terminator to a new basic block
+    let new_block = BasicBlockData {
+        statements: block_data.statements.drain(location.statement_index..).collect(),
+        terminator: block_data.terminator.take(),
+        is_cleanup: block_data.is_cleanup,
+    };
+
+    basic_blocks.push(new_block)
+}
+
+fn insert_alignment_check<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    local_decls: &mut LocalDecls<'tcx>,
+    block_data: &mut BasicBlockData<'tcx>,
+    pointer: Place<'tcx>,
+    pointee_ty: Ty<'tcx>,
+    source_info: SourceInfo,
+    new_block: BasicBlock,
+) {
+    // Cast the pointer to a *const ()
+    let const_raw_ptr = tcx.mk_ptr(TypeAndMut { ty: tcx.types.unit, mutbl: Mutability::Not });
+    let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr);
+    let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into();
+    block_data
+        .statements
+        .push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) });
+
+    // Cast the pointer to a usize
+    let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize);
+    let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
+    block_data
+        .statements
+        .push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });
+
+    // Get the alignment of the pointee
+    let alignment =
+        local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
+    let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty);
+    block_data.statements.push(Statement {
+        source_info,
+        kind: StatementKind::Assign(Box::new((alignment, rvalue))),
+    });
+
+    // Subtract 1 from the alignment to get the alignment mask
+    let alignment_mask =
+        local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
+    let one = Operand::Constant(Box::new(Constant {
+        span: source_info.span,
+        user_ty: None,
+        literal: ConstantKind::Val(
+            ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)),
+            tcx.types.usize,
+        ),
+    }));
+    block_data.statements.push(Statement {
+        source_info,
+        kind: StatementKind::Assign(Box::new((
+            alignment_mask,
+            Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(alignment), one))),
+        ))),
+    });
+
+    // BitAnd the alignment mask with the pointer
+    let alignment_bits =
+        local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
+    block_data.statements.push(Statement {
+        source_info,
+        kind: StatementKind::Assign(Box::new((
+            alignment_bits,
+            Rvalue::BinaryOp(
+                BinOp::BitAnd,
+                Box::new((Operand::Copy(addr), Operand::Copy(alignment_mask))),
+            ),
+        ))),
+    });
+
+    // Check if the alignment bits are all zero
+    let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
+    let zero = Operand::Constant(Box::new(Constant {
+        span: source_info.span,
+        user_ty: None,
+        literal: ConstantKind::Val(
+            ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)),
+            tcx.types.usize,
+        ),
+    }));
+    block_data.statements.push(Statement {
+        source_info,
+        kind: StatementKind::Assign(Box::new((
+            is_ok,
+            Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))),
+        ))),
+    });
+
+    // Set this block's terminator to our assert, continuing to new_block if we pass
+    block_data.terminator = Some(Terminator {
+        source_info,
+        kind: TerminatorKind::Assert {
+            cond: Operand::Copy(is_ok),
+            expected: true,
+            target: new_block,
+            msg: AssertKind::MisalignedPointerDereference {
+                required: Operand::Copy(alignment),
+                found: Operand::Copy(addr),
+            },
+            cleanup: None,
+        },
+    });
+}