about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_mir_transform/src/check_alignment.rs105
-rw-r--r--tests/debuginfo/simple-struct.rs2
-rw-r--r--tests/ui/mir/alignment/addrof_alignment.rs (renamed from tests/ui/mir/addrof_alignment.rs)3
-rw-r--r--tests/ui/mir/alignment/i686-pc-windows-msvc.rs (renamed from tests/ui/mir/mir_alignment_check_i686-pc-windows-msvc.rs)4
-rw-r--r--tests/ui/mir/alignment/misaligned_lhs.rs (renamed from tests/ui/mir/mir_alignment_check.rs)4
-rw-r--r--tests/ui/mir/alignment/misaligned_rhs.rs13
-rw-r--r--tests/ui/mir/alignment/packed.rs29
-rw-r--r--tests/ui/mir/alignment/place_computation.rs16
-rw-r--r--tests/ui/mir/alignment/place_without_read.rs9
-rw-r--r--tests/ui/mir/alignment/two_pointers.rs15
10 files changed, 142 insertions, 58 deletions
diff --git a/compiler/rustc_mir_transform/src/check_alignment.rs b/compiler/rustc_mir_transform/src/check_alignment.rs
index 28765af20ad..42b2f18869c 100644
--- a/compiler/rustc_mir_transform/src/check_alignment.rs
+++ b/compiler/rustc_mir_transform/src/check_alignment.rs
@@ -1,13 +1,12 @@
 use crate::MirPass;
-use rustc_hir::def_id::DefId;
 use rustc_hir::lang_items::LangItem;
 use rustc_index::IndexVec;
 use rustc_middle::mir::*;
 use rustc_middle::mir::{
     interpret::Scalar,
-    visit::{PlaceContext, Visitor},
+    visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor},
 };
-use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut};
+use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt, TypeAndMut};
 use rustc_session::Session;
 
 pub struct CheckAlignment;
@@ -30,7 +29,12 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
 
         let basic_blocks = body.basic_blocks.as_mut();
         let local_decls = &mut body.local_decls;
+        let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
 
+        // This pass inserts new blocks. Each insertion changes the Location for all
+        // statements/blocks after. Iterating or visiting the MIR in order would require updating
+        // our current location after every insertion. By iterating backwards, we dodge this issue:
+        // The only Locations that an insertion changes have already been handled.
         for block in (0..basic_blocks.len()).rev() {
             let block = block.into();
             for statement_index in (0..basic_blocks[block].statements.len()).rev() {
@@ -38,22 +42,19 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
                 let statement = &basic_blocks[block].statements[statement_index];
                 let source_info = statement.source_info;
 
-                let mut finder = PointerFinder {
-                    local_decls,
-                    tcx,
-                    pointers: Vec::new(),
-                    def_id: body.source.def_id(),
-                };
-                for (pointer, pointee_ty) in finder.find_pointers(statement) {
-                    debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty);
+                let mut finder =
+                    PointerFinder { tcx, local_decls, param_env, pointers: Vec::new() };
+                finder.visit_statement(statement, location);
 
+                for (local, ty) in finder.pointers {
+                    debug!("Inserting alignment check for {:?}", ty);
                     let new_block = split_block(basic_blocks, location);
                     insert_alignment_check(
                         tcx,
                         local_decls,
                         &mut basic_blocks[block],
-                        pointer,
-                        pointee_ty,
+                        local,
+                        ty,
                         source_info,
                         new_block,
                     );
@@ -63,69 +64,71 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
     }
 }
 
-impl<'tcx, 'a> PointerFinder<'tcx, 'a> {
-    fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> {
-        self.pointers.clear();
-        self.visit_statement(statement, Location::START);
-        core::mem::take(&mut self.pointers)
-    }
-}
-
 struct PointerFinder<'tcx, 'a> {
-    local_decls: &'a mut LocalDecls<'tcx>,
     tcx: TyCtxt<'tcx>,
-    def_id: DefId,
+    local_decls: &'a mut LocalDecls<'tcx>,
+    param_env: ParamEnv<'tcx>,
     pointers: Vec<(Place<'tcx>, Ty<'tcx>)>,
 }
 
 impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> {
-    fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
-        if let Rvalue::AddressOf(..) = rvalue {
-            // Ignore dereferences inside of an AddressOf
-            return;
+    fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
+        // We want to only check reads and writes to Places, so we specifically exclude
+        // Borrows and AddressOf.
+        match context {
+            PlaceContext::MutatingUse(
+                MutatingUseContext::Store
+                | MutatingUseContext::AsmOutput
+                | MutatingUseContext::Call
+                | MutatingUseContext::Yield
+                | MutatingUseContext::Drop,
+            ) => {}
+            PlaceContext::NonMutatingUse(
+                NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
+            ) => {}
+            _ => {
+                return;
+            }
         }
-        self.super_rvalue(rvalue, location);
-    }
 
-    fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
-        if let PlaceContext::NonUse(_) = context {
-            return;
-        }
         if !place.is_indirect() {
             return;
         }
 
+        // Since Deref projections must come first and only once, the pointer for an indirect place
+        // is the Local that the Place is based on.
         let pointer = Place::from(place.local);
-        let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty;
+        let pointer_ty = self.local_decls[place.local].ty;
 
-        // We only want to check unsafe pointers
+        // We only want to check places based on unsafe pointers
         if !pointer_ty.is_unsafe_ptr() {
-            trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty);
+            trace!("Indirect, but not based on an unsafe ptr, not checking {:?}", place);
             return;
         }
 
-        let Some(pointee) = pointer_ty.builtin_deref(true) else {
-            debug!("Indirect but no builtin deref: {:?}", pointer_ty);
+        let pointee_ty =
+            pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer").ty;
+        // Ideally we'd support this in the future, but for now we are limited to sized types.
+        if !pointee_ty.is_sized(self.tcx, self.param_env) {
+            debug!("Unsafe pointer, but pointee is not known to be sized: {:?}", pointer_ty);
             return;
-        };
-        let mut pointee_ty = pointee.ty;
-        if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() {
-            pointee_ty = pointee_ty.sequence_element_type(self.tcx);
         }
 
-        if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) {
-            debug!("Unsafe pointer, but unsized: {:?}", pointer_ty);
+        // Try to detect types we are sure have an alignment of 1 and skip the check
+        // We don't need to look for str and slices, we already rejected unsized types above
+        let element_ty = match pointee_ty.kind() {
+            ty::Array(ty, _) => *ty,
+            _ => pointee_ty,
+        };
+        if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8].contains(&element_ty) {
+            debug!("Trivially aligned place type: {:?}", pointee_ty);
             return;
         }
 
-        if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_]
-            .contains(&pointee_ty)
-        {
-            debug!("Trivially aligned pointee type: {:?}", pointer_ty);
-            return;
-        }
+        // Ensure that this place is based on an aligned pointer.
+        self.pointers.push((pointer, pointee_ty));
 
-        self.pointers.push((pointer, pointee_ty))
+        self.super_place(place, context, location);
     }
 }
 
diff --git a/tests/debuginfo/simple-struct.rs b/tests/debuginfo/simple-struct.rs
index aa3cf023a71..fea8109223a 100644
--- a/tests/debuginfo/simple-struct.rs
+++ b/tests/debuginfo/simple-struct.rs
@@ -1,7 +1,7 @@
 // min-lldb-version: 310
 // ignore-gdb // Test temporarily ignored due to debuginfo tests being disabled, see PR 47155
 
-// compile-flags:-g
+// compile-flags: -g -Zmir-enable-passes=-CheckAlignment
 
 // === GDB TESTS ===================================================================================
 
diff --git a/tests/ui/mir/addrof_alignment.rs b/tests/ui/mir/alignment/addrof_alignment.rs
index 892638bfb92..f3423e97a8a 100644
--- a/tests/ui/mir/addrof_alignment.rs
+++ b/tests/ui/mir/alignment/addrof_alignment.rs
@@ -1,5 +1,4 @@
 // run-pass
-// ignore-wasm32-bare: No panic messages
 // compile-flags: -C debug-assertions
 
 struct Misalignment {
@@ -9,7 +8,7 @@ struct Misalignment {
 fn main() {
     let items: [Misalignment; 2] = [Misalignment { a: 0 }, Misalignment { a: 1 }];
     unsafe {
-        let ptr: *const Misalignment = items.as_ptr().cast::<u8>().add(1).cast::<Misalignment>();
+        let ptr: *const Misalignment = items.as_ptr().byte_add(1);
         let _ptr = core::ptr::addr_of!((*ptr).a);
     }
 }
diff --git a/tests/ui/mir/mir_alignment_check_i686-pc-windows-msvc.rs b/tests/ui/mir/alignment/i686-pc-windows-msvc.rs
index 56388c1047e..74ba1fde649 100644
--- a/tests/ui/mir/mir_alignment_check_i686-pc-windows-msvc.rs
+++ b/tests/ui/mir/alignment/i686-pc-windows-msvc.rs
@@ -11,9 +11,9 @@
 
 fn main() {
     let mut x = [0u64; 2];
-    let ptr: *mut u8 = x.as_mut_ptr().cast::<u8>();
+    let ptr = x.as_mut_ptr();
     unsafe {
-        let misaligned = ptr.add(4).cast::<u64>();
+        let misaligned = ptr.byte_add(4);
         assert!(misaligned.addr() % 8 != 0);
         assert!(misaligned.addr() % 4 == 0);
         *misaligned = 42;
diff --git a/tests/ui/mir/mir_alignment_check.rs b/tests/ui/mir/alignment/misaligned_lhs.rs
index d1bf3d46a7c..97644ba8e09 100644
--- a/tests/ui/mir/mir_alignment_check.rs
+++ b/tests/ui/mir/alignment/misaligned_lhs.rs
@@ -6,8 +6,8 @@
 
 fn main() {
     let mut x = [0u32; 2];
-    let ptr: *mut u8 = x.as_mut_ptr().cast::<u8>();
+    let ptr = x.as_mut_ptr();
     unsafe {
-        *(ptr.add(1).cast::<u32>()) = 42;
+        *(ptr.byte_add(1)) = 42;
     }
 }
diff --git a/tests/ui/mir/alignment/misaligned_rhs.rs b/tests/ui/mir/alignment/misaligned_rhs.rs
new file mode 100644
index 00000000000..8534bc71a3a
--- /dev/null
+++ b/tests/ui/mir/alignment/misaligned_rhs.rs
@@ -0,0 +1,13 @@
+// run-fail
+// ignore-wasm32-bare: No panic messages
+// ignore-i686-pc-windows-msvc: #112480
+// compile-flags: -C debug-assertions
+// error-pattern: misaligned pointer dereference: address must be a multiple of 0x4 but is
+
+fn main() {
+    let mut x = [0u32; 2];
+    let ptr = x.as_mut_ptr();
+    unsafe {
+        let _v = *(ptr.byte_add(1));
+    }
+}
diff --git a/tests/ui/mir/alignment/packed.rs b/tests/ui/mir/alignment/packed.rs
new file mode 100644
index 00000000000..754698591e3
--- /dev/null
+++ b/tests/ui/mir/alignment/packed.rs
@@ -0,0 +1,29 @@
+// run-pass
+// compile-flags: -C debug-assertions
+
+#![feature(strict_provenance, pointer_is_aligned)]
+
+#[repr(packed)]
+struct Misaligner {
+    _head: u8,
+    tail: u64,
+}
+
+fn main() {
+    let memory = [Misaligner { _head: 0, tail: 0}, Misaligner { _head: 0, tail: 0}];
+    // Test that we can use addr_of! to get the address of a packed member which according to its
+    // type is not aligned, but because it is a projection from a packed type is a valid place.
+    let ptr0 = std::ptr::addr_of!(memory[0].tail);
+    let ptr1 = std::ptr::addr_of!(memory[0].tail);
+    // Even if ptr0 happens to be aligned by chance, ptr1 is not.
+    assert!(!ptr0.is_aligned() || !ptr1.is_aligned());
+
+    // And also test that we can get the addr of a packed struct then do a member read from it.
+    unsafe {
+        let ptr = std::ptr::addr_of!(memory[0]);
+        let _tail = (*ptr).tail;
+
+        let ptr = std::ptr::addr_of!(memory[1]);
+        let _tail = (*ptr).tail;
+    }
+}
diff --git a/tests/ui/mir/alignment/place_computation.rs b/tests/ui/mir/alignment/place_computation.rs
new file mode 100644
index 00000000000..fdd4864250a
--- /dev/null
+++ b/tests/ui/mir/alignment/place_computation.rs
@@ -0,0 +1,16 @@
+// run-pass
+// compile-flags: -C debug-assertions
+
+#[repr(align(8))]
+struct Misalignment {
+    a: u8,
+}
+
+fn main() {
+    let mem = 0u64;
+    let ptr = &mem as *const u64 as *const Misalignment;
+    unsafe {
+        let ptr = ptr.byte_add(1);
+        let _ref: &u8 = &(*ptr).a;
+    }
+}
diff --git a/tests/ui/mir/alignment/place_without_read.rs b/tests/ui/mir/alignment/place_without_read.rs
new file mode 100644
index 00000000000..b4be7a50f61
--- /dev/null
+++ b/tests/ui/mir/alignment/place_without_read.rs
@@ -0,0 +1,9 @@
+// run-pass
+// compile-flags: -C debug-assertions
+
+fn main() {
+    let ptr = 1 as *const u16;
+    unsafe {
+        let _ = *ptr;
+    }
+}
diff --git a/tests/ui/mir/alignment/two_pointers.rs b/tests/ui/mir/alignment/two_pointers.rs
new file mode 100644
index 00000000000..29af21dffc1
--- /dev/null
+++ b/tests/ui/mir/alignment/two_pointers.rs
@@ -0,0 +1,15 @@
+// run-fail
+// ignore-wasm32-bare: No panic messages
+// ignore-i686-pc-windows-msvc: #112480
+// compile-flags: -C debug-assertions
+// error-pattern: misaligned pointer dereference: address must be a multiple of 0x4 but is
+
+fn main() {
+    let x = [0u32; 2];
+    let ptr = x.as_ptr();
+    let mut dest = 0u32;
+    let dest_ptr = &mut dest as *mut u32;
+    unsafe {
+        *dest_ptr = *(ptr.byte_add(1));
+    }
+}