about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/asm.rs
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-07-11 17:10:09 +0000
committerbors <bors@rust-lang.org>2024-07-11 17:10:09 +0000
commita3d6efc9bcdc379e766735cc53e065d85a1755e9 (patch)
tree24923a8226aad3064a6c5739540c2d404d64f76e /compiler/rustc_codegen_llvm/src/asm.rs
parent45609a995e766d614da694d353f7aee03820b0f9 (diff)
parent62bbce2ad2b20d5cf1282da407d01de5c54161f1 (diff)
downloadrust-a3d6efc9bcdc379e766735cc53e065d85a1755e9.tar.gz
rust-a3d6efc9bcdc379e766735cc53e065d85a1755e9.zip
Auto merge of #17581 - lnicola:sync-from-rust, r=lnicola
minor: Sync from rust
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/asm.rs')
-rw-r--r--compiler/rustc_codegen_llvm/src/asm.rs94
1 files changed, 88 insertions, 6 deletions
diff --git a/compiler/rustc_codegen_llvm/src/asm.rs b/compiler/rustc_codegen_llvm/src/asm.rs
index 60e63b956db..597ebd97365 100644
--- a/compiler/rustc_codegen_llvm/src/asm.rs
+++ b/compiler/rustc_codegen_llvm/src/asm.rs
@@ -13,7 +13,7 @@ use rustc_codegen_ssa::traits::*;
 use rustc_data_structures::fx::FxHashMap;
 use rustc_middle::ty::layout::TyAndLayout;
 use rustc_middle::{bug, span_bug, ty::Instance};
-use rustc_span::{Pos, Span};
+use rustc_span::{sym, Pos, Span, Symbol};
 use rustc_target::abi::*;
 use rustc_target::asm::*;
 use tracing::debug;
@@ -64,7 +64,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                     let mut layout = None;
                     let ty = if let Some(ref place) = place {
                         layout = Some(&place.layout);
-                        llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout)
+                        llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout, instance)
                     } else if matches!(
                         reg.reg_class(),
                         InlineAsmRegClass::X86(
@@ -112,7 +112,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                         // so we just use the type of the input.
                         &in_value.layout
                     };
-                    let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout);
+                    let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout, instance);
                     output_types.push(ty);
                     op_idx.insert(idx, constraints.len());
                     let prefix = if late { "=" } else { "=&" };
@@ -127,8 +127,13 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
         for (idx, op) in operands.iter().enumerate() {
             match *op {
                 InlineAsmOperandRef::In { reg, value } => {
-                    let llval =
-                        llvm_fixup_input(self, value.immediate(), reg.reg_class(), &value.layout);
+                    let llval = llvm_fixup_input(
+                        self,
+                        value.immediate(),
+                        reg.reg_class(),
+                        &value.layout,
+                        instance,
+                    );
                     inputs.push(llval);
                     op_idx.insert(idx, constraints.len());
                     constraints.push(reg_to_llvm(reg, Some(&value.layout)));
@@ -139,6 +144,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                         in_value.immediate(),
                         reg.reg_class(),
                         &in_value.layout,
+                        instance,
                     );
                     inputs.push(value);
 
@@ -341,7 +347,8 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                 } else {
                     self.extract_value(result, op_idx[&idx] as u64)
                 };
-                let value = llvm_fixup_output(self, value, reg.reg_class(), &place.layout);
+                let value =
+                    llvm_fixup_output(self, value, reg.reg_class(), &place.layout, instance);
                 OperandValue::Immediate(value).store(self, place);
             }
         }
@@ -913,12 +920,22 @@ fn llvm_asm_scalar_type<'ll>(cx: &CodegenCx<'ll, '_>, scalar: Scalar) -> &'ll Ty
     }
 }
 
+fn any_target_feature_enabled(
+    cx: &CodegenCx<'_, '_>,
+    instance: Instance<'_>,
+    features: &[Symbol],
+) -> bool {
+    let enabled = cx.tcx.asm_target_features(instance.def_id());
+    features.iter().any(|feat| enabled.contains(feat))
+}
+
 /// Fix up an input value to work around LLVM bugs.
 fn llvm_fixup_input<'ll, 'tcx>(
     bx: &mut Builder<'_, 'll, 'tcx>,
     mut value: &'ll Value,
     reg: InlineAsmRegClass,
     layout: &TyAndLayout<'tcx>,
+    instance: Instance<'_>,
 ) -> &'ll Value {
     let dl = &bx.tcx.data_layout;
     match (reg, layout.abi) {
@@ -1020,6 +1037,19 @@ fn llvm_fixup_input<'ll, 'tcx>(
                 value
             }
         }
+        (
+            InlineAsmRegClass::Arm(
+                ArmInlineAsmRegClass::dreg
+                | ArmInlineAsmRegClass::dreg_low8
+                | ArmInlineAsmRegClass::dreg_low16
+                | ArmInlineAsmRegClass::qreg
+                | ArmInlineAsmRegClass::qreg_low4
+                | ArmInlineAsmRegClass::qreg_low8,
+            ),
+            Abi::Vector { element, count: count @ (4 | 8) },
+        ) if element.primitive() == Primitive::Float(Float::F16) => {
+            bx.bitcast(value, bx.type_vector(bx.type_i16(), count))
+        }
         (InlineAsmRegClass::Mips(MipsInlineAsmRegClass::reg), Abi::Scalar(s)) => {
             match s.primitive() {
                 // MIPS only supports register-length arithmetics.
@@ -1029,6 +1059,16 @@ fn llvm_fixup_input<'ll, 'tcx>(
                 _ => value,
             }
         }
+        (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
+            if s.primitive() == Primitive::Float(Float::F16)
+                && !any_target_feature_enabled(bx, instance, &[sym::zfhmin, sym::zfh]) =>
+        {
+            // Smaller floats are always "NaN-boxed" inside larger floats on RISC-V.
+            let value = bx.bitcast(value, bx.type_i16());
+            let value = bx.zext(value, bx.type_i32());
+            let value = bx.or(value, bx.const_u32(0xFFFF_0000));
+            bx.bitcast(value, bx.type_f32())
+        }
         _ => value,
     }
 }
@@ -1039,6 +1079,7 @@ fn llvm_fixup_output<'ll, 'tcx>(
     mut value: &'ll Value,
     reg: InlineAsmRegClass,
     layout: &TyAndLayout<'tcx>,
+    instance: Instance<'_>,
 ) -> &'ll Value {
     match (reg, layout.abi) {
         (InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => {
@@ -1130,6 +1171,19 @@ fn llvm_fixup_output<'ll, 'tcx>(
                 value
             }
         }
+        (
+            InlineAsmRegClass::Arm(
+                ArmInlineAsmRegClass::dreg
+                | ArmInlineAsmRegClass::dreg_low8
+                | ArmInlineAsmRegClass::dreg_low16
+                | ArmInlineAsmRegClass::qreg
+                | ArmInlineAsmRegClass::qreg_low4
+                | ArmInlineAsmRegClass::qreg_low8,
+            ),
+            Abi::Vector { element, count: count @ (4 | 8) },
+        ) if element.primitive() == Primitive::Float(Float::F16) => {
+            bx.bitcast(value, bx.type_vector(bx.type_f16(), count))
+        }
         (InlineAsmRegClass::Mips(MipsInlineAsmRegClass::reg), Abi::Scalar(s)) => {
             match s.primitive() {
                 // MIPS only supports register-length arithmetics.
@@ -1140,6 +1194,14 @@ fn llvm_fixup_output<'ll, 'tcx>(
                 _ => value,
             }
         }
+        (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
+            if s.primitive() == Primitive::Float(Float::F16)
+                && !any_target_feature_enabled(bx, instance, &[sym::zfhmin, sym::zfh]) =>
+        {
+            let value = bx.bitcast(value, bx.type_i32());
+            let value = bx.trunc(value, bx.type_i16());
+            bx.bitcast(value, bx.type_f16())
+        }
         _ => value,
     }
 }
@@ -1149,6 +1211,7 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
     cx: &CodegenCx<'ll, 'tcx>,
     reg: InlineAsmRegClass,
     layout: &TyAndLayout<'tcx>,
+    instance: Instance<'_>,
 ) -> &'ll Type {
     match (reg, layout.abi) {
         (InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => {
@@ -1233,6 +1296,19 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
                 layout.llvm_type(cx)
             }
         }
+        (
+            InlineAsmRegClass::Arm(
+                ArmInlineAsmRegClass::dreg
+                | ArmInlineAsmRegClass::dreg_low8
+                | ArmInlineAsmRegClass::dreg_low16
+                | ArmInlineAsmRegClass::qreg
+                | ArmInlineAsmRegClass::qreg_low4
+                | ArmInlineAsmRegClass::qreg_low8,
+            ),
+            Abi::Vector { element, count: count @ (4 | 8) },
+        ) if element.primitive() == Primitive::Float(Float::F16) => {
+            cx.type_vector(cx.type_i16(), count)
+        }
         (InlineAsmRegClass::Mips(MipsInlineAsmRegClass::reg), Abi::Scalar(s)) => {
             match s.primitive() {
                 // MIPS only supports register-length arithmetics.
@@ -1242,6 +1318,12 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
                 _ => layout.llvm_type(cx),
             }
         }
+        (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
+            if s.primitive() == Primitive::Float(Float::F16)
+                && !any_target_feature_enabled(cx, instance, &[sym::zfhmin, sym::zfh]) =>
+        {
+            cx.type_f32()
+        }
         _ => layout.llvm_type(cx),
     }
 }