about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorJakub Okoński <jakub@okonski.org>2023-11-15 23:08:02 +0100
committerJakub Okoński <jakub@okonski.org>2023-12-09 12:36:08 +0100
commit97ae5095f52100b98170ab476e516d2be5b2c297 (patch)
tree5d67c444603f183d86c69b19a8bfeb24dd6e6264 /compiler
parentc41669970a181b07ecf57c4607e50706f5d1e0c8 (diff)
downloadrust-97ae5095f52100b98170ab476e516d2be5b2c297.tar.gz
rust-97ae5095f52100b98170ab476e516d2be5b2c297.zip
Add simd_masked_{load,store} platform-intrinsics
This maps to the LLVM intrinsics: llvm.masked.load and llvm.masked.store
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs52
-rw-r--r--compiler/rustc_codegen_llvm/src/intrinsic.rs192
-rw-r--r--compiler/rustc_hir_analysis/src/check/intrinsic.rs2
-rw-r--r--compiler/rustc_span/src/symbol.rs2
4 files changed, 247 insertions, 1 deletions
diff --git a/compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs b/compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs
index 0bd211fd614..5997e6026b4 100644
--- a/compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs
+++ b/compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs
@@ -1,5 +1,6 @@
 //! Codegen `extern "platform-intrinsic"` intrinsics.
 
+use cranelift_codegen::ir::immediates::Offset32;
 use rustc_middle::ty::GenericArgsRef;
 use rustc_span::Symbol;
 use rustc_target::abi::Endian;
@@ -1008,8 +1009,57 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
             }
         }
 
+        sym::simd_masked_load => {
+            intrinsic_args!(fx, args => (mask, ptr, val); intrinsic);
+
+            let (val_lane_count, val_lane_ty) = val.layout().ty.simd_size_and_type(fx.tcx);
+            let (mask_lane_count, _mask_lane_ty) = mask.layout().ty.simd_size_and_type(fx.tcx);
+            let (ret_lane_count, ret_lane_ty) = ret.layout().ty.simd_size_and_type(fx.tcx);
+            assert_eq!(val_lane_count, mask_lane_count);
+            assert_eq!(val_lane_count, ret_lane_count);
+
+            let lane_clif_ty = fx.clif_type(val_lane_ty).unwrap();
+            let ret_lane_layout = fx.layout_of(ret_lane_ty);
+            let ptr_val = ptr.load_scalar(fx);
+
+            for lane_idx in 0..ret_lane_count {
+                let val_lane = val.value_lane(fx, lane_idx).load_scalar(fx);
+                let mask_lane = mask.value_lane(fx, lane_idx).load_scalar(fx);
+
+                let if_enabled = fx.bcx.create_block();
+                let if_disabled = fx.bcx.create_block();
+                let next = fx.bcx.create_block();
+                let res_lane = fx.bcx.append_block_param(next, lane_clif_ty);
+
+                fx.bcx.ins().brif(mask_lane, if_enabled, &[], if_disabled, &[]);
+                fx.bcx.seal_block(if_enabled);
+                fx.bcx.seal_block(if_disabled);
+
+                fx.bcx.switch_to_block(if_enabled);
+                let offset = lane_idx as i32 * lane_clif_ty.bytes() as i32;
+                let res = fx.bcx.ins().load(
+                    lane_clif_ty,
+                    MemFlags::trusted(),
+                    ptr_val,
+                    Offset32::new(offset),
+                );
+                fx.bcx.ins().jump(next, &[res]);
+
+                fx.bcx.switch_to_block(if_disabled);
+                fx.bcx.ins().jump(next, &[val_lane]);
+
+                fx.bcx.seal_block(next);
+                fx.bcx.switch_to_block(next);
+
+                fx.bcx.ins().nop();
+
+                ret.place_lane(fx, lane_idx)
+                    .write_cvalue(fx, CValue::by_val(res_lane, ret_lane_layout));
+            }
+        }
+
         sym::simd_scatter => {
-            intrinsic_args!(fx, args => (val, ptr, mask); intrinsic);
+            intrinsic_args!(fx, args => (mask, ptr, val); intrinsic);
 
             let (val_lane_count, _val_lane_ty) = val.layout().ty.simd_size_and_type(fx.tcx);
             let (ptr_lane_count, _ptr_lane_ty) = ptr.layout().ty.simd_size_and_type(fx.tcx);
diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs
index cc7e78b9c62..f16014e1361 100644
--- a/compiler/rustc_codegen_llvm/src/intrinsic.rs
+++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs
@@ -1492,6 +1492,198 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
         return Ok(v);
     }
 
+    if name == sym::simd_masked_load {
+        // simd_masked_load(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
+        // * N: number of elements in the input vectors
+        // * T: type of the element to load
+        // * M: any integer width is supported, will be truncated to i1
+        // Loads contiguous elements from memory behind `pointer`, but only for
+        // those lanes whose `mask` bit is enabled.
+        // The memory addresses corresponding to the “off” lanes are not accessed.
+
+        // The element type of the "mask" argument must be a signed integer type of any width
+        let mask_ty = in_ty;
+        let (mask_len, mask_elem) = (in_len, in_elem);
+
+        // The second argument must be a pointer matching the element type
+        let pointer_ty = arg_tys[1];
+
+        // The last argument is a passthrough vector providing values for disabled lanes
+        let values_ty = arg_tys[2];
+        let (values_len, values_elem) = require_simd!(values_ty, SimdThird);
+
+        require_simd!(ret_ty, SimdReturn);
+
+        // Of the same length:
+        require!(
+            values_len == mask_len,
+            InvalidMonomorphization::ThirdArgumentLength {
+                span,
+                name,
+                in_len: mask_len,
+                in_ty: mask_ty,
+                arg_ty: values_ty,
+                out_len: values_len
+            }
+        );
+
+        // The return type must match the last argument type
+        require!(
+            ret_ty == values_ty,
+            InvalidMonomorphization::ExpectedReturnType { span, name, in_ty: values_ty, ret_ty }
+        );
+
+        require!(
+            matches!(
+                pointer_ty.kind(),
+                ty::RawPtr(p) if p.ty == values_elem && p.ty.kind() == values_elem.kind()
+            ),
+            InvalidMonomorphization::ExpectedElementType {
+                span,
+                name,
+                expected_element: values_elem,
+                second_arg: pointer_ty,
+                in_elem: values_elem,
+                in_ty: values_ty,
+                mutability: ExpectedPointerMutability::Not,
+            }
+        );
+
+        require!(
+            matches!(mask_elem.kind(), ty::Int(_)),
+            InvalidMonomorphization::ThirdArgElementType {
+                span,
+                name,
+                expected_element: values_elem,
+                third_arg: mask_ty,
+            }
+        );
+
+        // Alignment of T, must be a constant integer value:
+        let alignment_ty = bx.type_i32();
+        let alignment = bx.const_i32(bx.align_of(values_ty).bytes() as i32);
+
+        // Truncate the mask vector to a vector of i1s:
+        let (mask, mask_ty) = {
+            let i1 = bx.type_i1();
+            let i1xn = bx.type_vector(i1, mask_len);
+            (bx.trunc(args[0].immediate(), i1xn), i1xn)
+        };
+
+        let llvm_pointer = bx.type_ptr();
+
+        // Type of the vector of elements:
+        let llvm_elem_vec_ty = llvm_vector_ty(bx, values_elem, values_len);
+        let llvm_elem_vec_str = llvm_vector_str(bx, values_elem, values_len);
+
+        let llvm_intrinsic = format!("llvm.masked.load.{llvm_elem_vec_str}.p0");
+        let fn_ty = bx
+            .type_func(&[llvm_pointer, alignment_ty, mask_ty, llvm_elem_vec_ty], llvm_elem_vec_ty);
+        let f = bx.declare_cfn(&llvm_intrinsic, llvm::UnnamedAddr::No, fn_ty);
+        let v = bx.call(
+            fn_ty,
+            None,
+            None,
+            f,
+            &[args[1].immediate(), alignment, mask, args[2].immediate()],
+            None,
+        );
+        return Ok(v);
+    }
+
+    if name == sym::simd_masked_store {
+        // simd_masked_store(mask: <N x i{M}>, pointer: *mut T, values: <N x T>) -> ()
+        // * N: number of elements in the input vectors
+        // * T: type of the element to load
+        // * M: any integer width is supported, will be truncated to i1
+        // Stores contiguous elements to memory behind `pointer`, but only for
+        // those lanes whose `mask` bit is enabled.
+        // The memory addresses corresponding to the “off” lanes are not accessed.
+
+        // The element type of the "mask" argument must be a signed integer type of any width
+        let mask_ty = in_ty;
+        let (mask_len, mask_elem) = (in_len, in_elem);
+
+        // The second argument must be a pointer matching the element type
+        let pointer_ty = arg_tys[1];
+
+        // The last argument specifies the values to store to memory
+        let values_ty = arg_tys[2];
+        let (values_len, values_elem) = require_simd!(values_ty, SimdThird);
+
+        // Of the same length:
+        require!(
+            values_len == mask_len,
+            InvalidMonomorphization::ThirdArgumentLength {
+                span,
+                name,
+                in_len: mask_len,
+                in_ty: mask_ty,
+                arg_ty: values_ty,
+                out_len: values_len
+            }
+        );
+
+        // The second argument must be a mutable pointer type matching the element type
+        require!(
+            matches!(
+                pointer_ty.kind(),
+                ty::RawPtr(p) if p.ty == values_elem && p.ty.kind() == values_elem.kind() && p.mutbl.is_mut()
+            ),
+            InvalidMonomorphization::ExpectedElementType {
+                span,
+                name,
+                expected_element: values_elem,
+                second_arg: pointer_ty,
+                in_elem: values_elem,
+                in_ty: values_ty,
+                mutability: ExpectedPointerMutability::Mut,
+            }
+        );
+
+        require!(
+            matches!(mask_elem.kind(), ty::Int(_)),
+            InvalidMonomorphization::ThirdArgElementType {
+                span,
+                name,
+                expected_element: values_elem,
+                third_arg: mask_ty,
+            }
+        );
+
+        // Alignment of T, must be a constant integer value:
+        let alignment_ty = bx.type_i32();
+        let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
+
+        // Truncate the mask vector to a vector of i1s:
+        let (mask, mask_ty) = {
+            let i1 = bx.type_i1();
+            let i1xn = bx.type_vector(i1, in_len);
+            (bx.trunc(args[0].immediate(), i1xn), i1xn)
+        };
+
+        let ret_t = bx.type_void();
+
+        let llvm_pointer = bx.type_ptr();
+
+        // Type of the vector of elements:
+        let llvm_elem_vec_ty = llvm_vector_ty(bx, values_elem, values_len);
+        let llvm_elem_vec_str = llvm_vector_str(bx, values_elem, values_len);
+
+        let llvm_intrinsic = format!("llvm.masked.store.{llvm_elem_vec_str}.p0");
+        let fn_ty = bx.type_func(&[llvm_elem_vec_ty, llvm_pointer, alignment_ty, mask_ty], ret_t);
+        let f = bx.declare_cfn(&llvm_intrinsic, llvm::UnnamedAddr::No, fn_ty);
+        let v = bx.call(
+            fn_ty,
+            None,
+            None,
+            f,
+            &[args[2].immediate(), args[1].immediate(), alignment, mask],
+            None,
+        );
+        return Ok(v);
+    }
+
     if name == sym::simd_scatter {
         // simd_scatter(values: <N x T>, pointers: <N x *mut T>,
         //             mask: <N x i{M}>) -> ()
diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs
index 7ea21b24fc8..33337190562 100644
--- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs
+++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs
@@ -521,6 +521,8 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
         sym::simd_fpowi => (1, 0, vec![param(0), tcx.types.i32], param(0)),
         sym::simd_fma => (1, 0, vec![param(0), param(0), param(0)], param(0)),
         sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
+        sym::simd_masked_load => (3, 0, vec![param(0), param(1), param(2)], param(2)),
+        sym::simd_masked_store => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
         sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
         sym::simd_insert => (2, 0, vec![param(0), tcx.types.u32, param(1)], param(0)),
         sym::simd_extract => (2, 0, vec![param(0), tcx.types.u32], param(1)),
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index 7b9b7b85293..485265e6889 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -1516,6 +1516,8 @@ symbols! {
         simd_insert,
         simd_le,
         simd_lt,
+        simd_masked_load,
+        simd_masked_store,
         simd_mul,
         simd_ne,
         simd_neg,