about summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Desjardins <erikdesjardins@users.noreply.github.com>2023-07-29 16:56:27 -0400
committerErik Desjardins <erikdesjardins@users.noreply.github.com>2023-07-29 16:56:27 -0400
commit55800123b73067ec98293f49ded6739036b0aca4 (patch)
tree1c112bac60f4e323b0d014ef1dc45131173baaed
parentcf7788d54b1b91cd3e778984f8dceada224e28ad (diff)
downloadrust-55800123b73067ec98293f49ded6739036b0aca4.tar.gz
rust-55800123b73067ec98293f49ded6739036b0aca4.zip
cg_llvm: simplify llvm.masked.gather/scatter naming with opaque pointers
With opaque pointers, there's no longer a need to generate a chain
of pointer types in the intrinsic name when arguments are pointers to
pointers.
-rw-r--r--compiler/rustc_codegen_llvm/src/intrinsic.rs152
-rw-r--r--tests/codegen/simd-intrinsic/simd-intrinsic-generic-gather.rs4
-rw-r--r--tests/codegen/simd-intrinsic/simd-intrinsic-generic-scatter.rs4
3 files changed, 51 insertions, 109 deletions
diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs
index 55f28088640..b62aa9cb308 100644
--- a/compiler/rustc_codegen_llvm/src/intrinsic.rs
+++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs
@@ -1307,49 +1307,34 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
     // FIXME: use:
     //  https://github.com/llvm-mirror/llvm/blob/master/include/llvm/IR/Function.h#L182
     //  https://github.com/llvm-mirror/llvm/blob/master/include/llvm/IR/Intrinsics.h#L81
-    fn llvm_vector_str(
-        elem_ty: Ty<'_>,
-        vec_len: u64,
-        no_pointers: usize,
-        bx: &Builder<'_, '_, '_>,
-    ) -> String {
-        let p0s: String = "p0".repeat(no_pointers);
+    fn llvm_vector_str(bx: &Builder<'_, '_, '_>, elem_ty: Ty<'_>, vec_len: u64) -> String {
         match *elem_ty.kind() {
             ty::Int(v) => format!(
-                "v{}{}i{}",
+                "v{}i{}",
                 vec_len,
-                p0s,
                 // Normalize to prevent crash if v: IntTy::Isize
                 v.normalize(bx.target_spec().pointer_width).bit_width().unwrap()
             ),
             ty::Uint(v) => format!(
-                "v{}{}i{}",
+                "v{}i{}",
                 vec_len,
-                p0s,
                 // Normalize to prevent crash if v: UIntTy::Usize
                 v.normalize(bx.target_spec().pointer_width).bit_width().unwrap()
             ),
-            ty::Float(v) => format!("v{}{}f{}", vec_len, p0s, v.bit_width()),
+            ty::Float(v) => format!("v{}f{}", vec_len, v.bit_width()),
+            ty::RawPtr(_) => format!("v{}p0", vec_len),
             _ => unreachable!(),
         }
     }
 
-    fn llvm_vector_ty<'ll>(
-        cx: &CodegenCx<'ll, '_>,
-        elem_ty: Ty<'_>,
-        vec_len: u64,
-        no_pointers: usize,
-    ) -> &'ll Type {
-        // FIXME: use cx.layout_of(ty).llvm_type() ?
-        let mut elem_ty = match *elem_ty.kind() {
+    fn llvm_vector_ty<'ll>(cx: &CodegenCx<'ll, '_>, elem_ty: Ty<'_>, vec_len: u64) -> &'ll Type {
+        let elem_ty = match *elem_ty.kind() {
             ty::Int(v) => cx.type_int_from_ty(v),
             ty::Uint(v) => cx.type_uint_from_ty(v),
             ty::Float(v) => cx.type_float_from_ty(v),
+            ty::RawPtr(_) => cx.type_ptr(),
             _ => unreachable!(),
         };
-        if no_pointers > 0 {
-            elem_ty = cx.type_ptr();
-        }
         cx.type_vector(elem_ty, vec_len)
     }
 
@@ -1404,47 +1389,26 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
             InvalidMonomorphization::ExpectedReturnType { span, name, in_ty, ret_ty }
         );
 
-        // This counts how many pointers
-        fn ptr_count(t: Ty<'_>) -> usize {
-            match t.kind() {
-                ty::RawPtr(p) => 1 + ptr_count(p.ty),
-                _ => 0,
-            }
-        }
-
-        // Non-ptr type
-        fn non_ptr(t: Ty<'_>) -> Ty<'_> {
-            match t.kind() {
-                ty::RawPtr(p) => non_ptr(p.ty),
-                _ => t,
-            }
-        }
-
         // The second argument must be a simd vector with an element type that's a pointer
         // to the element type of the first argument
         let (_, element_ty0) = arg_tys[0].simd_size_and_type(bx.tcx());
         let (_, element_ty1) = arg_tys[1].simd_size_and_type(bx.tcx());
-        let (pointer_count, underlying_ty) = match element_ty1.kind() {
-            ty::RawPtr(p) if p.ty == in_elem => (ptr_count(element_ty1), non_ptr(element_ty1)),
-            _ => {
-                require!(
-                    false,
-                    InvalidMonomorphization::ExpectedElementType {
-                        span,
-                        name,
-                        expected_element: element_ty1,
-                        second_arg: arg_tys[1],
-                        in_elem,
-                        in_ty,
-                        mutability: ExpectedPointerMutability::Not,
-                    }
-                );
-                unreachable!();
+
+        require!(
+            matches!(
+                element_ty1.kind(),
+                ty::RawPtr(p) if p.ty == in_elem && p.ty.kind() == element_ty0.kind()
+            ),
+            InvalidMonomorphization::ExpectedElementType {
+                span,
+                name,
+                expected_element: element_ty1,
+                second_arg: arg_tys[1],
+                in_elem,
+                in_ty,
+                mutability: ExpectedPointerMutability::Not,
             }
-        };
-        assert!(pointer_count > 0);
-        assert_eq!(pointer_count - 1, ptr_count(element_ty0));
-        assert_eq!(underlying_ty, non_ptr(element_ty0));
+        );
 
         // The element type of the third argument must be a signed integer type of any width:
         let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx());
@@ -1475,12 +1439,12 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
         };
 
         // Type of the vector of pointers:
-        let llvm_pointer_vec_ty = llvm_vector_ty(bx, underlying_ty, in_len, pointer_count);
-        let llvm_pointer_vec_str = llvm_vector_str(underlying_ty, in_len, pointer_count, bx);
+        let llvm_pointer_vec_ty = llvm_vector_ty(bx, element_ty1, in_len);
+        let llvm_pointer_vec_str = llvm_vector_str(bx, element_ty1, in_len);
 
         // Type of the vector of elements:
-        let llvm_elem_vec_ty = llvm_vector_ty(bx, underlying_ty, in_len, pointer_count - 1);
-        let llvm_elem_vec_str = llvm_vector_str(underlying_ty, in_len, pointer_count - 1, bx);
+        let llvm_elem_vec_ty = llvm_vector_ty(bx, element_ty0, in_len);
+        let llvm_elem_vec_str = llvm_vector_str(bx, element_ty0, in_len);
 
         let llvm_intrinsic =
             format!("llvm.masked.gather.{}.{}", llvm_elem_vec_str, llvm_pointer_vec_str);
@@ -1544,50 +1508,28 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
             }
         );
 
-        // This counts how many pointers
-        fn ptr_count(t: Ty<'_>) -> usize {
-            match t.kind() {
-                ty::RawPtr(p) => 1 + ptr_count(p.ty),
-                _ => 0,
-            }
-        }
-
-        // Non-ptr type
-        fn non_ptr(t: Ty<'_>) -> Ty<'_> {
-            match t.kind() {
-                ty::RawPtr(p) => non_ptr(p.ty),
-                _ => t,
-            }
-        }
-
         // The second argument must be a simd vector with an element type that's a pointer
         // to the element type of the first argument
         let (_, element_ty0) = arg_tys[0].simd_size_and_type(bx.tcx());
         let (_, element_ty1) = arg_tys[1].simd_size_and_type(bx.tcx());
         let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx());
-        let (pointer_count, underlying_ty) = match element_ty1.kind() {
-            ty::RawPtr(p) if p.ty == in_elem && p.mutbl.is_mut() => {
-                (ptr_count(element_ty1), non_ptr(element_ty1))
-            }
-            _ => {
-                require!(
-                    false,
-                    InvalidMonomorphization::ExpectedElementType {
-                        span,
-                        name,
-                        expected_element: element_ty1,
-                        second_arg: arg_tys[1],
-                        in_elem,
-                        in_ty,
-                        mutability: ExpectedPointerMutability::Mut,
-                    }
-                );
-                unreachable!();
+
+        require!(
+            matches!(
+                element_ty1.kind(),
+                ty::RawPtr(p)
+                    if p.ty == in_elem && p.mutbl.is_mut() && p.ty.kind() == element_ty0.kind()
+            ),
+            InvalidMonomorphization::ExpectedElementType {
+                span,
+                name,
+                expected_element: element_ty1,
+                second_arg: arg_tys[1],
+                in_elem,
+                in_ty,
+                mutability: ExpectedPointerMutability::Mut,
             }
-        };
-        assert!(pointer_count > 0);
-        assert_eq!(pointer_count - 1, ptr_count(element_ty0));
-        assert_eq!(underlying_ty, non_ptr(element_ty0));
+        );
 
         // The element type of the third argument must be a signed integer type of any width:
         match element_ty2.kind() {
@@ -1619,12 +1561,12 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
         let ret_t = bx.type_void();
 
         // Type of the vector of pointers:
-        let llvm_pointer_vec_ty = llvm_vector_ty(bx, underlying_ty, in_len, pointer_count);
-        let llvm_pointer_vec_str = llvm_vector_str(underlying_ty, in_len, pointer_count, bx);
+        let llvm_pointer_vec_ty = llvm_vector_ty(bx, element_ty1, in_len);
+        let llvm_pointer_vec_str = llvm_vector_str(bx, element_ty1, in_len);
 
         // Type of the vector of elements:
-        let llvm_elem_vec_ty = llvm_vector_ty(bx, underlying_ty, in_len, pointer_count - 1);
-        let llvm_elem_vec_str = llvm_vector_str(underlying_ty, in_len, pointer_count - 1, bx);
+        let llvm_elem_vec_ty = llvm_vector_ty(bx, element_ty0, in_len);
+        let llvm_elem_vec_str = llvm_vector_str(bx, element_ty0, in_len);
 
         let llvm_intrinsic =
             format!("llvm.masked.scatter.{}.{}", llvm_elem_vec_str, llvm_pointer_vec_str);
diff --git a/tests/codegen/simd-intrinsic/simd-intrinsic-generic-gather.rs b/tests/codegen/simd-intrinsic/simd-intrinsic-generic-gather.rs
index 7fe3ffd2086..0bb21019685 100644
--- a/tests/codegen/simd-intrinsic/simd-intrinsic-generic-gather.rs
+++ b/tests/codegen/simd-intrinsic/simd-intrinsic-generic-gather.rs
@@ -23,7 +23,7 @@ extern "platform-intrinsic" {
 #[no_mangle]
 pub unsafe fn gather_f32x2(pointers: Vec2<*const f32>, mask: Vec2<i32>,
                            values: Vec2<f32>) -> Vec2<f32> {
-    // CHECK: call <2 x float> @llvm.masked.gather.v2f32.{{.+}}(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}, <2 x float> {{.*}})
+    // CHECK: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}, <2 x float> {{.*}})
     simd_gather(values, pointers, mask)
 }
 
@@ -31,6 +31,6 @@ pub unsafe fn gather_f32x2(pointers: Vec2<*const f32>, mask: Vec2<i32>,
 #[no_mangle]
 pub unsafe fn gather_pf32x2(pointers: Vec2<*const *const f32>, mask: Vec2<i32>,
                            values: Vec2<*const f32>) -> Vec2<*const f32> {
-    // CHECK: call <2 x ptr> @llvm.masked.gather.{{.+}}(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}, <2 x ptr> {{.*}})
+    // CHECK: call <2 x ptr> @llvm.masked.gather.v2p0.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}, <2 x ptr> {{.*}})
     simd_gather(values, pointers, mask)
 }
diff --git a/tests/codegen/simd-intrinsic/simd-intrinsic-generic-scatter.rs b/tests/codegen/simd-intrinsic/simd-intrinsic-generic-scatter.rs
index 5c917474e45..51953560b4f 100644
--- a/tests/codegen/simd-intrinsic/simd-intrinsic-generic-scatter.rs
+++ b/tests/codegen/simd-intrinsic/simd-intrinsic-generic-scatter.rs
@@ -23,7 +23,7 @@ extern "platform-intrinsic" {
 #[no_mangle]
 pub unsafe fn scatter_f32x2(pointers: Vec2<*mut f32>, mask: Vec2<i32>,
                             values: Vec2<f32>) {
-    // CHECK: call void @llvm.masked.scatter.v2f32.v2p0{{.*}}(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}})
+    // CHECK: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}})
     simd_scatter(values, pointers, mask)
 }
 
@@ -32,6 +32,6 @@ pub unsafe fn scatter_f32x2(pointers: Vec2<*mut f32>, mask: Vec2<i32>,
 #[no_mangle]
 pub unsafe fn scatter_pf32x2(pointers: Vec2<*mut *const f32>, mask: Vec2<i32>,
                              values: Vec2<*const f32>) {
-    // CHECK: call void @llvm.masked.scatter.v2p0{{.*}}.v2p0{{.*}}(<2 x ptr> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}})
+    // CHECK: call void @llvm.masked.scatter.v2p0.v2p0(<2 x ptr> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}})
     simd_scatter(values, pointers, mask)
 }