about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCelina G. Val <celinval@amazon.com>2024-02-29 15:23:44 -0800
committerCelina G. Val <celinval@amazon.com>2024-03-01 11:02:05 -0800
commite3ac2c68b89c95e01440c1ab5fc7897e93a35c41 (patch)
treebf2008aeee1de7b4df1687b4aba12b651641294e
parent6db96de66c2c0ea3f4f2f348ed1a83c2c507687d (diff)
downloadrust-e3ac2c68b89c95e01440c1ab5fc7897e93a35c41.tar.gz
rust-e3ac2c68b89c95e01440c1ab5fc7897e93a35c41.zip
Implement missing ABI structures in StableMIR
-rw-r--r--compiler/rustc_smir/src/rustc_smir/convert/abi.rs66
-rw-r--r--compiler/stable_mir/src/abi.rs159
-rw-r--r--compiler/stable_mir/src/error.rs11
-rw-r--r--compiler/stable_mir/src/target.rs10
-rw-r--r--compiler/stable_mir/src/ty.rs7
-rw-r--r--tests/ui-fulldeps/stable-mir/check_abi.rs45
6 files changed, 273 insertions, 25 deletions
diff --git a/compiler/rustc_smir/src/rustc_smir/convert/abi.rs b/compiler/rustc_smir/src/rustc_smir/convert/abi.rs
index 088a836c901..071c02e0381 100644
--- a/compiler/rustc_smir/src/rustc_smir/convert/abi.rs
+++ b/compiler/rustc_smir/src/rustc_smir/convert/abi.rs
@@ -6,11 +6,12 @@ use crate::rustc_smir::{Stable, Tables};
 use rustc_middle::ty;
 use rustc_target::abi::call::Conv;
 use stable_mir::abi::{
-    ArgAbi, CallConvention, FieldsShape, FnAbi, Layout, LayoutShape, PassMode, TagEncoding,
-    TyAndLayout, ValueAbi, VariantsShape,
+    AddressSpace, ArgAbi, CallConvention, FieldsShape, FnAbi, IntegerLength, Layout, LayoutShape,
+    PassMode, Primitive, Scalar, TagEncoding, TyAndLayout, ValueAbi, VariantsShape, WrappingRange,
 };
-use stable_mir::ty::{Align, IndexedVal, Size, VariantIdx};
-use stable_mir::{opaque, Opaque};
+use stable_mir::opaque;
+use stable_mir::target::MachineSize as Size;
+use stable_mir::ty::{Align, IndexedVal, VariantIdx};
 
 impl<'tcx> Stable<'tcx> for rustc_target::abi::VariantIdx {
     type T = VariantIdx;
@@ -220,7 +221,7 @@ impl<'tcx> Stable<'tcx> for rustc_abi::Size {
     type T = Size;
 
     fn stable(&self, _tables: &mut Tables<'_>) -> Self::T {
-        self.bytes_usize()
+        Size::from_bits(self.bits_usize())
     }
 }
 
@@ -233,9 +234,60 @@ impl<'tcx> Stable<'tcx> for rustc_abi::Align {
 }
 
 impl<'tcx> Stable<'tcx> for rustc_abi::Scalar {
-    type T = Opaque;
+    type T = Scalar;
+
+    fn stable(&self, tables: &mut Tables<'_>) -> Self::T {
+        match self {
+            rustc_abi::Scalar::Initialized { value, valid_range } => Scalar::Initialized {
+                value: value.stable(tables),
+                valid_range: valid_range.stable(tables),
+            },
+            rustc_abi::Scalar::Union { value } => Scalar::Union { value: value.stable(tables) },
+        }
+    }
+}
+
+impl<'tcx> Stable<'tcx> for rustc_abi::Primitive {
+    type T = Primitive;
+
+    fn stable(&self, tables: &mut Tables<'_>) -> Self::T {
+        match self {
+            rustc_abi::Primitive::Int(length, signed) => {
+                Primitive::Int { length: length.stable(tables), signed: *signed }
+            }
+            rustc_abi::Primitive::F32 => Primitive::F32,
+            rustc_abi::Primitive::F64 => Primitive::F64,
+            rustc_abi::Primitive::Pointer(space) => Primitive::Pointer(space.stable(tables)),
+        }
+    }
+}
+
+impl<'tcx> Stable<'tcx> for rustc_abi::AddressSpace {
+    type T = AddressSpace;
+
+    fn stable(&self, _tables: &mut Tables<'_>) -> Self::T {
+        AddressSpace(self.0)
+    }
+}
+
+impl<'tcx> Stable<'tcx> for rustc_abi::Integer {
+    type T = IntegerLength;
+
+    fn stable(&self, _tables: &mut Tables<'_>) -> Self::T {
+        match self {
+            rustc_abi::Integer::I8 => IntegerLength::I8,
+            rustc_abi::Integer::I16 => IntegerLength::I16,
+            rustc_abi::Integer::I32 => IntegerLength::I32,
+            rustc_abi::Integer::I64 => IntegerLength::I64,
+            rustc_abi::Integer::I128 => IntegerLength::I128,
+        }
+    }
+}
+
+impl<'tcx> Stable<'tcx> for rustc_abi::WrappingRange {
+    type T = WrappingRange;
 
     fn stable(&self, _tables: &mut Tables<'_>) -> Self::T {
-        opaque(self)
+        WrappingRange { start: self.start, end: self.end }
     }
 }
diff --git a/compiler/stable_mir/src/abi.rs b/compiler/stable_mir/src/abi.rs
index a15fd3e0999..1c5e3275673 100644
--- a/compiler/stable_mir/src/abi.rs
+++ b/compiler/stable_mir/src/abi.rs
@@ -1,7 +1,11 @@
 use crate::compiler_interface::with;
+use crate::error;
 use crate::mir::FieldIdx;
-use crate::ty::{Align, IndexedVal, Size, Ty, VariantIdx};
+use crate::target::{MachineInfo, MachineSize as Size};
+use crate::ty::{Align, IndexedVal, Ty, VariantIdx};
+use crate::Error;
 use crate::Opaque;
+use std::fmt::{self, Debug};
 use std::num::NonZeroUsize;
 use std::ops::RangeInclusive;
 
@@ -100,7 +104,7 @@ impl LayoutShape {
 
     /// Returns `true` if the type is sized and a 1-ZST (meaning it has size 0 and alignment 1).
     pub fn is_1zst(&self) -> bool {
-        self.is_sized() && self.size == 0 && self.abi_align == 1
+        self.is_sized() && self.size.bits() == 0 && self.abi_align == 1
     }
 }
 
@@ -245,8 +249,155 @@ impl ValueAbi {
     }
 }
 
-/// We currently do not support `Scalar`, and use opaque instead.
-type Scalar = Opaque;
+/// Information about one scalar component of a Rust type.
+#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
+pub enum Scalar {
+    Initialized {
+        /// The primitive type used to represent this value.
+        value: Primitive,
+        /// The range that represents valid values.
+        /// The range must be valid for the `primitive` size.
+        valid_range: WrappingRange,
+    },
+    Union {
+        /// Unions never have niches, so there is no `valid_range`.
+        /// Even for unions, we need to use the correct registers for the kind of
+        /// values inside the union, so we keep the `Primitive` type around.
+        /// It is also used to compute the size of the scalar.
+        value: Primitive,
+    },
+}
+
+impl Scalar {
+    pub fn has_niche(&self, target: &MachineInfo) -> bool {
+        match self {
+            Scalar::Initialized { value, valid_range } => {
+                !valid_range.is_full(value.size(target)).unwrap()
+            }
+            Scalar::Union { .. } => false,
+        }
+    }
+}
+
+/// Fundamental unit of memory access and layout.
+#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
+pub enum Primitive {
+    /// The `bool` is the signedness of the `Integer` type.
+    ///
+    /// One would think we would not care about such details this low down,
+    /// but some ABIs are described in terms of C types and ISAs where the
+    /// integer arithmetic is done on {sign,zero}-extended registers, e.g.
+    /// a negative integer passed by zero-extension will appear positive in
+    /// the callee, and most operations on it will produce the wrong values.
+    Int {
+        length: IntegerLength,
+        signed: bool,
+    },
+    F32,
+    F64,
+    Pointer(AddressSpace),
+}
+
+impl Primitive {
+    pub fn size(self, target: &MachineInfo) -> Size {
+        match self {
+            Primitive::Int { length, .. } => Size::from_bits(length.bits()),
+            Primitive::F32 => Size::from_bits(32),
+            Primitive::F64 => Size::from_bits(64),
+            Primitive::Pointer(_) => target.pointer_width,
+        }
+    }
+}
+
+/// Enum representing the existing integer lengths.
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
+pub enum IntegerLength {
+    I8,
+    I16,
+    I32,
+    I64,
+    I128,
+}
+
+impl IntegerLength {
+    pub fn bits(self) -> usize {
+        match self {
+            IntegerLength::I8 => 8,
+            IntegerLength::I16 => 16,
+            IntegerLength::I32 => 32,
+            IntegerLength::I64 => 64,
+            IntegerLength::I128 => 128,
+        }
+    }
+}
+
+/// An identifier that specifies the address space that some operation
+/// should operate on. Special address spaces have an effect on code generation,
+/// depending on the target and the address spaces it implements.
+#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
+pub struct AddressSpace(pub u32);
+
+impl AddressSpace {
+    /// The default address space, corresponding to data space.
+    pub const DATA: Self = AddressSpace(0);
+}
+
+/// Inclusive wrap-around range of valid values (bitwise representation), that is, if
+/// start > end, it represents `start..=MAX`, followed by `0..=end`.
+///
+/// That is, for an i8 primitive, a range of `254..=2` means following
+/// sequence:
+///
+///    254 (-2), 255 (-1), 0, 1, 2
+#[derive(Clone, Copy, PartialEq, Eq, Hash)]
+pub struct WrappingRange {
+    pub start: u128,
+    pub end: u128,
+}
+
+impl WrappingRange {
+    /// Returns `true` if `size` completely fills the range.
+    #[inline]
+    pub fn is_full(&self, size: Size) -> Result<bool, Error> {
+        let Some(max_value) = size.unsigned_int_max() else {
+            return Err(error!("Expected size <= 128 bits, but found {} instead", size.bits()));
+        };
+        if self.start <= max_value && self.end <= max_value {
+            Ok(self.start == 0 && max_value == self.end)
+        } else {
+            Err(error!("Range `{self:?}` out of bounds for size `{}` bits.", size.bits()))
+        }
+    }
+
+    /// Returns `true` if `v` is contained in the range.
+    #[inline(always)]
+    pub fn contains(&self, v: u128) -> bool {
+        if self.wraps_around() {
+            self.start <= v || v <= self.end
+        } else {
+            self.start <= v && v <= self.end
+        }
+    }
+
+    /// Returns `true` if the range wraps around.
+    /// I.e., the range represents the union of `self.start..=MAX` and `0..=self.end`.
+    /// Returns `false` if this is a non-wrapping range, i.e.: `self.start..=self.end`.
+    #[inline]
+    pub fn wraps_around(&self) -> bool {
+        self.start > self.end
+    }
+}
+
+impl Debug for WrappingRange {
+    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+        if self.start > self.end {
+            write!(fmt, "(..={}) | ({}..)", self.end, self.start)?;
+        } else {
+            write!(fmt, "{}..={}", self.start, self.end)?;
+        }
+        Ok(())
+    }
+}
 
 /// General language calling conventions.
 #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
diff --git a/compiler/stable_mir/src/error.rs b/compiler/stable_mir/src/error.rs
index 9e3f4936944..050752e41eb 100644
--- a/compiler/stable_mir/src/error.rs
+++ b/compiler/stable_mir/src/error.rs
@@ -5,12 +5,14 @@
 //! - [Error]: Generic error that represents the reason why a request that could not be fulfilled.
 
 use std::fmt::{Debug, Display, Formatter};
-use std::{error, fmt, io};
+use std::{fmt, io};
 
 macro_rules! error {
      ($fmt: literal $(,)?) => { Error(format!($fmt)) };
      ($fmt: literal, $($arg:tt)*) => { Error(format!($fmt, $($arg)*)) };
- }
+}
+
+pub(crate) use error;
 
 /// An error type used to represent an error that has already been reported by the compiler.
 #[derive(Clone, Copy, PartialEq, Eq)]
@@ -72,8 +74,9 @@ where
     }
 }
 
-impl error::Error for Error {}
-impl<T> error::Error for CompilerError<T> where T: Display + Debug {}
+impl std::error::Error for Error {}
+
+impl<T> std::error::Error for CompilerError<T> where T: Display + Debug {}
 
 impl From<io::Error> for Error {
     fn from(value: io::Error) -> Self {
diff --git a/compiler/stable_mir/src/target.rs b/compiler/stable_mir/src/target.rs
index 41ec205cfc7..3a9011a2ffe 100644
--- a/compiler/stable_mir/src/target.rs
+++ b/compiler/stable_mir/src/target.rs
@@ -30,21 +30,29 @@ pub enum Endian {
 }
 
 /// Represent the size of a component.
-#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
 pub struct MachineSize {
     num_bits: usize,
 }
 
 impl MachineSize {
+    #[inline(always)]
     pub fn bytes(self) -> usize {
         self.num_bits / 8
     }
 
+    #[inline(always)]
     pub fn bits(self) -> usize {
         self.num_bits
     }
 
+    #[inline(always)]
     pub fn from_bits(num_bits: usize) -> MachineSize {
         MachineSize { num_bits }
     }
+
+    #[inline]
+    pub fn unsigned_int_max(self) -> Option<u128> {
+        (self.num_bits <= 128).then(|| u128::MAX >> (128 - self.bits()))
+    }
 }
diff --git a/compiler/stable_mir/src/ty.rs b/compiler/stable_mir/src/ty.rs
index ed4a4290246..86cc748eaec 100644
--- a/compiler/stable_mir/src/ty.rs
+++ b/compiler/stable_mir/src/ty.rs
@@ -324,7 +324,9 @@ impl TyKind {
 
     #[inline]
     pub fn is_cstr(&self) -> bool {
-        let TyKind::RigidTy(RigidTy::Adt(def, _)) = self else { return false };
+        let TyKind::RigidTy(RigidTy::Adt(def, _)) = self else {
+            return false;
+        };
         with(|cx| cx.adt_is_cstr(*def))
     }
 
@@ -1032,10 +1034,13 @@ pub struct BoundTy {
 }
 
 pub type Bytes = Vec<Option<u8>>;
+
+/// Size in bytes.
 pub type Size = usize;
 
 #[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
 pub struct Prov(pub AllocId);
+
 pub type Align = u64;
 pub type Promoted = u32;
 pub type InitMaskMaterialized = Vec<u64>;
diff --git a/tests/ui-fulldeps/stable-mir/check_abi.rs b/tests/ui-fulldeps/stable-mir/check_abi.rs
index c345987955e..74801e007c4 100644
--- a/tests/ui-fulldeps/stable-mir/check_abi.rs
+++ b/tests/ui-fulldeps/stable-mir/check_abi.rs
@@ -19,8 +19,12 @@ extern crate rustc_interface;
 extern crate stable_mir;
 
 use rustc_smir::rustc_internal;
-use stable_mir::abi::{ArgAbi, CallConvention, FieldsShape, PassMode, VariantsShape};
+use stable_mir::abi::{
+    ArgAbi, CallConvention, FieldsShape, IntegerLength, PassMode, Primitive, Scalar, ValueAbi,
+    VariantsShape,
+};
 use stable_mir::mir::mono::Instance;
+use stable_mir::target::MachineInfo;
 use stable_mir::{CrateDef, CrateItem, CrateItems, ItemKind};
 use std::assert_matches::assert_matches;
 use std::convert::TryFrom;
@@ -39,11 +43,12 @@ fn test_stable_mir() -> ControlFlow<()> {
     let instance = Instance::try_from(target_fn).unwrap();
     let fn_abi = instance.fn_abi().unwrap();
     assert_eq!(fn_abi.conv, CallConvention::Rust);
-    assert_eq!(fn_abi.args.len(), 2);
+    assert_eq!(fn_abi.args.len(), 3);
 
     check_ignore(&fn_abi.args[0]);
     check_primitive(&fn_abi.args[1]);
-    check_result(fn_abi.ret);
+    check_niche(&fn_abi.args[2]);
+    check_result(&fn_abi.ret);
 
     // Test variadic function.
     let variadic_fn = *get_item(&items, (ItemKind::Fn, "variadic_fn")).unwrap();
@@ -85,7 +90,7 @@ fn check_primitive(abi: &ArgAbi) {
 }
 
 /// Check the return value: `Result<usize, &str>`.
-fn check_result(abi: ArgAbi) {
+fn check_result(abi: &ArgAbi) {
     assert!(abi.ty.kind().is_enum());
     assert_matches!(abi.mode, PassMode::Indirect { .. });
     let layout = abi.layout.shape();
@@ -94,6 +99,25 @@ fn check_result(abi: ArgAbi) {
     assert_matches!(layout.variants, VariantsShape::Multiple { .. })
 }
 
+/// Check the niche information about: `NonZeroU8`
+fn check_niche(abi: &ArgAbi) {
+    assert!(abi.ty.kind().is_struct());
+    assert_matches!(abi.mode, PassMode::Direct { .. });
+    let layout = abi.layout.shape();
+    assert!(layout.is_sized());
+    assert_eq!(layout.size.bytes(), 1);
+
+    let ValueAbi::Scalar(scalar) = layout.abi else { unreachable!() };
+    assert!(scalar.has_niche(&MachineInfo::target()), "Opps: {:?}", scalar);
+
+    let Scalar::Initialized { value, valid_range } = scalar else { unreachable!() };
+    assert_matches!(value, Primitive::Int { length: IntegerLength::I8, signed: false });
+    assert_eq!(valid_range.start, 1);
+    assert_eq!(valid_range.end, u8::MAX.into());
+    assert!(!valid_range.contains(0));
+    assert!(!valid_range.wraps_around());
+}
+
 fn get_item<'a>(
     items: &'a CrateItems,
     item: (ItemKind, &str),
@@ -126,11 +150,16 @@ fn generate_input(path: &str) -> std::io::Result<()> {
         #![feature(c_variadic)]
         #![allow(unused_variables)]
 
-        pub fn fn_abi(ignore: [u8; 0], primitive: char) -> Result<usize, &'static str> {{
-            // We only care about the signature.
-            todo!()
-        }}
+        use std::num::NonZeroU8;
 
+        pub fn fn_abi(
+            ignore: [u8; 0],
+            primitive: char,
+            niche: NonZeroU8,
+        ) -> Result<usize, &'static str> {{
+                // We only care about the signature.
+                todo!()
+        }}
 
         pub unsafe extern "C" fn variadic_fn(n: usize, mut args: ...) -> usize {{
             0