about summary refs log tree commit diff
path: root/compiler/rustc_abi/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_abi/src')
-rw-r--r--compiler/rustc_abi/src/layout.rs94
-rw-r--r--compiler/rustc_abi/src/lib.rs35
-rw-r--r--compiler/rustc_abi/src/tests.rs63
3 files changed, 158 insertions, 34 deletions
diff --git a/compiler/rustc_abi/src/layout.rs b/compiler/rustc_abi/src/layout.rs
index 80b44e432ee..c2405553756 100644
--- a/compiler/rustc_abi/src/layout.rs
+++ b/compiler/rustc_abi/src/layout.rs
@@ -1,3 +1,4 @@
+use std::collections::BTreeSet;
 use std::fmt::{self, Write};
 use std::ops::{Bound, Deref};
 use std::{cmp, iter};
@@ -5,7 +6,7 @@ use std::{cmp, iter};
 use rustc_hashes::Hash64;
 use rustc_index::Idx;
 use rustc_index::bit_set::BitMatrix;
-use tracing::debug;
+use tracing::{debug, trace};
 
 use crate::{
     AbiAlign, Align, BackendRepr, FieldsShape, HasDataLayout, IndexSlice, IndexVec, Integer,
@@ -313,7 +314,6 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
         scalar_valid_range: (Bound<u128>, Bound<u128>),
         discr_range_of_repr: impl Fn(i128, i128) -> (Integer, bool),
         discriminants: impl Iterator<Item = (VariantIdx, i128)>,
-        dont_niche_optimize_enum: bool,
         always_sized: bool,
     ) -> LayoutCalculatorResult<FieldIdx, VariantIdx, F> {
         let (present_first, present_second) = {
@@ -352,13 +352,7 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
             // structs. (We have also handled univariant enums
             // that allow representation optimization.)
             assert!(is_enum);
-            self.layout_of_enum(
-                repr,
-                variants,
-                discr_range_of_repr,
-                discriminants,
-                dont_niche_optimize_enum,
-            )
+            self.layout_of_enum(repr, variants, discr_range_of_repr, discriminants)
         }
     }
 
@@ -599,7 +593,6 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
         variants: &IndexSlice<VariantIdx, IndexVec<FieldIdx, F>>,
         discr_range_of_repr: impl Fn(i128, i128) -> (Integer, bool),
         discriminants: impl Iterator<Item = (VariantIdx, i128)>,
-        dont_niche_optimize_enum: bool,
     ) -> LayoutCalculatorResult<FieldIdx, VariantIdx, F> {
         // Until we've decided whether to use the tagged or
         // niche filling LayoutData, we don't want to intern the
@@ -618,7 +611,7 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
         }
 
         let calculate_niche_filling_layout = || -> Option<TmpLayout<FieldIdx, VariantIdx>> {
-            if dont_niche_optimize_enum {
+            if repr.inhibit_enum_layout_opt() {
                 return None;
             }
 
@@ -774,30 +767,63 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
 
         let niche_filling_layout = calculate_niche_filling_layout();
 
-        let (mut min, mut max) = (i128::MAX, i128::MIN);
         let discr_type = repr.discr_type();
-        let bits = Integer::from_attr(dl, discr_type).size().bits();
-        for (i, mut val) in discriminants {
-            if !repr.c() && variants[i].iter().any(|f| f.is_uninhabited()) {
-                continue;
-            }
-            if discr_type.is_signed() {
-                // sign extend the raw representation to be an i128
-                val = (val << (128 - bits)) >> (128 - bits);
-            }
-            if val < min {
-                min = val;
-            }
-            if val > max {
-                max = val;
-            }
-        }
-        // We might have no inhabited variants, so pretend there's at least one.
-        if (min, max) == (i128::MAX, i128::MIN) {
-            min = 0;
-            max = 0;
-        }
-        assert!(min <= max, "discriminant range is {min}...{max}");
+        let discr_int = Integer::from_attr(dl, discr_type);
+        // Because we can only represent one range of valid values, we'll look for the
+        // largest range of invalid values and pick everything else as the range of valid
+        // values.
+
+        // First we need to sort the possible discriminant values so that we can look for the largest gap:
+        let valid_discriminants: BTreeSet<i128> = discriminants
+            .filter(|&(i, _)| repr.c() || variants[i].iter().all(|f| !f.is_uninhabited()))
+            .map(|(_, val)| {
+                if discr_type.is_signed() {
+                    // sign extend the raw representation to be an i128
+                    // FIXME: do this at the discriminant iterator creation sites
+                    discr_int.size().sign_extend(val as u128)
+                } else {
+                    val
+                }
+            })
+            .collect();
+        trace!(?valid_discriminants);
+        let discriminants = valid_discriminants.iter().copied();
+        //let next_discriminants = discriminants.clone().cycle().skip(1);
+        let next_discriminants =
+            discriminants.clone().chain(valid_discriminants.first().copied()).skip(1);
+        // Iterate over pairs of each discriminant together with the next one.
+        // Since they were sorted, we can now compute the niche sizes and pick the largest.
+        let discriminants = discriminants.zip(next_discriminants);
+        let largest_niche = discriminants.max_by_key(|&(start, end)| {
+            trace!(?start, ?end);
+            // If this is a wraparound range, the niche size is `MAX - abs(diff)`, as the diff between
+            // the two end points is actually the size of the valid discriminants.
+            let dist = if start > end {
+                // Overflow can happen for 128 bit discriminants if `end` is negative.
+                // But in that case casting to `u128` still gets us the right value,
+                // as the distance must be positive if the lhs of the subtraction is larger than the rhs.
+                let dist = start.wrapping_sub(end);
+                if discr_type.is_signed() {
+                    discr_int.signed_max().wrapping_sub(dist) as u128
+                } else {
+                    discr_int.size().unsigned_int_max() - dist as u128
+                }
+            } else {
+                // Overflow can happen for 128 bit discriminants if `start` is negative.
+                // But in that case casting to `u128` still gets us the right value,
+                // as the distance must be positive if the lhs of the subtraction is larger than the rhs.
+                end.wrapping_sub(start) as u128
+            };
+            trace!(?dist);
+            dist
+        });
+        trace!(?largest_niche);
+
+        // `max` is the last valid discriminant before the largest niche
+        // `min` is the first valid discriminant after the largest niche
+        let (max, min) = largest_niche
+            // We might have no inhabited variants, so pretend there's at least one.
+            .unwrap_or((0, 0));
         let (min_ity, signed) = discr_range_of_repr(min, max); //Integer::repr_discr(tcx, ty, &repr, min, max);
 
         let mut align = dl.aggregate_align;
diff --git a/compiler/rustc_abi/src/lib.rs b/compiler/rustc_abi/src/lib.rs
index 5bd73502d98..14e256b8045 100644
--- a/compiler/rustc_abi/src/lib.rs
+++ b/compiler/rustc_abi/src/lib.rs
@@ -1205,6 +1205,19 @@ impl Integer {
         }
     }
 
+    /// Returns the smallest signed value that can be represented by this Integer.
+    #[inline]
+    pub fn signed_min(self) -> i128 {
+        use Integer::*;
+        match self {
+            I8 => i8::MIN as i128,
+            I16 => i16::MIN as i128,
+            I32 => i32::MIN as i128,
+            I64 => i64::MIN as i128,
+            I128 => i128::MIN,
+        }
+    }
+
     /// Finds the smallest Integer type which can represent the signed value.
     #[inline]
     pub fn fit_signed(x: i128) -> Integer {
@@ -1376,6 +1389,28 @@ impl WrappingRange {
         }
     }
 
+    /// Returns `true` if all the values in `other` are contained in this range,
+    /// when the values are considered as having width `size`.
+    #[inline(always)]
+    pub fn contains_range(&self, other: Self, size: Size) -> bool {
+        if self.is_full_for(size) {
+            true
+        } else {
+            let trunc = |x| size.truncate(x);
+
+            let delta = self.start;
+            let max = trunc(self.end.wrapping_sub(delta));
+
+            let other_start = trunc(other.start.wrapping_sub(delta));
+            let other_end = trunc(other.end.wrapping_sub(delta));
+
+            // Having shifted both input ranges by `delta`, now we only need to check
+            // whether `0..=max` contains `other_start..=other_end`, which can only
+            // happen if the other doesn't wrap since `self` isn't everything.
+            (other_start <= other_end) && (other_end <= max)
+        }
+    }
+
     /// Returns `self` with replaced `start`
     #[inline(always)]
     fn with_start(mut self, start: u128) -> Self {
diff --git a/compiler/rustc_abi/src/tests.rs b/compiler/rustc_abi/src/tests.rs
index d993012378c..d49c2d44af8 100644
--- a/compiler/rustc_abi/src/tests.rs
+++ b/compiler/rustc_abi/src/tests.rs
@@ -5,3 +5,66 @@ fn align_constants() {
     assert_eq!(Align::ONE, Align::from_bytes(1).unwrap());
     assert_eq!(Align::EIGHT, Align::from_bytes(8).unwrap());
 }
+
+#[test]
+fn wrapping_range_contains_range() {
+    let size16 = Size::from_bytes(16);
+
+    let a = WrappingRange { start: 10, end: 20 };
+    assert!(a.contains_range(a, size16));
+    assert!(a.contains_range(WrappingRange { start: 11, end: 19 }, size16));
+    assert!(a.contains_range(WrappingRange { start: 10, end: 10 }, size16));
+    assert!(a.contains_range(WrappingRange { start: 20, end: 20 }, size16));
+    assert!(!a.contains_range(WrappingRange { start: 10, end: 21 }, size16));
+    assert!(!a.contains_range(WrappingRange { start: 9, end: 20 }, size16));
+    assert!(!a.contains_range(WrappingRange { start: 4, end: 6 }, size16));
+    assert!(!a.contains_range(WrappingRange { start: 24, end: 26 }, size16));
+
+    assert!(!a.contains_range(WrappingRange { start: 16, end: 14 }, size16));
+
+    let b = WrappingRange { start: 20, end: 10 };
+    assert!(b.contains_range(b, size16));
+    assert!(b.contains_range(WrappingRange { start: 20, end: 20 }, size16));
+    assert!(b.contains_range(WrappingRange { start: 10, end: 10 }, size16));
+    assert!(b.contains_range(WrappingRange { start: 0, end: 10 }, size16));
+    assert!(b.contains_range(WrappingRange { start: 20, end: 30 }, size16));
+    assert!(b.contains_range(WrappingRange { start: 20, end: 9 }, size16));
+    assert!(b.contains_range(WrappingRange { start: 21, end: 10 }, size16));
+    assert!(b.contains_range(WrappingRange { start: 999, end: 9999 }, size16));
+    assert!(b.contains_range(WrappingRange { start: 999, end: 9 }, size16));
+    assert!(!b.contains_range(WrappingRange { start: 19, end: 19 }, size16));
+    assert!(!b.contains_range(WrappingRange { start: 11, end: 11 }, size16));
+    assert!(!b.contains_range(WrappingRange { start: 19, end: 11 }, size16));
+    assert!(!b.contains_range(WrappingRange { start: 11, end: 19 }, size16));
+
+    let f = WrappingRange { start: 0, end: u128::MAX };
+    assert!(f.contains_range(WrappingRange { start: 10, end: 20 }, size16));
+    assert!(f.contains_range(WrappingRange { start: 20, end: 10 }, size16));
+
+    let g = WrappingRange { start: 2, end: 1 };
+    assert!(g.contains_range(WrappingRange { start: 10, end: 20 }, size16));
+    assert!(g.contains_range(WrappingRange { start: 20, end: 10 }, size16));
+
+    let size1 = Size::from_bytes(1);
+    let u8r = WrappingRange { start: 0, end: 255 };
+    let i8r = WrappingRange { start: 128, end: 127 };
+    assert!(u8r.contains_range(i8r, size1));
+    assert!(i8r.contains_range(u8r, size1));
+    assert!(!u8r.contains_range(i8r, size16));
+    assert!(i8r.contains_range(u8r, size16));
+
+    let boolr = WrappingRange { start: 0, end: 1 };
+    assert!(u8r.contains_range(boolr, size1));
+    assert!(i8r.contains_range(boolr, size1));
+    assert!(!boolr.contains_range(u8r, size1));
+    assert!(!boolr.contains_range(i8r, size1));
+
+    let cmpr = WrappingRange { start: 255, end: 1 };
+    assert!(u8r.contains_range(cmpr, size1));
+    assert!(i8r.contains_range(cmpr, size1));
+    assert!(!cmpr.contains_range(u8r, size1));
+    assert!(!cmpr.contains_range(i8r, size1));
+
+    assert!(!boolr.contains_range(cmpr, size1));
+    assert!(cmpr.contains_range(boolr, size1));
+}