about summary refs log tree commit diff
path: root/compiler/rustc_const_eval
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2024-05-09 12:35:11 +0200
committerRalf Jung <post@ralfj.de>2024-05-13 07:59:16 +0200
commit5c33a5690de89587b645fb0e7b657fe545e4d0e8 (patch)
tree63aa3ce58e14c1ff566a6182788fbe37a7c52db7 /compiler/rustc_const_eval
parentba956ef4b00c91579cff9b2220358ee3a46d982f (diff)
downloadrust-5c33a5690de89587b645fb0e7b657fe545e4d0e8.tar.gz
rust-5c33a5690de89587b645fb0e7b657fe545e4d0e8.zip
offset, offset_from: allow zero-byte offset on arbitrary pointers
Diffstat (limited to 'compiler/rustc_const_eval')
-rw-r--r--compiler/rustc_const_eval/src/const_eval/machine.rs20
-rw-r--r--compiler/rustc_const_eval/src/interpret/intrinsics.rs1
-rw-r--r--compiler/rustc_const_eval/src/interpret/memory.rs39
-rw-r--r--compiler/rustc_const_eval/src/interpret/validity.rs11
4 files changed, 44 insertions, 27 deletions
diff --git a/compiler/rustc_const_eval/src/const_eval/machine.rs b/compiler/rustc_const_eval/src/const_eval/machine.rs
index 836e548ae2b..310fd462d5f 100644
--- a/compiler/rustc_const_eval/src/const_eval/machine.rs
+++ b/compiler/rustc_const_eval/src/const_eval/machine.rs
@@ -25,9 +25,9 @@ use rustc_target::spec::abi::Abi as CallAbi;
 use crate::errors::{LongRunning, LongRunningWarn};
 use crate::fluent_generated as fluent;
 use crate::interpret::{
-    self, compile_time_machine, err_ub, throw_exhaust, throw_inval, throw_ub_custom,
+    self, compile_time_machine, err_ub, throw_exhaust, throw_inval, throw_ub_custom, throw_unsup,
     throw_unsup_format, AllocId, AllocRange, ConstAllocation, CtfeProvenance, FnArg, FnVal, Frame,
-    ImmTy, InterpCx, InterpResult, MPlaceTy, OpTy, Pointer, PointerArithmetic, Scalar,
+    GlobalAlloc, ImmTy, InterpCx, InterpResult, MPlaceTy, OpTy, Pointer, PointerArithmetic, Scalar,
 };
 
 use super::error::*;
@@ -759,11 +759,21 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for CompileTimeInterpreter<'mir,
         ecx: &InterpCx<'mir, 'tcx, Self>,
         alloc_id: AllocId,
     ) -> InterpResult<'tcx> {
+        // Check if this is the currently evaluated static.
         if Some(alloc_id) == ecx.machine.static_root_ids.map(|(id, _)| id) {
-            Err(ConstEvalErrKind::RecursiveStatic.into())
-        } else {
-            Ok(())
+            return Err(ConstEvalErrKind::RecursiveStatic.into());
         }
+        // If this is another static, make sure we fire off the query to detect cycles.
+        // But only do that when checks for static recursion are enabled.
+        if ecx.machine.static_root_ids.is_some() {
+            if let Some(GlobalAlloc::Static(def_id)) = ecx.tcx.try_get_global_alloc(alloc_id) {
+                if ecx.tcx.is_foreign_item(def_id) {
+                    throw_unsup!(ExternStatic(def_id));
+                }
+                ecx.ctfe_query(|tcx| tcx.eval_static_initializer(def_id))?;
+            }
+        }
+        Ok(())
     }
 }
 
diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs
index dce4d56f7e0..72dad562695 100644
--- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs
+++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs
@@ -255,6 +255,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                                     name = intrinsic_name,
                                 );
                             }
+                            // This will always return 0.
                             (a, b)
                         }
                         (Err(_), _) | (_, Err(_)) => {
diff --git a/compiler/rustc_const_eval/src/interpret/memory.rs b/compiler/rustc_const_eval/src/interpret/memory.rs
index 350fd480fba..737f2fd8bb9 100644
--- a/compiler/rustc_const_eval/src/interpret/memory.rs
+++ b/compiler/rustc_const_eval/src/interpret/memory.rs
@@ -413,6 +413,8 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
     /// to the allocation it points to. Supports both shared and mutable references, as the actual
     /// checking is offloaded to a helper closure.
     ///
+    /// `alloc_size` will only get called for non-zero-sized accesses.
+    ///
     /// Returns `None` if and only if the size is 0.
     fn check_and_deref_ptr<T>(
         &self,
@@ -425,18 +427,19 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
             M::ProvenanceExtra,
         ) -> InterpResult<'tcx, (Size, Align, T)>,
     ) -> InterpResult<'tcx, Option<T>> {
+        // Everything is okay with size 0.
+        if size.bytes() == 0 {
+            return Ok(None);
+        }
+
         Ok(match self.ptr_try_get_alloc_id(ptr) {
             Err(addr) => {
-                // We couldn't get a proper allocation. This is only okay if the access size is 0,
-                // and the address is not null.
-                if size.bytes() > 0 || addr == 0 {
-                    throw_ub!(DanglingIntPointer(addr, msg));
-                }
-                None
+                // We couldn't get a proper allocation.
+                throw_ub!(DanglingIntPointer(addr, msg));
             }
             Ok((alloc_id, offset, prov)) => {
                 let (alloc_size, _alloc_align, ret_val) = alloc_size(alloc_id, offset, prov)?;
-                // Test bounds. This also ensures non-null.
+                // Test bounds.
                 // It is sufficient to check this for the end pointer. Also check for overflow!
                 if offset.checked_add(size, &self.tcx).map_or(true, |end| end > alloc_size) {
                     throw_ub!(PointerOutOfBounds {
@@ -447,14 +450,8 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                         msg,
                     })
                 }
-                // Ensure we never consider the null pointer dereferenceable.
-                if M::Provenance::OFFSET_IS_ADDR {
-                    assert_ne!(ptr.addr(), Size::ZERO);
-                }
 
-                // We can still be zero-sized in this branch, in which case we have to
-                // return `None`.
-                if size.bytes() == 0 { None } else { Some(ret_val) }
+                Some(ret_val)
             }
         })
     }
@@ -641,16 +638,18 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
             size,
             CheckInAllocMsg::MemoryAccessTest,
             |alloc_id, offset, prov| {
-                if !self.memory.validation_in_progress.get() {
-                    // We want to call the hook on *all* accesses that involve an AllocId,
-                    // including zero-sized accesses. That means we have to do it here
-                    // rather than below in the `Some` branch.
-                    M::before_alloc_read(self, alloc_id)?;
-                }
                 let alloc = self.get_alloc_raw(alloc_id)?;
                 Ok((alloc.size(), alloc.align, (alloc_id, offset, prov, alloc)))
             },
         )?;
+        // We want to call the hook on *all* accesses that involve an AllocId, including zero-sized
+        // accesses. That means we cannot rely on the closure above or the `Some` branch below. We
+        // do this after `check_and_deref_ptr` to ensure some basic sanity has already been checked.
+        if !self.memory.validation_in_progress.get() {
+            if let Ok((alloc_id, ..)) = self.ptr_try_get_alloc_id(ptr) {
+                M::before_alloc_read(self, alloc_id)?;
+            }
+        }
 
         if let Some((alloc_id, offset, prov, alloc)) = ptr_and_alloc {
             let range = alloc_range(offset, size);
diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs
index 2bd4d9dc07a..a47828bb63c 100644
--- a/compiler/rustc_const_eval/src/interpret/validity.rs
+++ b/compiler/rustc_const_eval/src/interpret/validity.rs
@@ -434,6 +434,10 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
                 found_bytes: has.bytes()
             },
         );
+        // Make sure this is non-null. (ZST references can be dereferenceable and null.)
+        if self.ecx.scalar_may_be_null(Scalar::from_maybe_pointer(place.ptr(), self.ecx))? {
+            throw_validation_failure!(self.path, NullPtr { ptr_kind })
+        }
         // Do not allow pointers to uninhabited types.
         if place.layout.abi.is_uninhabited() {
             let ty = place.layout.ty;
@@ -456,8 +460,8 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
             // `!` is a ZST and we want to validate it.
             if let Ok((alloc_id, _offset, _prov)) = self.ecx.ptr_try_get_alloc_id(place.ptr()) {
                 let mut skip_recursive_check = false;
-                let alloc_actual_mutbl = mutability(self.ecx, alloc_id);
-                if let GlobalAlloc::Static(did) = self.ecx.tcx.global_alloc(alloc_id) {
+                if let Some(GlobalAlloc::Static(did)) = self.ecx.tcx.try_get_global_alloc(alloc_id)
+                {
                     let DefKind::Static { nested, .. } = self.ecx.tcx.def_kind(did) else { bug!() };
                     // Special handling for pointers to statics (irrespective of their type).
                     assert!(!self.ecx.tcx.is_thread_local_static(did));
@@ -495,6 +499,7 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
                 // If this allocation has size zero, there is no actual mutability here.
                 let (size, _align, _alloc_kind) = self.ecx.get_alloc_info(alloc_id);
                 if size != Size::ZERO {
+                    let alloc_actual_mutbl = mutability(self.ecx, alloc_id);
                     // Mutable pointer to immutable memory is no good.
                     if ptr_expected_mutbl == Mutability::Mut
                         && alloc_actual_mutbl == Mutability::Not
@@ -831,6 +836,8 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValueVisitor<'mir, 'tcx, M>
         trace!("visit_value: {:?}, {:?}", *op, op.layout);
 
         // Check primitive types -- the leaves of our recursive descent.
+        // We assume that the Scalar validity range does not restrict these values
+        // any further than `try_visit_primitive` does!
         if self.try_visit_primitive(op)? {
             return Ok(());
         }