about summary refs log tree commit diff
path: root/compiler/rustc_const_eval/src/interpret/validity.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_const_eval/src/interpret/validity.rs')
-rw-r--r--compiler/rustc_const_eval/src/interpret/validity.rs54
1 files changed, 33 insertions, 21 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs
index 108394d224b..21c655988a0 100644
--- a/compiler/rustc_const_eval/src/interpret/validity.rs
+++ b/compiler/rustc_const_eval/src/interpret/validity.rs
@@ -19,7 +19,9 @@ use rustc_middle::mir::interpret::{
 use rustc_middle::ty;
 use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
 use rustc_span::symbol::{sym, Symbol};
-use rustc_target::abi::{Abi, FieldIdx, Scalar as ScalarAbi, Size, VariantIdx, Variants};
+use rustc_target::abi::{
+    Abi, FieldIdx, Scalar as ScalarAbi, Size, VariantIdx, Variants, WrappingRange,
+};
 
 use std::hash::Hash;
 
@@ -552,7 +554,7 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
                     // FIXME: Check if the signature matches
                 } else {
                     // Otherwise (for standalone Miri), we have to still check it to be non-null.
-                    if self.ecx.ptr_scalar_range(value)?.contains(&0) {
+                    if self.ecx.scalar_may_be_null(value)? {
                         throw_validation_failure!(self.path, NullFnPtr);
                     }
                 }
@@ -593,36 +595,46 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
     ) -> InterpResult<'tcx> {
         let size = scalar_layout.size(self.ecx);
         let valid_range = scalar_layout.valid_range(self.ecx);
+        let WrappingRange { start, end } = valid_range;
         let max_value = size.unsigned_int_max();
-        assert!(valid_range.end <= max_value);
-        match scalar.try_to_int() {
-            Ok(int) => {
-                // We have an explicit int: check it against the valid range.
-                let bits = int.assert_bits(size);
-                if valid_range.contains(bits) {
-                    Ok(())
-                } else {
-                    throw_validation_failure!(
-                        self.path,
-                        OutOfRange { value: format!("{bits}"), range: valid_range, max_value }
-                    )
-                }
-            }
+        assert!(end <= max_value);
+        let bits = match scalar.try_to_int() {
+            Ok(int) => int.assert_bits(size),
             Err(_) => {
                 // So this is a pointer then, and casting to an int failed.
                 // Can only happen during CTFE.
-                // We check if the possible addresses are compatible with the valid range.
-                let range = self.ecx.ptr_scalar_range(scalar)?;
-                if valid_range.contains_range(range) {
-                    Ok(())
+                // We support 2 kinds of ranges here: full range, and excluding zero.
+                if start == 1 && end == max_value {
+                    // Only null is the niche. So make sure the ptr is NOT null.
+                    if self.ecx.scalar_may_be_null(scalar)? {
+                        throw_validation_failure!(
+                            self.path,
+                            NullablePtrOutOfRange { range: valid_range, max_value }
+                        )
+                    } else {
+                        return Ok(());
+                    }
+                } else if scalar_layout.is_always_valid(self.ecx) {
+                    // Easy. (This is reachable if `enforce_number_validity` is set.)
+                    return Ok(());
                 } else {
-                    // Reject conservatively, because the pointer *could* have a bad value.
+                    // Conservatively, we reject, because the pointer *could* have a bad
+                    // value.
                     throw_validation_failure!(
                         self.path,
                         PtrOutOfRange { range: valid_range, max_value }
                     )
                 }
             }
+        };
+        // Now compare.
+        if valid_range.contains(bits) {
+            Ok(())
+        } else {
+            throw_validation_failure!(
+                self.path,
+                OutOfRange { value: format!("{bits}"), range: valid_range, max_value }
+            )
         }
     }
 }