about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_const_eval/src/interpret/intrinsics.rs47
-rw-r--r--compiler/rustc_const_eval/src/interpret/operand.rs12
-rw-r--r--compiler/rustc_const_eval/src/interpret/place.rs28
-rw-r--r--compiler/rustc_middle/src/ty/sty.rs5
-rw-r--r--src/test/ui/consts/const-eval/simd/insert_extract.rs13
5 files changed, 67 insertions, 38 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs
index 698742fe98c..44da27a43db 100644
--- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs
+++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs
@@ -413,48 +413,33 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
             sym::simd_insert => {
                 let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
                 let elem = &args[2];
-                let input = &args[0];
-                let (len, e_ty) = input.layout.ty.simd_size_and_type(*self.tcx);
+                let (input, input_len) = self.operand_to_simd(&args[0])?;
+                let (dest, dest_len) = self.place_to_simd(dest)?;
+                assert_eq!(input_len, dest_len, "Return vector length must match input length");
                 assert!(
-                    index < len,
-                    "Index `{}` must be in bounds of vector type `{}`: `[0, {})`",
+                    index < dest_len,
+                    "Index `{}` must be in bounds of vector with length {}`",
                     index,
-                    e_ty,
-                    len
-                );
-                assert_eq!(
-                    input.layout, dest.layout,
-                    "Return type `{}` must match vector type `{}`",
-                    dest.layout.ty, input.layout.ty
-                );
-                assert_eq!(
-                    elem.layout.ty, e_ty,
-                    "Scalar element type `{}` must match vector element type `{}`",
-                    elem.layout.ty, e_ty
+                    dest_len
                 );
 
-                for i in 0..len {
-                    let place = self.place_index(dest, i)?;
-                    let value = if i == index { *elem } else { self.operand_index(input, i)? };
-                    self.copy_op(&value, &place)?;
+                for i in 0..dest_len {
+                    let place = self.mplace_index(&dest, i)?;
+                    let value =
+                        if i == index { *elem } else { self.mplace_index(&input, i)?.into() };
+                    self.copy_op(&value, &place.into())?;
                 }
             }
             sym::simd_extract => {
                 let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
-                let (len, e_ty) = args[0].layout.ty.simd_size_and_type(*self.tcx);
+                let (input, input_len) = self.operand_to_simd(&args[0])?;
                 assert!(
-                    index < len,
-                    "index `{}` is out-of-bounds of vector type `{}` with length `{}`",
+                    index < input_len,
+                    "index `{}` must be in bounds of vector with length `{}`",
                     index,
-                    e_ty,
-                    len
-                );
-                assert_eq!(
-                    e_ty, dest.layout.ty,
-                    "Return type `{}` must match vector element type `{}`",
-                    dest.layout.ty, e_ty
+                    input_len
                 );
-                self.copy_op(&self.operand_index(&args[0], index)?, dest)?;
+                self.copy_op(&self.mplace_index(&input, index)?.into(), dest)?;
             }
             sym::likely | sym::unlikely | sym::black_box => {
                 // These just return their argument
diff --git a/compiler/rustc_const_eval/src/interpret/operand.rs b/compiler/rustc_const_eval/src/interpret/operand.rs
index b6682b13ed2..de9e94ce2ac 100644
--- a/compiler/rustc_const_eval/src/interpret/operand.rs
+++ b/compiler/rustc_const_eval/src/interpret/operand.rs
@@ -437,6 +437,18 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
         })
     }
 
+    /// Converts a repr(simd) operand into an operand where `place_index` accesses the SIMD elements.
+    /// Also returns the number of elements.
+    pub fn operand_to_simd(
+        &self,
+        base: &OpTy<'tcx, M::PointerTag>,
+    ) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, u64)> {
+        // Basically we just transmute this place into an array following simd_size_and_type.
+        // This only works in memory, but repr(simd) types should never be immediates anyway.
+        assert!(base.layout.ty.is_simd());
+        self.mplace_to_simd(&base.assert_mem_place())
+    }
+
     /// Read from a local. Will not actually access the local if reading from a ZST.
     /// Will not access memory, instead an indirect `Operand` is returned.
     ///
diff --git a/compiler/rustc_const_eval/src/interpret/place.rs b/compiler/rustc_const_eval/src/interpret/place.rs
index d425b84bdaf..d7f2853fc86 100644
--- a/compiler/rustc_const_eval/src/interpret/place.rs
+++ b/compiler/rustc_const_eval/src/interpret/place.rs
@@ -200,7 +200,7 @@ impl<'tcx, Tag: Provenance> MPlaceTy<'tcx, Tag> {
             }
         } else {
             // Go through the layout.  There are lots of types that support a length,
-            // e.g., SIMD types.
+            // e.g., SIMD types. (But not all repr(simd) types even have FieldsShape::Array!)
             match self.layout.fields {
                 FieldsShape::Array { count, .. } => Ok(count),
                 _ => bug!("len not supported on sized type {:?}", self.layout.ty),
@@ -533,6 +533,22 @@ where
         })
     }
 
+    /// Converts a repr(simd) place into a place where `place_index` accesses the SIMD elements.
+    /// Also returns the number of elements.
+    pub fn mplace_to_simd(
+        &self,
+        base: &MPlaceTy<'tcx, M::PointerTag>,
+    ) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, u64)> {
+        // Basically we just transmute this place into an array following simd_size_and_type.
+        // (Transmuting is okay since this is an in-memory place. We also double-check the size
+        // stays the same.)
+        let (len, e_ty) = base.layout.ty.simd_size_and_type(*self.tcx);
+        let array = self.tcx.mk_array(e_ty, len);
+        let layout = self.layout_of(array)?;
+        assert_eq!(layout.size, base.layout.size);
+        Ok((MPlaceTy { layout, ..*base }, len))
+    }
+
     /// Gets the place of a field inside the place, and also the field's type.
     /// Just a convenience function, but used quite a bit.
     /// This is the only projection that might have a side-effect: We cannot project
@@ -594,6 +610,16 @@ where
         })
     }
 
+    /// Converts a repr(simd) place into a place where `place_index` accesses the SIMD elements.
+    /// Also returns the number of elements.
+    pub fn place_to_simd(
+        &mut self,
+        base: &PlaceTy<'tcx, M::PointerTag>,
+    ) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, u64)> {
+        let mplace = self.force_allocation(base)?;
+        self.mplace_to_simd(&mplace)
+    }
+
     /// Computes a place. You should only use this if you intend to write into this
     /// place; for reading, a more efficient alternative is `eval_place_for_read`.
     pub fn eval_place(
diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs
index 610f9bd8f82..c79e25f4781 100644
--- a/compiler/rustc_middle/src/ty/sty.rs
+++ b/compiler/rustc_middle/src/ty/sty.rs
@@ -1805,10 +1805,13 @@ impl<'tcx> TyS<'tcx> {
     pub fn simd_size_and_type(&self, tcx: TyCtxt<'tcx>) -> (u64, Ty<'tcx>) {
         match self.kind() {
             Adt(def, substs) => {
+                assert!(def.repr.simd(), "`simd_size_and_type` called on non-SIMD type");
                 let variant = def.non_enum_variant();
                 let f0_ty = variant.fields[0].ty(tcx, substs);
 
                 match f0_ty.kind() {
+                    // If the first field is an array, we assume it is the only field and its
+                    // elements are the SIMD components.
                     Array(f0_elem_ty, f0_len) => {
                         // FIXME(repr_simd): https://github.com/rust-lang/rust/pull/78863#discussion_r522784112
                         // The way we evaluate the `N` in `[T; N]` here only works since we use
@@ -1816,6 +1819,8 @@ impl<'tcx> TyS<'tcx> {
                         // if we use it in generic code. See the `simd-array-trait` ui test.
                         (f0_len.eval_usize(tcx, ParamEnv::empty()) as u64, f0_elem_ty)
                     }
+                    // Otherwise, the fields of this Adt are the SIMD components (and we assume they
+                    // all have the same type).
                     _ => (variant.fields.len() as u64, f0_ty),
                 }
             }
diff --git a/src/test/ui/consts/const-eval/simd/insert_extract.rs b/src/test/ui/consts/const-eval/simd/insert_extract.rs
index cae8fcf1068..a1cd24077c6 100644
--- a/src/test/ui/consts/const-eval/simd/insert_extract.rs
+++ b/src/test/ui/consts/const-eval/simd/insert_extract.rs
@@ -7,7 +7,8 @@
 
 #[repr(simd)] struct i8x1(i8);
 #[repr(simd)] struct u16x2(u16, u16);
-#[repr(simd)] struct f32x4(f32, f32, f32, f32);
+// Make one of them an array type to ensure those also work.
+#[repr(simd)] struct f32x4([f32; 4]);
 
 extern "platform-intrinsic" {
     #[rustc_const_stable(feature = "foo", since = "1.3.37")]
@@ -38,12 +39,12 @@ fn main() {
         assert_eq!(Y1, 42);
     }
     {
-        const U: f32x4 = f32x4(13., 14., 15., 16.);
+        const U: f32x4 = f32x4([13., 14., 15., 16.]);
         const V: f32x4 = unsafe { simd_insert(U, 1_u32, 42_f32) };
-        const X0: f32 = V.0;
-        const X1: f32 = V.1;
-        const X2: f32 = V.2;
-        const X3: f32 = V.3;
+        const X0: f32 = V.0[0];
+        const X1: f32 = V.0[1];
+        const X2: f32 = V.0[2];
+        const X3: f32 = V.0[3];
         const Y0: f32 = unsafe { simd_extract(V, 0) };
         const Y1: f32 = unsafe { simd_extract(V, 1) };
         const Y2: f32 = unsafe { simd_extract(V, 2) };