about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/asm.rs
diff options
context:
space:
mode:
authorbeetrees <b@beetr.ee>2024-06-15 22:30:25 +0100
committerbeetrees <b@beetr.ee>2024-06-21 18:48:20 +0100
commit771e44ebd32df8e342efa1f246f5d5070af04ec4 (patch)
tree3f86a9e6e35a9a580a773337831d7fe8bbb6b06d /compiler/rustc_codegen_llvm/src/asm.rs
parent92af831290cf60434aa44ba7c6a5171ec48e98be (diff)
downloadrust-771e44ebd32df8e342efa1f246f5d5070af04ec4.tar.gz
rust-771e44ebd32df8e342efa1f246f5d5070af04ec4.zip
Add `f16` inline ASM support for RISC-V
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/asm.rs')
-rw-r--r--compiler/rustc_codegen_llvm/src/asm.rs55
1 files changed, 49 insertions, 6 deletions
diff --git a/compiler/rustc_codegen_llvm/src/asm.rs b/compiler/rustc_codegen_llvm/src/asm.rs
index 60e63b956db..34a0f9973f6 100644
--- a/compiler/rustc_codegen_llvm/src/asm.rs
+++ b/compiler/rustc_codegen_llvm/src/asm.rs
@@ -13,7 +13,7 @@ use rustc_codegen_ssa::traits::*;
 use rustc_data_structures::fx::FxHashMap;
 use rustc_middle::ty::layout::TyAndLayout;
 use rustc_middle::{bug, span_bug, ty::Instance};
-use rustc_span::{Pos, Span};
+use rustc_span::{sym, Pos, Span, Symbol};
 use rustc_target::abi::*;
 use rustc_target::asm::*;
 use tracing::debug;
@@ -64,7 +64,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                     let mut layout = None;
                     let ty = if let Some(ref place) = place {
                         layout = Some(&place.layout);
-                        llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout)
+                        llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout, instance)
                     } else if matches!(
                         reg.reg_class(),
                         InlineAsmRegClass::X86(
@@ -112,7 +112,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                         // so we just use the type of the input.
                         &in_value.layout
                     };
-                    let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout);
+                    let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout, instance);
                     output_types.push(ty);
                     op_idx.insert(idx, constraints.len());
                     let prefix = if late { "=" } else { "=&" };
@@ -127,8 +127,13 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
         for (idx, op) in operands.iter().enumerate() {
             match *op {
                 InlineAsmOperandRef::In { reg, value } => {
-                    let llval =
-                        llvm_fixup_input(self, value.immediate(), reg.reg_class(), &value.layout);
+                    let llval = llvm_fixup_input(
+                        self,
+                        value.immediate(),
+                        reg.reg_class(),
+                        &value.layout,
+                        instance,
+                    );
                     inputs.push(llval);
                     op_idx.insert(idx, constraints.len());
                     constraints.push(reg_to_llvm(reg, Some(&value.layout)));
@@ -139,6 +144,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                         in_value.immediate(),
                         reg.reg_class(),
                         &in_value.layout,
+                        instance,
                     );
                     inputs.push(value);
 
@@ -341,7 +347,8 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                 } else {
                     self.extract_value(result, op_idx[&idx] as u64)
                 };
-                let value = llvm_fixup_output(self, value, reg.reg_class(), &place.layout);
+                let value =
+                    llvm_fixup_output(self, value, reg.reg_class(), &place.layout, instance);
                 OperandValue::Immediate(value).store(self, place);
             }
         }
@@ -913,12 +920,22 @@ fn llvm_asm_scalar_type<'ll>(cx: &CodegenCx<'ll, '_>, scalar: Scalar) -> &'ll Ty
     }
 }
 
+fn any_target_feature_enabled(
+    cx: &CodegenCx<'_, '_>,
+    instance: Instance<'_>,
+    features: &[Symbol],
+) -> bool {
+    let enabled = cx.tcx.asm_target_features(instance.def_id());
+    features.iter().any(|feat| enabled.contains(feat))
+}
+
 /// Fix up an input value to work around LLVM bugs.
 fn llvm_fixup_input<'ll, 'tcx>(
     bx: &mut Builder<'_, 'll, 'tcx>,
     mut value: &'ll Value,
     reg: InlineAsmRegClass,
     layout: &TyAndLayout<'tcx>,
+    instance: Instance<'_>,
 ) -> &'ll Value {
     let dl = &bx.tcx.data_layout;
     match (reg, layout.abi) {
@@ -1029,6 +1046,16 @@ fn llvm_fixup_input<'ll, 'tcx>(
                 _ => value,
             }
         }
+        (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
+            if s.primitive() == Primitive::Float(Float::F16)
+                && !any_target_feature_enabled(bx, instance, &[sym::zfhmin, sym::zfh]) =>
+        {
+            // Smaller floats are always "NaN-boxed" inside larger floats on RISC-V.
+            let value = bx.bitcast(value, bx.type_i16());
+            let value = bx.zext(value, bx.type_i32());
+            let value = bx.or(value, bx.const_u32(0xFFFF_0000));
+            bx.bitcast(value, bx.type_f32())
+        }
         _ => value,
     }
 }
@@ -1039,6 +1066,7 @@ fn llvm_fixup_output<'ll, 'tcx>(
     mut value: &'ll Value,
     reg: InlineAsmRegClass,
     layout: &TyAndLayout<'tcx>,
+    instance: Instance<'_>,
 ) -> &'ll Value {
     match (reg, layout.abi) {
         (InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => {
@@ -1140,6 +1168,14 @@ fn llvm_fixup_output<'ll, 'tcx>(
                 _ => value,
             }
         }
+        (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
+            if s.primitive() == Primitive::Float(Float::F16)
+                && !any_target_feature_enabled(bx, instance, &[sym::zfhmin, sym::zfh]) =>
+        {
+            let value = bx.bitcast(value, bx.type_i32());
+            let value = bx.trunc(value, bx.type_i16());
+            bx.bitcast(value, bx.type_f16())
+        }
         _ => value,
     }
 }
@@ -1149,6 +1185,7 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
     cx: &CodegenCx<'ll, 'tcx>,
     reg: InlineAsmRegClass,
     layout: &TyAndLayout<'tcx>,
+    instance: Instance<'_>,
 ) -> &'ll Type {
     match (reg, layout.abi) {
         (InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => {
@@ -1242,6 +1279,12 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
                 _ => layout.llvm_type(cx),
             }
         }
+        (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
+            if s.primitive() == Primitive::Float(Float::F16)
+                && !any_target_feature_enabled(cx, instance, &[sym::zfhmin, sym::zfh]) =>
+        {
+            cx.type_f32()
+        }
         _ => layout.llvm_type(cx),
     }
 }