about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/intrinsic.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/intrinsic.rs')
-rw-r--r--compiler/rustc_codegen_llvm/src/intrinsic.rs197
1 files changed, 104 insertions, 93 deletions
diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs
index 91b44b084b0..eab4a9f30c9 100644
--- a/compiler/rustc_codegen_llvm/src/intrinsic.rs
+++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs
@@ -1182,6 +1182,60 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
         }};
     }
 
+    /// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
+    macro_rules! require_int_ty {
+        ($ty: expr, $diag: expr) => {
+            match $ty {
+                ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
+                _ => {
+                    return_error!($diag);
+                }
+            }
+        };
+    }
+
+    /// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
+    macro_rules! require_int_or_uint_ty {
+        ($ty: expr, $diag: expr) => {
+            match $ty {
+                ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
+                ty::Uint(i) => {
+                    i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
+                }
+                _ => {
+                    return_error!($diag);
+                }
+            }
+        };
+    }
+
+    /// Converts a vector mask, where each element has a bit width equal to the data elements it is used with,
+    /// down to an i1 based mask that can be used by llvm intrinsics.
+    ///
+    /// The rust simd semantics are that each element should either consist of all ones or all zeroes,
+    /// but this information is not available to llvm. Truncating the vector effectively uses the lowest bit,
+    /// but codegen for several targets is better if we consider the highest bit by shifting.
+    ///
+    /// For x86 SSE/AVX targets this is beneficial since most instructions with mask parameters only consider the highest bit.
+    /// So even though on llvm level we have an additional shift, in the final assembly there is no shift or truncate and
+    /// instead the mask can be used as is.
+    ///
+    /// For aarch64 and other targets there is a benefit because a mask from the sign bit can be more
+    /// efficiently converted to an all ones / all zeroes mask by comparing whether each element is negative.
+    fn vector_mask_to_bitmask<'a, 'll, 'tcx>(
+        bx: &mut Builder<'a, 'll, 'tcx>,
+        i_xn: &'ll Value,
+        in_elem_bitwidth: u64,
+        in_len: u64,
+    ) -> &'ll Value {
+        // Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
+        let shift_idx = bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
+        let shift_indices = vec![shift_idx; in_len as _];
+        let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
+        // Truncate vector to an <i1 x N>
+        bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len))
+    }
+
     let tcx = bx.tcx();
     let sig = tcx.normalize_erasing_late_bound_regions(bx.typing_env(), callee_ty.fn_sig(tcx));
     let arg_tys = sig.inputs();
@@ -1433,14 +1487,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
             m_len,
             v_len
         });
-        match m_elem_ty.kind() {
-            ty::Int(_) => {}
-            _ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
-        }
-        // truncate the mask to a vector of i1s
-        let i1 = bx.type_i1();
-        let i1xn = bx.type_vector(i1, m_len as u64);
-        let m_i1s = bx.trunc(args[0].immediate(), i1xn);
+        let in_elem_bitwidth =
+            require_int_ty!(m_elem_ty.kind(), InvalidMonomorphization::MaskType {
+                span,
+                name,
+                ty: m_elem_ty
+            });
+        let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
         return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
     }
 
@@ -1457,33 +1510,15 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
         let expected_bytes = in_len.div_ceil(8);
 
         // Integer vector <i{in_bitwidth} x in_len>:
-        let (i_xn, in_elem_bitwidth) = match in_elem.kind() {
-            ty::Int(i) => (
-                args[0].immediate(),
-                i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
-            ),
-            ty::Uint(i) => (
-                args[0].immediate(),
-                i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
-            ),
-            _ => return_error!(InvalidMonomorphization::VectorArgument {
+        let in_elem_bitwidth =
+            require_int_or_uint_ty!(in_elem.kind(), InvalidMonomorphization::VectorArgument {
                 span,
                 name,
                 in_ty,
                 in_elem
-            }),
-        };
+            });
 
-        // LLVM doesn't always know the inputs are `0` or `!0`, so we shift here so it optimizes to
-        // `pmovmskb` and similar on x86.
-        let shift_indices =
-            vec![
-                bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
-                in_len as _
-            ];
-        let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
-        // Truncate vector to an <i1 x N>
-        let i1xn = bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len));
+        let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len);
         // Bitcast <i1 x N> to iN:
         let i_ = bx.bitcast(i1xn, bx.type_ix(in_len));
 
@@ -1704,28 +1739,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
             }
         );
 
-        match element_ty2.kind() {
-            ty::Int(_) => (),
-            _ => {
-                return_error!(InvalidMonomorphization::ThirdArgElementType {
-                    span,
-                    name,
-                    expected_element: element_ty2,
-                    third_arg: arg_tys[2]
-                });
-            }
-        }
+        let mask_elem_bitwidth =
+            require_int_ty!(element_ty2.kind(), InvalidMonomorphization::ThirdArgElementType {
+                span,
+                name,
+                expected_element: element_ty2,
+                third_arg: arg_tys[2]
+            });
 
         // Alignment of T, must be a constant integer value:
         let alignment_ty = bx.type_i32();
         let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
 
         // Truncate the mask vector to a vector of i1s:
-        let (mask, mask_ty) = {
-            let i1 = bx.type_i1();
-            let i1xn = bx.type_vector(i1, in_len);
-            (bx.trunc(args[2].immediate(), i1xn), i1xn)
-        };
+        let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
+        let mask_ty = bx.type_vector(bx.type_i1(), in_len);
 
         // Type of the vector of pointers:
         let llvm_pointer_vec_ty = llvm_vector_ty(bx, element_ty1, in_len);
@@ -1810,27 +1838,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
             }
         );
 
-        require!(
-            matches!(mask_elem.kind(), ty::Int(_)),
-            InvalidMonomorphization::ThirdArgElementType {
+        let m_elem_bitwidth =
+            require_int_ty!(mask_elem.kind(), InvalidMonomorphization::ThirdArgElementType {
                 span,
                 name,
                 expected_element: values_elem,
                 third_arg: mask_ty,
-            }
-        );
+            });
+
+        let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
+        let mask_ty = bx.type_vector(bx.type_i1(), mask_len);
 
         // Alignment of T, must be a constant integer value:
         let alignment_ty = bx.type_i32();
         let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
 
-        // Truncate the mask vector to a vector of i1s:
-        let (mask, mask_ty) = {
-            let i1 = bx.type_i1();
-            let i1xn = bx.type_vector(i1, mask_len);
-            (bx.trunc(args[0].immediate(), i1xn), i1xn)
-        };
-
         let llvm_pointer = bx.type_ptr();
 
         // Type of the vector of elements:
@@ -1901,27 +1923,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
             }
         );
 
-        require!(
-            matches!(mask_elem.kind(), ty::Int(_)),
-            InvalidMonomorphization::ThirdArgElementType {
+        let m_elem_bitwidth =
+            require_int_ty!(mask_elem.kind(), InvalidMonomorphization::ThirdArgElementType {
                 span,
                 name,
                 expected_element: values_elem,
                 third_arg: mask_ty,
-            }
-        );
+            });
+
+        let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
+        let mask_ty = bx.type_vector(bx.type_i1(), mask_len);
 
         // Alignment of T, must be a constant integer value:
         let alignment_ty = bx.type_i32();
         let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
 
-        // Truncate the mask vector to a vector of i1s:
-        let (mask, mask_ty) = {
-            let i1 = bx.type_i1();
-            let i1xn = bx.type_vector(i1, in_len);
-            (bx.trunc(args[0].immediate(), i1xn), i1xn)
-        };
-
         let ret_t = bx.type_void();
 
         let llvm_pointer = bx.type_ptr();
@@ -1995,28 +2011,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
         );
 
         // The element type of the third argument must be a signed integer type of any width:
-        match element_ty2.kind() {
-            ty::Int(_) => (),
-            _ => {
-                return_error!(InvalidMonomorphization::ThirdArgElementType {
-                    span,
-                    name,
-                    expected_element: element_ty2,
-                    third_arg: arg_tys[2]
-                });
-            }
-        }
+        let mask_elem_bitwidth =
+            require_int_ty!(element_ty2.kind(), InvalidMonomorphization::ThirdArgElementType {
+                span,
+                name,
+                expected_element: element_ty2,
+                third_arg: arg_tys[2]
+            });
 
         // Alignment of T, must be a constant integer value:
         let alignment_ty = bx.type_i32();
         let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
 
         // Truncate the mask vector to a vector of i1s:
-        let (mask, mask_ty) = {
-            let i1 = bx.type_i1();
-            let i1xn = bx.type_vector(i1, in_len);
-            (bx.trunc(args[2].immediate(), i1xn), i1xn)
-        };
+        let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
+        let mask_ty = bx.type_vector(bx.type_i1(), in_len);
 
         let ret_t = bx.type_void();
 
@@ -2164,8 +2173,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
                     });
                     args[0].immediate()
                 } else {
-                    match in_elem.kind() {
-                        ty::Int(_) | ty::Uint(_) => {}
+                    let bitwidth = match in_elem.kind() {
+                        ty::Int(i) => {
+                            i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
+                        }
+                        ty::Uint(i) => {
+                            i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
+                        }
                         _ => return_error!(InvalidMonomorphization::UnsupportedSymbol {
                             span,
                             name,
@@ -2174,12 +2188,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
                             in_elem,
                             ret_ty
                         }),
-                    }
+                    };
 
-                    // boolean reductions operate on vectors of i1s:
-                    let i1 = bx.type_i1();
-                    let i1xn = bx.type_vector(i1, in_len as u64);
-                    bx.trunc(args[0].immediate(), i1xn)
+                    vector_mask_to_bitmask(bx, args[0].immediate(), bitwidth, in_len as _)
                 };
                 return match in_elem.kind() {
                     ty::Int(_) | ty::Uint(_) => {