about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2025-02-05 19:36:41 +0000
committerMichael Goulet <michael@errs.io>2025-02-08 21:38:16 +0000
commitb4641b2b3fd0616fa9597537f8fd954c34e6535f (patch)
tree183c6c4e11f0d552e7832dc48f66c25202bbcbe9 /compiler/rustc_mir_transform
parent73bf7947e9ab731bf2764db219cd9cda216a3aed (diff)
downloadrust-b4641b2b3fd0616fa9597537f8fd954c34e6535f.tar.gz
rust-b4641b2b3fd0616fa9597537f8fd954c34e6535f.zip
Detect (non-raw) borrows of null ZST pointers in CheckNull
Diffstat (limited to 'compiler/rustc_mir_transform')
-rw-r--r--compiler/rustc_mir_transform/src/check_alignment.rs2
-rw-r--r--compiler/rustc_mir_transform/src/check_null.rs69
-rw-r--r--compiler/rustc_mir_transform/src/check_pointers.rs14
3 files changed, 56 insertions, 29 deletions
diff --git a/compiler/rustc_mir_transform/src/check_alignment.rs b/compiler/rustc_mir_transform/src/check_alignment.rs
index ca5564e447a..b70cca14840 100644
--- a/compiler/rustc_mir_transform/src/check_alignment.rs
+++ b/compiler/rustc_mir_transform/src/check_alignment.rs
@@ -1,5 +1,6 @@
 use rustc_index::IndexVec;
 use rustc_middle::mir::interpret::Scalar;
+use rustc_middle::mir::visit::PlaceContext;
 use rustc_middle::mir::*;
 use rustc_middle::ty::{Ty, TyCtxt};
 use rustc_session::Session;
@@ -44,6 +45,7 @@ fn insert_alignment_check<'tcx>(
     tcx: TyCtxt<'tcx>,
     pointer: Place<'tcx>,
     pointee_ty: Ty<'tcx>,
+    _context: PlaceContext,
     local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
     stmts: &mut Vec<Statement<'tcx>>,
     source_info: SourceInfo,
diff --git a/compiler/rustc_mir_transform/src/check_null.rs b/compiler/rustc_mir_transform/src/check_null.rs
index 0b6c0ceaac1..543e1845e65 100644
--- a/compiler/rustc_mir_transform/src/check_null.rs
+++ b/compiler/rustc_mir_transform/src/check_null.rs
@@ -1,5 +1,5 @@
 use rustc_index::IndexVec;
-use rustc_middle::mir::interpret::Scalar;
+use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext};
 use rustc_middle::mir::*;
 use rustc_middle::ty::{Ty, TyCtxt};
 use rustc_session::Session;
@@ -26,6 +26,7 @@ fn insert_null_check<'tcx>(
     tcx: TyCtxt<'tcx>,
     pointer: Place<'tcx>,
     pointee_ty: Ty<'tcx>,
+    context: PlaceContext,
     local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
     stmts: &mut Vec<Statement<'tcx>>,
     source_info: SourceInfo,
@@ -42,30 +43,51 @@ fn insert_null_check<'tcx>(
     let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
     stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });
 
-    // Get size of the pointee (zero-sized reads and writes are allowed).
-    let rvalue = Rvalue::NullaryOp(NullOp::SizeOf, pointee_ty);
-    let sizeof_pointee =
-        local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
-    stmts.push(Statement {
-        source_info,
-        kind: StatementKind::Assign(Box::new((sizeof_pointee, rvalue))),
-    });
-
-    // Check that the pointee is not a ZST.
     let zero = Operand::Constant(Box::new(ConstOperand {
         span: source_info.span,
         user_ty: None,
-        const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), tcx.types.usize),
+        const_: Const::Val(ConstValue::from_target_usize(0, &tcx), tcx.types.usize),
     }));
-    let is_pointee_no_zst =
-        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
-    stmts.push(Statement {
-        source_info,
-        kind: StatementKind::Assign(Box::new((
-            is_pointee_no_zst,
-            Rvalue::BinaryOp(BinOp::Ne, Box::new((Operand::Copy(sizeof_pointee), zero.clone()))),
-        ))),
-    });
+
+    let pointee_should_be_checked = match context {
+        // Borrows pointing to "null" are UB even if the pointee is a ZST.
+        PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow)
+        | PlaceContext::MutatingUse(MutatingUseContext::Borrow) => {
+            // Pointer should be checked unconditionally.
+            Operand::Constant(Box::new(ConstOperand {
+                span: source_info.span,
+                user_ty: None,
+                const_: Const::Val(ConstValue::from_bool(true), tcx.types.bool),
+            }))
+        }
+        // Other usages of null pointers only are UB if the pointee is not a ZST.
+        _ => {
+            let rvalue = Rvalue::NullaryOp(NullOp::SizeOf, pointee_ty);
+            let sizeof_pointee =
+                local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
+            stmts.push(Statement {
+                source_info,
+                kind: StatementKind::Assign(Box::new((sizeof_pointee, rvalue))),
+            });
+
+            // Check that the pointee is not a ZST.
+            let is_pointee_not_zst =
+                local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
+            stmts.push(Statement {
+                source_info,
+                kind: StatementKind::Assign(Box::new((
+                    is_pointee_not_zst,
+                    Rvalue::BinaryOp(
+                        BinOp::Ne,
+                        Box::new((Operand::Copy(sizeof_pointee), zero.clone())),
+                    ),
+                ))),
+            });
+
+            // Pointer needs to be checked only if pointee is not a ZST.
+            Operand::Copy(is_pointee_not_zst)
+        }
+    };
 
     // Check whether the pointer is null.
     let is_null = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
@@ -77,7 +99,8 @@ fn insert_null_check<'tcx>(
         ))),
     });
 
-    // We want to throw an exception if the pointer is null and doesn't point to a ZST.
+    // We want to throw an exception if the pointer is null and the pointee is not unconditionally
+    // allowed (which for all non-borrow place uses, is when the pointee is ZST).
     let should_throw_exception =
         local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
     stmts.push(Statement {
@@ -86,7 +109,7 @@ fn insert_null_check<'tcx>(
             should_throw_exception,
             Rvalue::BinaryOp(
                 BinOp::BitAnd,
-                Box::new((Operand::Copy(is_null), Operand::Copy(is_pointee_no_zst))),
+                Box::new((Operand::Copy(is_null), pointee_should_be_checked)),
             ),
         ))),
     });
diff --git a/compiler/rustc_mir_transform/src/check_pointers.rs b/compiler/rustc_mir_transform/src/check_pointers.rs
index 72460542f87..ccaa83fd9e2 100644
--- a/compiler/rustc_mir_transform/src/check_pointers.rs
+++ b/compiler/rustc_mir_transform/src/check_pointers.rs
@@ -40,10 +40,10 @@ pub(crate) enum BorrowCheckMode {
 ///   success and fail the check otherwise.
 /// This utility will insert a terminator block that asserts on the condition
 /// and panics on failure.
-pub(crate) fn check_pointers<'a, 'tcx, F>(
+pub(crate) fn check_pointers<'tcx, F>(
     tcx: TyCtxt<'tcx>,
     body: &mut Body<'tcx>,
-    excluded_pointees: &'a [Ty<'tcx>],
+    excluded_pointees: &[Ty<'tcx>],
     on_finding: F,
     borrow_check_mode: BorrowCheckMode,
 ) where
@@ -51,6 +51,7 @@ pub(crate) fn check_pointers<'a, 'tcx, F>(
         /* tcx: */ TyCtxt<'tcx>,
         /* pointer: */ Place<'tcx>,
         /* pointee_ty: */ Ty<'tcx>,
+        /* context: */ PlaceContext,
         /* local_decls: */ &mut IndexVec<Local, LocalDecl<'tcx>>,
         /* stmts: */ &mut Vec<Statement<'tcx>>,
         /* source_info: */ SourceInfo,
@@ -86,7 +87,7 @@ pub(crate) fn check_pointers<'a, 'tcx, F>(
             );
             finder.visit_statement(statement, location);
 
-            for (local, ty) in finder.into_found_pointers() {
+            for (local, ty, context) in finder.into_found_pointers() {
                 debug!("Inserting check for {:?}", ty);
                 let new_block = split_block(basic_blocks, location);
 
@@ -98,6 +99,7 @@ pub(crate) fn check_pointers<'a, 'tcx, F>(
                     tcx,
                     local,
                     ty,
+                    context,
                     local_decls,
                     &mut block_data.statements,
                     source_info,
@@ -125,7 +127,7 @@ struct PointerFinder<'a, 'tcx> {
     tcx: TyCtxt<'tcx>,
     local_decls: &'a mut LocalDecls<'tcx>,
     typing_env: ty::TypingEnv<'tcx>,
-    pointers: Vec<(Place<'tcx>, Ty<'tcx>)>,
+    pointers: Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)>,
     excluded_pointees: &'a [Ty<'tcx>],
     borrow_check_mode: BorrowCheckMode,
 }
@@ -148,7 +150,7 @@ impl<'a, 'tcx> PointerFinder<'a, 'tcx> {
         }
     }
 
-    fn into_found_pointers(self) -> Vec<(Place<'tcx>, Ty<'tcx>)> {
+    fn into_found_pointers(self) -> Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)> {
         self.pointers
     }
 
@@ -211,7 +213,7 @@ impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> {
             return;
         }
 
-        self.pointers.push((pointer, pointee_ty));
+        self.pointers.push((pointer, pointee_ty, context));
 
         self.super_place(place, context, location);
     }