about summary refs log tree commit diff
path: root/compiler/rustc_const_eval/src/interpret/validity.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_const_eval/src/interpret/validity.rs')
-rw-r--r--compiler/rustc_const_eval/src/interpret/validity.rs85
1 files changed, 39 insertions, 46 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs
index 5360c51c48d..792e1c9e736 100644
--- a/compiler/rustc_const_eval/src/interpret/validity.rs
+++ b/compiler/rustc_const_eval/src/interpret/validity.rs
@@ -445,22 +445,22 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
             // Determine whether this pointer expects to be pointing to something mutable.
             let ptr_expected_mutbl = match ptr_kind {
                 PointerKind::Box => Mutability::Mut,
-                PointerKind::Ref => {
-                    let tam = value.layout.ty.builtin_deref(false).unwrap();
-                    // ZST never require mutability. We do not take into account interior mutability
-                    // here since we cannot know if there really is an `UnsafeCell` inside
-                    // `Option<UnsafeCell>` -- so we check that in the recursive descent behind this
-                    // reference.
-                    if size == Size::ZERO { Mutability::Not } else { tam.mutbl }
+                PointerKind::Ref(mutbl) => {
+                    // We do not take into account interior mutability here since we cannot know if
+                    // there really is an `UnsafeCell` inside `Option<UnsafeCell>` -- so we check
+                    // that in the recursive descent behind this reference (controlled by
+                    // `allow_immutable_unsafe_cell`).
+                    mutbl
                 }
             };
             // Proceed recursively even for ZST, no reason to skip them!
             // `!` is a ZST and we want to validate it.
             if let Ok((alloc_id, _offset, _prov)) = self.ecx.ptr_try_get_alloc_id(place.ptr()) {
+                let mut skip_recursive_check = false;
                 // Let's see what kind of memory this points to.
                 // `unwrap` since dangling pointers have already been handled.
                 let alloc_kind = self.ecx.tcx.try_get_global_alloc(alloc_id).unwrap();
-                match alloc_kind {
+                let alloc_actual_mutbl = match alloc_kind {
                     GlobalAlloc::Static(did) => {
                         // Special handling for pointers to statics (irrespective of their type).
                         assert!(!self.ecx.tcx.is_thread_local_static(did));
@@ -474,12 +474,6 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
                                     .no_bound_vars()
                                     .expect("statics should not have generic parameters")
                                     .is_freeze(*self.ecx.tcx, ty::ParamEnv::reveal_all());
-                        // Mutability check.
-                        if ptr_expected_mutbl == Mutability::Mut {
-                            if !is_mut {
-                                throw_validation_failure!(self.path, MutableRefToImmutable);
-                            }
-                        }
                         // Mode-specific checks
                         match self.ctfe_mode {
                             Some(
@@ -494,15 +488,9 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
                                 // trigger cycle errors if we try to compute the value of the other static
                                 // and that static refers back to us (potentially through a promoted).
                                 // This could miss some UB, but that's fine.
-                                return Ok(());
+                                skip_recursive_check = true;
                             }
                             Some(CtfeValidationMode::Const { .. }) => {
-                                // For consts on the other hand we have to recursively check;
-                                // pattern matching assumes a valid value. However we better make
-                                // sure this is not mutable.
-                                if is_mut {
-                                    throw_validation_failure!(self.path, ConstRefToMutable);
-                                }
                                 // We can't recursively validate `extern static`, so we better reject them.
                                 if self.ecx.tcx.is_foreign_item(did) {
                                     throw_validation_failure!(self.path, ConstRefToExtern);
@@ -510,26 +498,39 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
                             }
                             None => {}
                         }
+                        // Return alloc mutability
+                        if is_mut { Mutability::Mut } else { Mutability::Not }
                     }
-                    GlobalAlloc::Memory(alloc) => {
-                        if alloc.inner().mutability == Mutability::Mut
-                            && matches!(self.ctfe_mode, Some(CtfeValidationMode::Const { .. }))
-                        {
-                            throw_validation_failure!(self.path, ConstRefToMutable);
-                        }
-                        if ptr_expected_mutbl == Mutability::Mut
-                            && alloc.inner().mutability == Mutability::Not
-                        {
-                            throw_validation_failure!(self.path, MutableRefToImmutable);
-                        }
-                    }
+                    GlobalAlloc::Memory(alloc) => alloc.inner().mutability,
                     GlobalAlloc::Function(..) | GlobalAlloc::VTable(..) => {
                         // These are immutable, we better don't allow mutable pointers here.
-                        if ptr_expected_mutbl == Mutability::Mut {
-                            throw_validation_failure!(self.path, MutableRefToImmutable);
-                        }
+                        Mutability::Not
+                    }
+                };
+                // Mutability check.
+                // If this allocation has size zero, there is no actual mutability here.
+                let (size, _align, _alloc_kind) = self.ecx.get_alloc_info(alloc_id);
+                if size != Size::ZERO {
+                    if ptr_expected_mutbl == Mutability::Mut
+                        && alloc_actual_mutbl == Mutability::Not
+                    {
+                        throw_validation_failure!(self.path, MutableRefToImmutable);
+                    }
+                    if ptr_expected_mutbl == Mutability::Mut
+                        && self.ctfe_mode.is_some_and(|c| !c.may_contain_mutable_ref())
+                    {
+                        throw_validation_failure!(self.path, MutableRefInConstOrStatic);
+                    }
+                    if alloc_actual_mutbl == Mutability::Mut
+                        && matches!(self.ctfe_mode, Some(CtfeValidationMode::Const { .. }))
+                    {
+                        throw_validation_failure!(self.path, ConstRefToMutable);
                     }
                 }
+                // Potentially skip recursive check.
+                if skip_recursive_check {
+                    return Ok(());
+                }
             }
             let path = &self.path;
             ref_tracking.track(place, || {
@@ -598,16 +599,8 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
                 }
                 Ok(true)
             }
-            ty::Ref(_, ty, mutbl) => {
-                if self.ctfe_mode.is_some_and(|c| !c.may_contain_mutable_ref())
-                    && *mutbl == Mutability::Mut
-                {
-                    let layout = self.ecx.layout_of(*ty)?;
-                    if !layout.is_zst() {
-                        throw_validation_failure!(self.path, MutableRefInConst);
-                    }
-                }
-                self.check_safe_pointer(value, PointerKind::Ref)?;
+            ty::Ref(_, _ty, mutbl) => {
+                self.check_safe_pointer(value, PointerKind::Ref(*mutbl))?;
                 Ok(true)
             }
             ty::FnPtr(_sig) => {