about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbeetrees <b@beetr.ee>2024-06-15 22:30:25 +0100
committerbeetrees <b@beetr.ee>2024-06-21 18:48:20 +0100
commit771e44ebd32df8e342efa1f246f5d5070af04ec4 (patch)
tree3f86a9e6e35a9a580a773337831d7fe8bbb6b06d
parent92af831290cf60434aa44ba7c6a5171ec48e98be (diff)
downloadrust-771e44ebd32df8e342efa1f246f5d5070af04ec4.tar.gz
rust-771e44ebd32df8e342efa1f246f5d5070af04ec4.zip
Add `f16` inline ASM support for RISC-V
-rw-r--r--compiler/rustc_codegen_llvm/src/asm.rs55
-rw-r--r--compiler/rustc_span/src/symbol.rs2
-rw-r--r--compiler/rustc_target/src/asm/riscv.rs7
-rw-r--r--tests/assembly/asm/riscv-types.rs55
4 files changed, 108 insertions, 11 deletions
diff --git a/compiler/rustc_codegen_llvm/src/asm.rs b/compiler/rustc_codegen_llvm/src/asm.rs
index 60e63b956db..34a0f9973f6 100644
--- a/compiler/rustc_codegen_llvm/src/asm.rs
+++ b/compiler/rustc_codegen_llvm/src/asm.rs
@@ -13,7 +13,7 @@ use rustc_codegen_ssa::traits::*;
 use rustc_data_structures::fx::FxHashMap;
 use rustc_middle::ty::layout::TyAndLayout;
 use rustc_middle::{bug, span_bug, ty::Instance};
-use rustc_span::{Pos, Span};
+use rustc_span::{sym, Pos, Span, Symbol};
 use rustc_target::abi::*;
 use rustc_target::asm::*;
 use tracing::debug;
@@ -64,7 +64,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                     let mut layout = None;
                     let ty = if let Some(ref place) = place {
                         layout = Some(&place.layout);
-                        llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout)
+                        llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout, instance)
                     } else if matches!(
                         reg.reg_class(),
                         InlineAsmRegClass::X86(
@@ -112,7 +112,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                         // so we just use the type of the input.
                         &in_value.layout
                     };
-                    let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout);
+                    let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout, instance);
                     output_types.push(ty);
                     op_idx.insert(idx, constraints.len());
                     let prefix = if late { "=" } else { "=&" };
@@ -127,8 +127,13 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
         for (idx, op) in operands.iter().enumerate() {
             match *op {
                 InlineAsmOperandRef::In { reg, value } => {
-                    let llval =
-                        llvm_fixup_input(self, value.immediate(), reg.reg_class(), &value.layout);
+                    let llval = llvm_fixup_input(
+                        self,
+                        value.immediate(),
+                        reg.reg_class(),
+                        &value.layout,
+                        instance,
+                    );
                     inputs.push(llval);
                     op_idx.insert(idx, constraints.len());
                     constraints.push(reg_to_llvm(reg, Some(&value.layout)));
@@ -139,6 +144,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                         in_value.immediate(),
                         reg.reg_class(),
                         &in_value.layout,
+                        instance,
                     );
                     inputs.push(value);
 
@@ -341,7 +347,8 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                 } else {
                     self.extract_value(result, op_idx[&idx] as u64)
                 };
-                let value = llvm_fixup_output(self, value, reg.reg_class(), &place.layout);
+                let value =
+                    llvm_fixup_output(self, value, reg.reg_class(), &place.layout, instance);
                 OperandValue::Immediate(value).store(self, place);
             }
         }
@@ -913,12 +920,22 @@ fn llvm_asm_scalar_type<'ll>(cx: &CodegenCx<'ll, '_>, scalar: Scalar) -> &'ll Ty
     }
 }
 
+fn any_target_feature_enabled(
+    cx: &CodegenCx<'_, '_>,
+    instance: Instance<'_>,
+    features: &[Symbol],
+) -> bool {
+    let enabled = cx.tcx.asm_target_features(instance.def_id());
+    features.iter().any(|feat| enabled.contains(feat))
+}
+
 /// Fix up an input value to work around LLVM bugs.
 fn llvm_fixup_input<'ll, 'tcx>(
     bx: &mut Builder<'_, 'll, 'tcx>,
     mut value: &'ll Value,
     reg: InlineAsmRegClass,
     layout: &TyAndLayout<'tcx>,
+    instance: Instance<'_>,
 ) -> &'ll Value {
     let dl = &bx.tcx.data_layout;
     match (reg, layout.abi) {
@@ -1029,6 +1046,16 @@ fn llvm_fixup_input<'ll, 'tcx>(
                 _ => value,
             }
         }
+        (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
+            if s.primitive() == Primitive::Float(Float::F16)
+                && !any_target_feature_enabled(bx, instance, &[sym::zfhmin, sym::zfh]) =>
+        {
+            // Smaller floats are always "NaN-boxed" inside larger floats on RISC-V.
+            let value = bx.bitcast(value, bx.type_i16());
+            let value = bx.zext(value, bx.type_i32());
+            let value = bx.or(value, bx.const_u32(0xFFFF_0000));
+            bx.bitcast(value, bx.type_f32())
+        }
         _ => value,
     }
 }
@@ -1039,6 +1066,7 @@ fn llvm_fixup_output<'ll, 'tcx>(
     mut value: &'ll Value,
     reg: InlineAsmRegClass,
     layout: &TyAndLayout<'tcx>,
+    instance: Instance<'_>,
 ) -> &'ll Value {
     match (reg, layout.abi) {
         (InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => {
@@ -1140,6 +1168,14 @@ fn llvm_fixup_output<'ll, 'tcx>(
                 _ => value,
             }
         }
+        (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
+            if s.primitive() == Primitive::Float(Float::F16)
+                && !any_target_feature_enabled(bx, instance, &[sym::zfhmin, sym::zfh]) =>
+        {
+            let value = bx.bitcast(value, bx.type_i32());
+            let value = bx.trunc(value, bx.type_i16());
+            bx.bitcast(value, bx.type_f16())
+        }
         _ => value,
     }
 }
@@ -1149,6 +1185,7 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
     cx: &CodegenCx<'ll, 'tcx>,
     reg: InlineAsmRegClass,
     layout: &TyAndLayout<'tcx>,
+    instance: Instance<'_>,
 ) -> &'ll Type {
     match (reg, layout.abi) {
         (InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => {
@@ -1242,6 +1279,12 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
                 _ => layout.llvm_type(cx),
             }
         }
+        (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
+            if s.primitive() == Primitive::Float(Float::F16)
+                && !any_target_feature_enabled(cx, instance, &[sym::zfhmin, sym::zfh]) =>
+        {
+            cx.type_f32()
+        }
         _ => layout.llvm_type(cx),
     }
 }
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index f44fa1bcb4f..018442af144 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -2054,6 +2054,8 @@ symbols! {
         yes,
         yield_expr,
         ymm_reg,
+        zfh,
+        zfhmin,
         zmm_reg,
     }
 }
diff --git a/compiler/rustc_target/src/asm/riscv.rs b/compiler/rustc_target/src/asm/riscv.rs
index 3845a0e14af..02a4a5e2ece 100644
--- a/compiler/rustc_target/src/asm/riscv.rs
+++ b/compiler/rustc_target/src/asm/riscv.rs
@@ -40,12 +40,13 @@ impl RiscVInlineAsmRegClass {
         match self {
             Self::reg => {
                 if arch == InlineAsmArch::RiscV64 {
-                    types! { _: I8, I16, I32, I64, F32, F64; }
+                    types! { _: I8, I16, I32, I64, F16, F32, F64; }
                 } else {
-                    types! { _: I8, I16, I32, F32; }
+                    types! { _: I8, I16, I32, F16, F32; }
                 }
             }
-            Self::freg => types! { f: F32; d: F64; },
+            // FIXME(f16_f128): Add `q: F128;` once LLVM support the `Q` extension.
+            Self::freg => types! { f: F16, F32; d: F64; },
             Self::vreg => &[],
         }
     }
diff --git a/tests/assembly/asm/riscv-types.rs b/tests/assembly/asm/riscv-types.rs
index 0d1f8305d37..51b3aaf99d9 100644
--- a/tests/assembly/asm/riscv-types.rs
+++ b/tests/assembly/asm/riscv-types.rs
@@ -1,12 +1,34 @@
-//@ revisions: riscv64 riscv32
+//@ revisions: riscv64 riscv32 riscv64-zfhmin riscv32-zfhmin riscv64-zfh riscv32-zfh
 //@ assembly-output: emit-asm
+
 //@[riscv64] compile-flags: --target riscv64imac-unknown-none-elf
 //@[riscv64] needs-llvm-components: riscv
+
 //@[riscv32] compile-flags: --target riscv32imac-unknown-none-elf
 //@[riscv32] needs-llvm-components: riscv
+
+//@[riscv64-zfhmin] compile-flags: --target riscv64imac-unknown-none-elf --cfg riscv64
+//@[riscv64-zfhmin] needs-llvm-components: riscv
+//@[riscv64-zfhmin] compile-flags: -C target-feature=+zfhmin
+//@[riscv64-zfhmin] filecheck-flags: --check-prefix riscv64
+
+//@[riscv32-zfhmin] compile-flags: --target riscv32imac-unknown-none-elf
+//@[riscv32-zfhmin] needs-llvm-components: riscv
+//@[riscv32-zfhmin] compile-flags: -C target-feature=+zfhmin
+
+//@[riscv64-zfh] compile-flags: --target riscv64imac-unknown-none-elf --cfg riscv64
+//@[riscv64-zfh] needs-llvm-components: riscv
+//@[riscv64-zfh] compile-flags: -C target-feature=+zfh
+//@[riscv64-zfh] filecheck-flags: --check-prefix riscv64 --check-prefix zfhmin
+
+//@[riscv32-zfh] compile-flags: --target riscv32imac-unknown-none-elf
+//@[riscv32-zfh] needs-llvm-components: riscv
+//@[riscv32-zfh] compile-flags: -C target-feature=+zfh
+//@[riscv32-zfh] filecheck-flags: --check-prefix zfhmin
+
 //@ compile-flags: -C target-feature=+d
 
-#![feature(no_core, lang_items, rustc_attrs)]
+#![feature(no_core, lang_items, rustc_attrs, f16)]
 #![crate_type = "rlib"]
 #![no_core]
 #![allow(asm_sub_register)]
@@ -33,6 +55,7 @@ type ptr = *mut u8;
 
 impl Copy for i8 {}
 impl Copy for i16 {}
+impl Copy for f16 {}
 impl Copy for i32 {}
 impl Copy for f32 {}
 impl Copy for i64 {}
@@ -103,6 +126,12 @@ macro_rules! check_reg {
 // CHECK: #NO_APP
 check!(reg_i8 i8 reg "mv");
 
+// CHECK-LABEL: reg_f16:
+// CHECK: #APP
+// CHECK: mv {{[a-z0-9]+}}, {{[a-z0-9]+}}
+// CHECK: #NO_APP
+check!(reg_f16 f16 reg "mv");
+
 // CHECK-LABEL: reg_i16:
 // CHECK: #APP
 // CHECK: mv {{[a-z0-9]+}}, {{[a-z0-9]+}}
@@ -141,6 +170,14 @@ check!(reg_f64 f64 reg "mv");
 // CHECK: #NO_APP
 check!(reg_ptr ptr reg "mv");
 
+// CHECK-LABEL: freg_f16:
+// zfhmin-NOT: or
+// CHECK: #APP
+// CHECK: fmv.s f{{[a-z0-9]+}}, f{{[a-z0-9]+}}
+// CHECK: #NO_APP
+// zfhmin-NOT: or
+check!(freg_f16 f16 freg "fmv.s");
+
 // CHECK-LABEL: freg_f32:
 // CHECK: #APP
 // CHECK: fmv.s f{{[a-z0-9]+}}, f{{[a-z0-9]+}}
@@ -165,6 +202,12 @@ check_reg!(a0_i8 i8 "a0" "mv");
 // CHECK: #NO_APP
 check_reg!(a0_i16 i16 "a0" "mv");
 
+// CHECK-LABEL: a0_f16:
+// CHECK: #APP
+// CHECK: mv a0, a0
+// CHECK: #NO_APP
+check_reg!(a0_f16 f16 "a0" "mv");
+
 // CHECK-LABEL: a0_i32:
 // CHECK: #APP
 // CHECK: mv a0, a0
@@ -197,6 +240,14 @@ check_reg!(a0_f64 f64 "a0" "mv");
 // CHECK: #NO_APP
 check_reg!(a0_ptr ptr "a0" "mv");
 
+// CHECK-LABEL: fa0_f16:
+// zfhmin-NOT: or
+// CHECK: #APP
+// CHECK: fmv.s fa0, fa0
+// CHECK: #NO_APP
+// zfhmin-NOT: or
+check_reg!(fa0_f16 f16 "fa0" "fmv.s");
+
 // CHECK-LABEL: fa0_f32:
 // CHECK: #APP
 // CHECK: fmv.s fa0, fa0