about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2025-04-29 05:36:44 +0000
committerbors <bors@rust-lang.org>2025-04-29 05:36:44 +0000
commit4c83e55e2d88ff93155be2784b9f64b91b870e99 (patch)
treea30af0c932392e54d81d8a992b629ad043a8b663 /compiler/rustc_mir_transform/src
parent1b8ab72680f36e783af84c1a3c4f8508572bd9f9 (diff)
parent7082fa27a792e34a9474db396ccf3958143211af (diff)
downloadrust-4c83e55e2d88ff93155be2784b9f64b91b870e99.tar.gz
rust-4c83e55e2d88ff93155be2784b9f64b91b870e99.zip
Auto merge of #137940 - 1c3t3a:alignment-borrows-check, r=saethlin
Extend the alignment check to borrows

The current alignment check does not include checks for creating misaligned references from raw pointers, which is now added in this patch.

When inserting the check we need to be careful with references to field projections (e.g. `&(*ptr).a`), in which case the resulting reference must be aligned according to the field type and not the type of the pointer.

r? `@saethlin`

cc `@RalfJung,` after our discussion in #134424
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/check_alignment.rs10
-rw-r--r--compiler/rustc_mir_transform/src/check_null.rs10
-rw-r--r--compiler/rustc_mir_transform/src/check_pointers.rs60
3 files changed, 48 insertions, 32 deletions
diff --git a/compiler/rustc_mir_transform/src/check_alignment.rs b/compiler/rustc_mir_transform/src/check_alignment.rs
index 5115583f37c..8f88613b79f 100644
--- a/compiler/rustc_mir_transform/src/check_alignment.rs
+++ b/compiler/rustc_mir_transform/src/check_alignment.rs
@@ -6,7 +6,7 @@ use rustc_middle::mir::*;
 use rustc_middle::ty::{Ty, TyCtxt};
 use rustc_session::Session;
 
-use crate::check_pointers::{BorrowCheckMode, PointerCheck, check_pointers};
+use crate::check_pointers::{BorrowedFieldProjectionMode, PointerCheck, check_pointers};
 
 pub(super) struct CheckAlignment;
 
@@ -19,15 +19,15 @@ impl<'tcx> crate::MirPass<'tcx> for CheckAlignment {
         // Skip trivially aligned place types.
         let excluded_pointees = [tcx.types.bool, tcx.types.i8, tcx.types.u8];
 
-        // We have to exclude borrows here: in `&x.field`, the exact
-        // requirement is that the final reference must be aligned, but
-        // `check_pointers` would check that `x` is aligned, which would be wrong.
+        // When checking the alignment of references to field projections (`&(*ptr).a`),
+        // we need to make sure that the reference is aligned according to the field type
+        // and not to the pointer type.
         check_pointers(
             tcx,
             body,
             &excluded_pointees,
             insert_alignment_check,
-            BorrowCheckMode::ExcludeBorrows,
+            BorrowedFieldProjectionMode::FollowProjections,
         );
     }
 
diff --git a/compiler/rustc_mir_transform/src/check_null.rs b/compiler/rustc_mir_transform/src/check_null.rs
index 543e1845e65..ad74e335bd9 100644
--- a/compiler/rustc_mir_transform/src/check_null.rs
+++ b/compiler/rustc_mir_transform/src/check_null.rs
@@ -4,7 +4,7 @@ use rustc_middle::mir::*;
 use rustc_middle::ty::{Ty, TyCtxt};
 use rustc_session::Session;
 
-use crate::check_pointers::{BorrowCheckMode, PointerCheck, check_pointers};
+use crate::check_pointers::{BorrowedFieldProjectionMode, PointerCheck, check_pointers};
 
 pub(super) struct CheckNull;
 
@@ -14,7 +14,13 @@ impl<'tcx> crate::MirPass<'tcx> for CheckNull {
     }
 
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
-        check_pointers(tcx, body, &[], insert_null_check, BorrowCheckMode::IncludeBorrows);
+        check_pointers(
+            tcx,
+            body,
+            &[],
+            insert_null_check,
+            BorrowedFieldProjectionMode::NoFollowProjections,
+        );
     }
 
     fn is_required(&self) -> bool {
diff --git a/compiler/rustc_mir_transform/src/check_pointers.rs b/compiler/rustc_mir_transform/src/check_pointers.rs
index 2d04b621935..bf94f1aad24 100644
--- a/compiler/rustc_mir_transform/src/check_pointers.rs
+++ b/compiler/rustc_mir_transform/src/check_pointers.rs
@@ -12,13 +12,13 @@ pub(crate) struct PointerCheck<'tcx> {
     pub(crate) assert_kind: Box<AssertKind<Operand<'tcx>>>,
 }
 
-/// Indicates whether we insert the checks for borrow places of a raw pointer.
-/// Concretely places with [MutatingUseContext::Borrow] or
-/// [NonMutatingUseContext::SharedBorrow].
+/// When checking for borrows of field projections (`&(*ptr).a`), we might want
+/// to check for the field type (type of `.a` in the example). This enum defines
+/// the variations (pass the pointer [Ty] or the field [Ty]).
 #[derive(Copy, Clone)]
-pub(crate) enum BorrowCheckMode {
-    IncludeBorrows,
-    ExcludeBorrows,
+pub(crate) enum BorrowedFieldProjectionMode {
+    FollowProjections,
+    NoFollowProjections,
 }
 
 /// Utility for adding a check for read/write on every sized, raw pointer.
@@ -27,8 +27,8 @@ pub(crate) enum BorrowCheckMode {
 /// new basic block directly before the pointer access. (Read/write accesses
 /// are determined by the `PlaceContext` of the MIR visitor.) Then calls
 /// `on_finding` to insert the actual logic for a pointer check (e.g. check for
-/// alignment). A check can choose to be inserted for (mutable) borrows of
-/// raw pointers via the `borrow_check_mode` parameter.
+/// alignment). A check can choose to follow borrows of field projections via
+/// the `field_projection_mode` parameter.
 ///
 /// This utility takes care of the right order of blocks, the only thing a
 /// caller must do in `on_finding` is:
@@ -45,7 +45,7 @@ pub(crate) fn check_pointers<'tcx, F>(
     body: &mut Body<'tcx>,
     excluded_pointees: &[Ty<'tcx>],
     on_finding: F,
-    borrow_check_mode: BorrowCheckMode,
+    field_projection_mode: BorrowedFieldProjectionMode,
 ) where
     F: Fn(
         /* tcx: */ TyCtxt<'tcx>,
@@ -82,7 +82,7 @@ pub(crate) fn check_pointers<'tcx, F>(
                 local_decls,
                 typing_env,
                 excluded_pointees,
-                borrow_check_mode,
+                field_projection_mode,
             );
             finder.visit_statement(statement, location);
 
@@ -128,7 +128,7 @@ struct PointerFinder<'a, 'tcx> {
     typing_env: ty::TypingEnv<'tcx>,
     pointers: Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)>,
     excluded_pointees: &'a [Ty<'tcx>],
-    borrow_check_mode: BorrowCheckMode,
+    field_projection_mode: BorrowedFieldProjectionMode,
 }
 
 impl<'a, 'tcx> PointerFinder<'a, 'tcx> {
@@ -137,7 +137,7 @@ impl<'a, 'tcx> PointerFinder<'a, 'tcx> {
         local_decls: &'a mut LocalDecls<'tcx>,
         typing_env: ty::TypingEnv<'tcx>,
         excluded_pointees: &'a [Ty<'tcx>],
-        borrow_check_mode: BorrowCheckMode,
+        field_projection_mode: BorrowedFieldProjectionMode,
     ) -> Self {
         PointerFinder {
             tcx,
@@ -145,7 +145,7 @@ impl<'a, 'tcx> PointerFinder<'a, 'tcx> {
             typing_env,
             excluded_pointees,
             pointers: Vec::new(),
-            borrow_check_mode,
+            field_projection_mode,
         }
     }
 
@@ -163,15 +163,14 @@ impl<'a, 'tcx> PointerFinder<'a, 'tcx> {
                 MutatingUseContext::Store
                 | MutatingUseContext::Call
                 | MutatingUseContext::Yield
-                | MutatingUseContext::Drop,
+                | MutatingUseContext::Drop
+                | MutatingUseContext::Borrow,
             ) => true,
             PlaceContext::NonMutatingUse(
-                NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
+                NonMutatingUseContext::Copy
+                | NonMutatingUseContext::Move
+                | NonMutatingUseContext::SharedBorrow,
             ) => true,
-            PlaceContext::MutatingUse(MutatingUseContext::Borrow)
-            | PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow) => {
-                matches!(self.borrow_check_mode, BorrowCheckMode::IncludeBorrows)
-            }
             _ => false,
         }
     }
@@ -183,19 +182,29 @@ impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> {
             return;
         }
 
-        // Since Deref projections must come first and only once, the pointer for an indirect place
-        // is the Local that the Place is based on.
+        // Get the place and type we visit.
         let pointer = Place::from(place.local);
-        let pointer_ty = self.local_decls[place.local].ty;
+        let pointer_ty = pointer.ty(self.local_decls, self.tcx).ty;
 
         // We only want to check places based on raw pointers
-        if !pointer_ty.is_raw_ptr() {
+        let &ty::RawPtr(mut pointee_ty, _) = pointer_ty.kind() else {
             trace!("Indirect, but not based on an raw ptr, not checking {:?}", place);
             return;
+        };
+
+        // If we see a borrow of a field projection, we want to pass the field type to the
+        // check and not the pointee type.
+        if matches!(self.field_projection_mode, BorrowedFieldProjectionMode::FollowProjections)
+            && matches!(
+                context,
+                PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow)
+                    | PlaceContext::MutatingUse(MutatingUseContext::Borrow)
+            )
+        {
+            // Naturally, the field type is type of the initial place we look at.
+            pointee_ty = place.ty(self.local_decls, self.tcx).ty;
         }
 
-        let pointee_ty =
-            pointer_ty.builtin_deref(true).expect("no builtin_deref for an raw pointer");
         // Ideally we'd support this in the future, but for now we are limited to sized types.
         if !pointee_ty.is_sized(self.tcx, self.typing_env) {
             trace!("Raw pointer, but pointee is not known to be sized: {:?}", pointer_ty);
@@ -207,6 +216,7 @@ impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> {
             ty::Array(ty, _) => *ty,
             _ => pointee_ty,
         };
+        // Check if we excluded this pointee type from the check.
         if self.excluded_pointees.contains(&element_ty) {
             trace!("Skipping pointer for type: {:?}", pointee_ty);
             return;