about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMatthias Krüger <matthias.krueger@famsik.de>2023-12-12 06:52:51 +0100
committerGitHub <noreply@github.com>2023-12-12 06:52:51 +0100
commitd67e80f12dadf15656d76c5b517631cd783efa0c (patch)
tree9f8e86642fd363a270bc4816f20f96febbce929f
parentcdc4fc9f35a7d982d98f4c09baa324cb2506aec5 (diff)
parent638b08ebdfde428a54efce2fd896da8d3332f403 (diff)
downloadrust-d67e80f12dadf15656d76c5b517631cd783efa0c.tar.gz
rust-d67e80f12dadf15656d76c5b517631cd783efa0c.zip
Rollup merge of #118846 - celinval:smir-ty-methods, r=compiler-errors
Fix BinOp `ty()` assertion and `fn_sig()` for closures

`BinOp::ty()` was asserting that the argument types were primitives. However, the primitive check doesn't include pointers, which can be used in a `BinaryOperation`. Thus extend the arguments to include them.

Since I had to add methods to check for pointers in TyKind, I just went ahead and added a bunch more utility checks that can be handy for our users and fixed the `fn_sig()` method to also include closures.

`@compiler-errors` just wanted to confirm that today no `BinaryOperation` accept SIMD types. Is that correct?

r? `@compiler-errors`
-rw-r--r--compiler/rustc_smir/src/rustc_smir/context.rs12
-rw-r--r--compiler/stable_mir/src/compiler_interface.rs6
-rw-r--r--compiler/stable_mir/src/mir/body.rs30
-rw-r--r--compiler/stable_mir/src/ty.rs121
4 files changed, 156 insertions, 13 deletions
diff --git a/compiler/rustc_smir/src/rustc_smir/context.rs b/compiler/rustc_smir/src/rustc_smir/context.rs
index 22e9f66ba96..4ec5e2a5387 100644
--- a/compiler/rustc_smir/src/rustc_smir/context.rs
+++ b/compiler/rustc_smir/src/rustc_smir/context.rs
@@ -213,6 +213,11 @@ impl<'tcx> Context for TablesWrapper<'tcx> {
         def.internal(&mut *tables).is_box()
     }
 
+    fn adt_is_simd(&self, def: AdtDef) -> bool {
+        let mut tables = self.0.borrow_mut();
+        def.internal(&mut *tables).repr().simd()
+    }
+
     fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig {
         let mut tables = self.0.borrow_mut();
         let def_id = def.0.internal(&mut *tables);
@@ -220,6 +225,13 @@ impl<'tcx> Context for TablesWrapper<'tcx> {
         sig.stable(&mut *tables)
     }
 
+    fn closure_sig(&self, args: &GenericArgs) -> PolyFnSig {
+        let mut tables = self.0.borrow_mut();
+        let args_ref = args.internal(&mut *tables);
+        let sig = args_ref.as_closure().sig();
+        sig.stable(&mut *tables)
+    }
+
     fn adt_variants_len(&self, def: AdtDef) -> usize {
         let mut tables = self.0.borrow_mut();
         def.internal(&mut *tables).variants().len()
diff --git a/compiler/stable_mir/src/compiler_interface.rs b/compiler/stable_mir/src/compiler_interface.rs
index 17c5212fb9c..2fac59e71fd 100644
--- a/compiler/stable_mir/src/compiler_interface.rs
+++ b/compiler/stable_mir/src/compiler_interface.rs
@@ -69,9 +69,15 @@ pub trait Context {
     /// Returns if the ADT is a box.
     fn adt_is_box(&self, def: AdtDef) -> bool;
 
+    /// Returns whether this ADT is simd.
+    fn adt_is_simd(&self, def: AdtDef) -> bool;
+
     /// Retrieve the function signature for the given generic arguments.
     fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig;
 
+    /// Retrieve the closure signature for the given generic arguments.
+    fn closure_sig(&self, args: &GenericArgs) -> PolyFnSig;
+
     /// The number of variants in this ADT.
     fn adt_variants_len(&self, def: AdtDef) -> usize;
 
diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs
index 663275d9a0f..3dfe7096399 100644
--- a/compiler/stable_mir/src/mir/body.rs
+++ b/compiler/stable_mir/src/mir/body.rs
@@ -228,7 +228,7 @@ pub struct InlineAsmOperand {
     pub raw_rpr: String,
 }
 
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum UnwindAction {
     Continue,
     Unreachable,
@@ -248,7 +248,7 @@ pub enum AssertMessage {
     MisalignedPointerDereference { required: Operand, found: Operand },
 }
 
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum BinOp {
     Add,
     AddUnchecked,
@@ -278,8 +278,6 @@ impl BinOp {
     /// Return the type of this operation for the given input Ty.
     /// This function does not perform type checking, and it currently doesn't handle SIMD.
     pub fn ty(&self, lhs_ty: Ty, rhs_ty: Ty) -> Ty {
-        assert!(lhs_ty.kind().is_primitive());
-        assert!(rhs_ty.kind().is_primitive());
         match self {
             BinOp::Add
             | BinOp::AddUnchecked
@@ -293,20 +291,30 @@ impl BinOp {
             | BinOp::BitAnd
             | BinOp::BitOr => {
                 assert_eq!(lhs_ty, rhs_ty);
+                assert!(lhs_ty.kind().is_primitive());
                 lhs_ty
             }
-            BinOp::Shl | BinOp::ShlUnchecked | BinOp::Shr | BinOp::ShrUnchecked | BinOp::Offset => {
+            BinOp::Shl | BinOp::ShlUnchecked | BinOp::Shr | BinOp::ShrUnchecked => {
+                assert!(lhs_ty.kind().is_primitive());
+                assert!(rhs_ty.kind().is_primitive());
+                lhs_ty
+            }
+            BinOp::Offset => {
+                assert!(lhs_ty.kind().is_raw_ptr());
+                assert!(rhs_ty.kind().is_integral());
                 lhs_ty
             }
             BinOp::Eq | BinOp::Lt | BinOp::Le | BinOp::Ne | BinOp::Ge | BinOp::Gt => {
                 assert_eq!(lhs_ty, rhs_ty);
+                let lhs_kind = lhs_ty.kind();
+                assert!(lhs_kind.is_primitive() || lhs_kind.is_raw_ptr() || lhs_kind.is_fn_ptr());
                 Ty::bool_ty()
             }
         }
     }
 }
 
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum UnOp {
     Not,
     Neg,
@@ -319,7 +327,7 @@ pub enum CoroutineKind {
     Gen(CoroutineSource),
 }
 
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum CoroutineSource {
     Block,
     Closure,
@@ -343,7 +351,7 @@ pub enum FakeReadCause {
 }
 
 /// Describes what kind of retag is to be performed
-#[derive(Clone, Debug, Eq, PartialEq, Hash)]
+#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
 pub enum RetagKind {
     FnEntry,
     TwoPhase,
@@ -351,7 +359,7 @@ pub enum RetagKind {
     Default,
 }
 
-#[derive(Clone, Debug, Eq, PartialEq, Hash)]
+#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
 pub enum Variance {
     Covariant,
     Invariant,
@@ -862,7 +870,7 @@ pub enum Safety {
     Normal,
 }
 
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum PointerCoercion {
     /// Go from a fn-item type to a fn-pointer type.
     ReifyFnPointer,
@@ -889,7 +897,7 @@ pub enum PointerCoercion {
     Unsize,
 }
 
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum CastKind {
     PointerExposeAddress,
     PointerFromExposedAddress,
diff --git a/compiler/stable_mir/src/ty.rs b/compiler/stable_mir/src/ty.rs
index bea7702bd34..f473fd8dbb7 100644
--- a/compiler/stable_mir/src/ty.rs
+++ b/compiler/stable_mir/src/ty.rs
@@ -214,38 +214,62 @@ impl TyKind {
         if let TyKind::RigidTy(inner) = self { Some(inner) } else { None }
     }
 
+    #[inline]
     pub fn is_unit(&self) -> bool {
         matches!(self, TyKind::RigidTy(RigidTy::Tuple(data)) if data.is_empty())
     }
 
+    #[inline]
     pub fn is_bool(&self) -> bool {
         matches!(self, TyKind::RigidTy(RigidTy::Bool))
     }
 
+    #[inline]
+    pub fn is_char(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::Char))
+    }
+
+    #[inline]
     pub fn is_trait(&self) -> bool {
         matches!(self, TyKind::RigidTy(RigidTy::Dynamic(_, _, DynKind::Dyn)))
     }
 
+    #[inline]
     pub fn is_enum(&self) -> bool {
         matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.kind() == AdtKind::Enum)
     }
 
+    #[inline]
     pub fn is_struct(&self) -> bool {
         matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.kind() == AdtKind::Struct)
     }
 
+    #[inline]
     pub fn is_union(&self) -> bool {
         matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.kind() == AdtKind::Union)
     }
 
+    #[inline]
+    pub fn is_adt(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::Adt(..)))
+    }
+
+    #[inline]
+    pub fn is_ref(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::Ref(..)))
+    }
+
+    #[inline]
     pub fn is_fn(&self) -> bool {
         matches!(self, TyKind::RigidTy(RigidTy::FnDef(..)))
     }
 
+    #[inline]
     pub fn is_fn_ptr(&self) -> bool {
         matches!(self, TyKind::RigidTy(RigidTy::FnPtr(..)))
     }
 
+    #[inline]
     pub fn is_primitive(&self) -> bool {
         matches!(
             self,
@@ -259,6 +283,84 @@ impl TyKind {
         )
     }
 
+    #[inline]
+    pub fn is_float(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::Float(_)))
+    }
+
+    #[inline]
+    pub fn is_integral(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::Int(_) | RigidTy::Uint(_)))
+    }
+
+    #[inline]
+    pub fn is_numeric(&self) -> bool {
+        self.is_integral() || self.is_float()
+    }
+
+    #[inline]
+    pub fn is_signed(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::Int(_)))
+    }
+
+    #[inline]
+    pub fn is_str(&self) -> bool {
+        *self == TyKind::RigidTy(RigidTy::Str)
+    }
+
+    #[inline]
+    pub fn is_slice(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::Slice(_)))
+    }
+
+    #[inline]
+    pub fn is_array(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::Array(..)))
+    }
+
+    #[inline]
+    pub fn is_mutable_ptr(&self) -> bool {
+        matches!(
+            self,
+            TyKind::RigidTy(RigidTy::RawPtr(_, Mutability::Mut))
+                | TyKind::RigidTy(RigidTy::Ref(_, _, Mutability::Mut))
+        )
+    }
+
+    #[inline]
+    pub fn is_raw_ptr(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::RawPtr(..)))
+    }
+
+    /// Tests if this is any kind of primitive pointer type (reference, raw pointer, fn pointer).
+    #[inline]
+    pub fn is_any_ptr(&self) -> bool {
+        self.is_ref() || self.is_raw_ptr() || self.is_fn_ptr()
+    }
+
+    #[inline]
+    pub fn is_coroutine(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::Coroutine(..)))
+    }
+
+    #[inline]
+    pub fn is_closure(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::Closure(..)))
+    }
+
+    #[inline]
+    pub fn is_box(&self) -> bool {
+        match self {
+            TyKind::RigidTy(RigidTy::Adt(def, _)) => def.is_box(),
+            _ => false,
+        }
+    }
+
+    #[inline]
+    pub fn is_simd(&self) -> bool {
+        matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.is_simd())
+    }
+
     pub fn trait_principal(&self) -> Option<Binder<ExistentialTraitRef>> {
         if let TyKind::RigidTy(RigidTy::Dynamic(predicates, _, _)) = self {
             if let Some(Binder { value: ExistentialPredicate::Trait(trait_ref), bound_vars }) =
@@ -300,12 +402,12 @@ impl TyKind {
         }
     }
 
-    /// Get the function signature for function like types (Fn, FnPtr, Closure, Coroutine)
-    /// FIXME(closure)
+    /// Get the function signature for function like types (Fn, FnPtr, and Closure)
     pub fn fn_sig(&self) -> Option<PolyFnSig> {
         match self {
             TyKind::RigidTy(RigidTy::FnDef(def, args)) => Some(with(|cx| cx.fn_sig(*def, args))),
             TyKind::RigidTy(RigidTy::FnPtr(sig)) => Some(sig.clone()),
+            TyKind::RigidTy(RigidTy::Closure(_def, args)) => Some(with(|cx| cx.closure_sig(args))),
             _ => None,
         }
     }
@@ -481,6 +583,10 @@ impl AdtDef {
         with(|cx| cx.adt_is_box(*self))
     }
 
+    pub fn is_simd(&self) -> bool {
+        with(|cx| cx.adt_is_simd(*self))
+    }
+
     /// The number of variants in this ADT.
     pub fn num_variants(&self) -> usize {
         with(|cx| cx.adt_variants_len(*self))
@@ -738,6 +844,7 @@ pub enum Abi {
     RiscvInterruptS,
 }
 
+/// A binder represents a possibly generic type and its bound vars.
 #[derive(Clone, Debug, Eq, PartialEq)]
 pub struct Binder<T> {
     pub value: T,
@@ -745,6 +852,16 @@ pub struct Binder<T> {
 }
 
 impl<T> Binder<T> {
+    /// Create a new binder with the given bound vars.
+    pub fn bind_with_vars(value: T, bound_vars: Vec<BoundVariableKind>) -> Self {
+        Binder { value, bound_vars }
+    }
+
+    /// Create a new binder with no bounded variable.
+    pub fn dummy(value: T) -> Self {
+        Binder { value, bound_vars: vec![] }
+    }
+
     pub fn skip_binder(self) -> T {
         self.value
     }