about summary refs log tree commit diff
path: root/compiler/rustc_const_eval/src
diff options
context:
space:
mode:
authorMoulins <arthur.heuillard@orange.fr>2023-06-25 19:32:40 +0200
committerMoulins <arthur.heuillard@orange.fr>2023-07-21 03:31:45 +0200
commit76c49aead6d49a993c4b2e59cceaf7d8d3324944 (patch)
tree0061384d8ed50ca481236be892a9982ebd8463e6 /compiler/rustc_const_eval/src
parent3c0527686644cb30291a03bdecdcfddb396765ab (diff)
downloadrust-76c49aead6d49a993c4b2e59cceaf7d8d3324944.tar.gz
rust-76c49aead6d49a993c4b2e59cceaf7d8d3324944.zip
support non-null pointer niches in CTFE
Diffstat (limited to 'compiler/rustc_const_eval/src')
-rw-r--r--compiler/rustc_const_eval/src/const_eval/machine.rs2
-rw-r--r--compiler/rustc_const_eval/src/errors.rs5
-rw-r--r--compiler/rustc_const_eval/src/interpret/discriminant.rs24
-rw-r--r--compiler/rustc_const_eval/src/interpret/memory.rs41
-rw-r--r--compiler/rustc_const_eval/src/interpret/validity.rs54
5 files changed, 63 insertions, 63 deletions
diff --git a/compiler/rustc_const_eval/src/const_eval/machine.rs b/compiler/rustc_const_eval/src/const_eval/machine.rs
index 267795a6cb4..51012da6b90 100644
--- a/compiler/rustc_const_eval/src/const_eval/machine.rs
+++ b/compiler/rustc_const_eval/src/const_eval/machine.rs
@@ -333,7 +333,7 @@ impl<'mir, 'tcx: 'mir> CompileTimeEvalContext<'mir, 'tcx> {
             // Inequality with integers other than null can never be known for sure.
             (Scalar::Int(int), ptr @ Scalar::Ptr(..))
             | (ptr @ Scalar::Ptr(..), Scalar::Int(int))
-                if int.is_null() && !self.scalar_may_be_null(ptr)? =>
+                if int.is_null() && !self.ptr_scalar_range(ptr)?.contains(&0) =>
             {
                 0
             }
diff --git a/compiler/rustc_const_eval/src/errors.rs b/compiler/rustc_const_eval/src/errors.rs
index ca38cce710e..61ce695ccd2 100644
--- a/compiler/rustc_const_eval/src/errors.rs
+++ b/compiler/rustc_const_eval/src/errors.rs
@@ -617,7 +617,6 @@ impl<'tcx> ReportErrorExt for ValidationErrorInfo<'tcx> {
             MutableRefInConst => const_eval_mutable_ref_in_const,
             NullFnPtr => const_eval_null_fn_ptr,
             NeverVal => const_eval_never_val,
-            NullablePtrOutOfRange { .. } => const_eval_nullable_ptr_out_of_range,
             PtrOutOfRange { .. } => const_eval_ptr_out_of_range,
             OutOfRange { .. } => const_eval_out_of_range,
             UnsafeCell => const_eval_unsafe_cell,
@@ -732,9 +731,7 @@ impl<'tcx> ReportErrorExt for ValidationErrorInfo<'tcx> {
             | InvalidFnPtr { value } => {
                 err.set_arg("value", value);
             }
-            NullablePtrOutOfRange { range, max_value } | PtrOutOfRange { range, max_value } => {
-                add_range_arg(range, max_value, handler, err)
-            }
+            PtrOutOfRange { range, max_value } => add_range_arg(range, max_value, handler, err),
             OutOfRange { range, max_value, value } => {
                 err.set_arg("value", value);
                 add_range_arg(range, max_value, handler, err);
diff --git a/compiler/rustc_const_eval/src/interpret/discriminant.rs b/compiler/rustc_const_eval/src/interpret/discriminant.rs
index f23a455c2ca..99ea0ab18bc 100644
--- a/compiler/rustc_const_eval/src/interpret/discriminant.rs
+++ b/compiler/rustc_const_eval/src/interpret/discriminant.rs
@@ -2,8 +2,7 @@
 
 use rustc_middle::ty::layout::{LayoutOf, PrimitiveExt};
 use rustc_middle::{mir, ty};
-use rustc_target::abi::{self, TagEncoding};
-use rustc_target::abi::{VariantIdx, Variants};
+use rustc_target::abi::{self, TagEncoding, VariantIdx, Variants, WrappingRange};
 
 use super::{ImmTy, InterpCx, InterpResult, Machine, OpTy, PlaceTy, Scalar};
 
@@ -180,19 +179,24 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                 // discriminant (encoded in niche/tag) and variant index are the same.
                 let variants_start = niche_variants.start().as_u32();
                 let variants_end = niche_variants.end().as_u32();
+                let variants_len = u128::from(variants_end - variants_start);
                 let variant = match tag_val.try_to_int() {
                     Err(dbg_val) => {
                         // So this is a pointer then, and casting to an int failed.
                         // Can only happen during CTFE.
-                        // The niche must be just 0, and the ptr not null, then we know this is
-                        // okay. Everything else, we conservatively reject.
-                        let ptr_valid = niche_start == 0
-                            && variants_start == variants_end
-                            && !self.scalar_may_be_null(tag_val)?;
-                        if !ptr_valid {
+                        // The pointer and niches ranges must be disjoint, then we know
+                        // this is the untagged variant (as the value is not in the niche).
+                        // Everything else, we conservatively reject.
+                        let range = self.ptr_scalar_range(tag_val)?;
+                        let niches = WrappingRange {
+                            start: niche_start,
+                            end: niche_start.wrapping_add(variants_len),
+                        };
+                        if niches.overlaps_range(range) {
                             throw_ub!(InvalidTag(dbg_val))
+                        } else {
+                            untagged_variant
                         }
-                        untagged_variant
                     }
                     Ok(tag_bits) => {
                         let tag_bits = tag_bits.assert_bits(tag_layout.size);
@@ -205,7 +209,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                         let variant_index_relative =
                             variant_index_relative_val.to_scalar().assert_bits(tag_val.layout.size);
                         // Check if this is in the range that indicates an actual discriminant.
-                        if variant_index_relative <= u128::from(variants_end - variants_start) {
+                        if variant_index_relative <= variants_len {
                             let variant_index_relative = u32::try_from(variant_index_relative)
                                 .expect("we checked that this fits into a u32");
                             // Then computing the absolute variant idx should not overflow any more.
diff --git a/compiler/rustc_const_eval/src/interpret/memory.rs b/compiler/rustc_const_eval/src/interpret/memory.rs
index 7b44a20ef03..10a2a70364b 100644
--- a/compiler/rustc_const_eval/src/interpret/memory.rs
+++ b/compiler/rustc_const_eval/src/interpret/memory.rs
@@ -10,6 +10,7 @@ use std::assert_matches::assert_matches;
 use std::borrow::Cow;
 use std::collections::VecDeque;
 use std::fmt;
+use std::ops::RangeInclusive;
 use std::ptr;
 
 use rustc_ast::Mutability;
@@ -1222,24 +1223,34 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
 
 /// Machine pointer introspection.
 impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
-    /// Test if this value might be null.
+    /// Turn a pointer-sized scalar into a (non-empty) range of possible values.
     /// If the machine does not support ptr-to-int casts, this is conservative.
-    pub fn scalar_may_be_null(&self, scalar: Scalar<M::Provenance>) -> InterpResult<'tcx, bool> {
-        Ok(match scalar.try_to_int() {
-            Ok(int) => int.is_null(),
-            Err(_) => {
-                // Can only happen during CTFE.
-                let ptr = scalar.to_pointer(self)?;
-                match self.ptr_try_get_alloc_id(ptr) {
-                    Ok((alloc_id, offset, _)) => {
-                        let (size, _align, _kind) = self.get_alloc_info(alloc_id);
-                        // If the pointer is out-of-bounds, it may be null.
-                        // Note that one-past-the-end (offset == size) is still inbounds, and never null.
-                        offset > size
-                    }
-                    Err(_offset) => bug!("a non-int scalar is always a pointer"),
+    pub fn ptr_scalar_range(
+        &self,
+        scalar: Scalar<M::Provenance>,
+    ) -> InterpResult<'tcx, RangeInclusive<u64>> {
+        if let Ok(int) = scalar.to_target_usize(self) {
+            return Ok(int..=int);
+        }
+
+        let ptr = scalar.to_pointer(self)?;
+
+        // Can only happen during CTFE.
+        Ok(match self.ptr_try_get_alloc_id(ptr) {
+            Ok((alloc_id, offset, _)) => {
+                let offset = offset.bytes();
+                let (size, align, _) = self.get_alloc_info(alloc_id);
+                let dl = self.data_layout();
+                if offset > size.bytes() {
+                    // If the pointer is out-of-bounds, we do not have a
+                    // meaningful range to return.
+                    0..=dl.max_address()
+                } else {
+                    let (min, max) = dl.address_range_for(size, align);
+                    (min + offset)..=(max + offset)
                 }
             }
+            Err(_offset) => bug!("a non-int scalar is always a pointer"),
         })
     }
 
diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs
index 21c655988a0..108394d224b 100644
--- a/compiler/rustc_const_eval/src/interpret/validity.rs
+++ b/compiler/rustc_const_eval/src/interpret/validity.rs
@@ -19,9 +19,7 @@ use rustc_middle::mir::interpret::{
 use rustc_middle::ty;
 use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
 use rustc_span::symbol::{sym, Symbol};
-use rustc_target::abi::{
-    Abi, FieldIdx, Scalar as ScalarAbi, Size, VariantIdx, Variants, WrappingRange,
-};
+use rustc_target::abi::{Abi, FieldIdx, Scalar as ScalarAbi, Size, VariantIdx, Variants};
 
 use std::hash::Hash;
 
@@ -554,7 +552,7 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
                     // FIXME: Check if the signature matches
                 } else {
                     // Otherwise (for standalone Miri), we have to still check it to be non-null.
-                    if self.ecx.scalar_may_be_null(value)? {
+                    if self.ecx.ptr_scalar_range(value)?.contains(&0) {
                         throw_validation_failure!(self.path, NullFnPtr);
                     }
                 }
@@ -595,46 +593,36 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
     ) -> InterpResult<'tcx> {
         let size = scalar_layout.size(self.ecx);
         let valid_range = scalar_layout.valid_range(self.ecx);
-        let WrappingRange { start, end } = valid_range;
         let max_value = size.unsigned_int_max();
-        assert!(end <= max_value);
-        let bits = match scalar.try_to_int() {
-            Ok(int) => int.assert_bits(size),
+        assert!(valid_range.end <= max_value);
+        match scalar.try_to_int() {
+            Ok(int) => {
+                // We have an explicit int: check it against the valid range.
+                let bits = int.assert_bits(size);
+                if valid_range.contains(bits) {
+                    Ok(())
+                } else {
+                    throw_validation_failure!(
+                        self.path,
+                        OutOfRange { value: format!("{bits}"), range: valid_range, max_value }
+                    )
+                }
+            }
             Err(_) => {
                 // So this is a pointer then, and casting to an int failed.
                 // Can only happen during CTFE.
-                // We support 2 kinds of ranges here: full range, and excluding zero.
-                if start == 1 && end == max_value {
-                    // Only null is the niche. So make sure the ptr is NOT null.
-                    if self.ecx.scalar_may_be_null(scalar)? {
-                        throw_validation_failure!(
-                            self.path,
-                            NullablePtrOutOfRange { range: valid_range, max_value }
-                        )
-                    } else {
-                        return Ok(());
-                    }
-                } else if scalar_layout.is_always_valid(self.ecx) {
-                    // Easy. (This is reachable if `enforce_number_validity` is set.)
-                    return Ok(());
+                // We check if the possible addresses are compatible with the valid range.
+                let range = self.ecx.ptr_scalar_range(scalar)?;
+                if valid_range.contains_range(range) {
+                    Ok(())
                 } else {
-                    // Conservatively, we reject, because the pointer *could* have a bad
-                    // value.
+                    // Reject conservatively, because the pointer *could* have a bad value.
                     throw_validation_failure!(
                         self.path,
                         PtrOutOfRange { range: valid_range, max_value }
                     )
                 }
             }
-        };
-        // Now compare.
-        if valid_range.contains(bits) {
-            Ok(())
-        } else {
-            throw_validation_failure!(
-                self.path,
-                OutOfRange { value: format!("{bits}"), range: valid_range, max_value }
-            )
         }
     }
 }