about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src/check_null.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/check_null.rs')
-rw-r--r--compiler/rustc_mir_transform/src/check_null.rs69
1 files changed, 46 insertions, 23 deletions
diff --git a/compiler/rustc_mir_transform/src/check_null.rs b/compiler/rustc_mir_transform/src/check_null.rs
index 0b6c0ceaac1..543e1845e65 100644
--- a/compiler/rustc_mir_transform/src/check_null.rs
+++ b/compiler/rustc_mir_transform/src/check_null.rs
@@ -1,5 +1,5 @@
 use rustc_index::IndexVec;
-use rustc_middle::mir::interpret::Scalar;
+use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext};
 use rustc_middle::mir::*;
 use rustc_middle::ty::{Ty, TyCtxt};
 use rustc_session::Session;
@@ -26,6 +26,7 @@ fn insert_null_check<'tcx>(
     tcx: TyCtxt<'tcx>,
     pointer: Place<'tcx>,
     pointee_ty: Ty<'tcx>,
+    context: PlaceContext,
     local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
     stmts: &mut Vec<Statement<'tcx>>,
     source_info: SourceInfo,
@@ -42,30 +43,51 @@ fn insert_null_check<'tcx>(
     let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
     stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });
 
-    // Get size of the pointee (zero-sized reads and writes are allowed).
-    let rvalue = Rvalue::NullaryOp(NullOp::SizeOf, pointee_ty);
-    let sizeof_pointee =
-        local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
-    stmts.push(Statement {
-        source_info,
-        kind: StatementKind::Assign(Box::new((sizeof_pointee, rvalue))),
-    });
-
-    // Check that the pointee is not a ZST.
     let zero = Operand::Constant(Box::new(ConstOperand {
         span: source_info.span,
         user_ty: None,
-        const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), tcx.types.usize),
+        const_: Const::Val(ConstValue::from_target_usize(0, &tcx), tcx.types.usize),
     }));
-    let is_pointee_no_zst =
-        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
-    stmts.push(Statement {
-        source_info,
-        kind: StatementKind::Assign(Box::new((
-            is_pointee_no_zst,
-            Rvalue::BinaryOp(BinOp::Ne, Box::new((Operand::Copy(sizeof_pointee), zero.clone()))),
-        ))),
-    });
+
+    let pointee_should_be_checked = match context {
+        // Borrows pointing to "null" are UB even if the pointee is a ZST.
+        PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow)
+        | PlaceContext::MutatingUse(MutatingUseContext::Borrow) => {
+            // Pointer should be checked unconditionally.
+            Operand::Constant(Box::new(ConstOperand {
+                span: source_info.span,
+                user_ty: None,
+                const_: Const::Val(ConstValue::from_bool(true), tcx.types.bool),
+            }))
+        }
+        // Other usages of null pointers only are UB if the pointee is not a ZST.
+        _ => {
+            let rvalue = Rvalue::NullaryOp(NullOp::SizeOf, pointee_ty);
+            let sizeof_pointee =
+                local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
+            stmts.push(Statement {
+                source_info,
+                kind: StatementKind::Assign(Box::new((sizeof_pointee, rvalue))),
+            });
+
+            // Check that the pointee is not a ZST.
+            let is_pointee_not_zst =
+                local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
+            stmts.push(Statement {
+                source_info,
+                kind: StatementKind::Assign(Box::new((
+                    is_pointee_not_zst,
+                    Rvalue::BinaryOp(
+                        BinOp::Ne,
+                        Box::new((Operand::Copy(sizeof_pointee), zero.clone())),
+                    ),
+                ))),
+            });
+
+            // Pointer needs to be checked only if pointee is not a ZST.
+            Operand::Copy(is_pointee_not_zst)
+        }
+    };
 
     // Check whether the pointer is null.
     let is_null = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
@@ -77,7 +99,8 @@ fn insert_null_check<'tcx>(
         ))),
     });
 
-    // We want to throw an exception if the pointer is null and doesn't point to a ZST.
+    // We want to throw an exception if the pointer is null and the pointee is not unconditionally
+    // allowed (which for all non-borrow place uses, is when the pointee is ZST).
     let should_throw_exception =
         local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
     stmts.push(Statement {
@@ -86,7 +109,7 @@ fn insert_null_check<'tcx>(
             should_throw_exception,
             Rvalue::BinaryOp(
                 BinOp::BitAnd,
-                Box::new((Operand::Copy(is_null), Operand::Copy(is_pointee_no_zst))),
+                Box::new((Operand::Copy(is_null), pointee_should_be_checked)),
             ),
         ))),
     });