about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
-rw-r--r--compiler/rustc_codegen_llvm/src/abi.rs11
-rw-r--r--compiler/rustc_codegen_llvm/src/back/lto.rs2
-rw-r--r--compiler/rustc_codegen_llvm/src/back/write.rs6
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs44
-rw-r--r--compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs13
-rw-r--r--compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs73
-rw-r--r--compiler/rustc_codegen_llvm/src/intrinsic.rs3
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs8
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/ffi.rs91
9 files changed, 159 insertions, 92 deletions
diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs
index ac7583f5666..11be7041167 100644
--- a/compiler/rustc_codegen_llvm/src/abi.rs
+++ b/compiler/rustc_codegen_llvm/src/abi.rs
@@ -215,9 +215,9 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
                 let align = attrs.pointee_align.unwrap_or(self.layout.align.abi);
                 OperandValue::Ref(PlaceValue::new_sized(val, align)).store(bx, dst);
             }
-            // Unsized indirect qrguments
+            // Unsized indirect arguments cannot be stored
             PassMode::Indirect { attrs: _, meta_attrs: Some(_), on_stack: _ } => {
-                bug!("unsized `ArgAbi` must be handled through `store_fn_arg`");
+                bug!("unsized `ArgAbi` cannot be stored");
             }
             PassMode::Cast { cast, pad_i32: _ } => {
                 // The ABI mandates that the value is passed as a different struct representation.
@@ -272,12 +272,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
                 OperandValue::Pair(next(), next()).store(bx, dst);
             }
             PassMode::Indirect { attrs: _, meta_attrs: Some(_), on_stack: _ } => {
-                let place_val = PlaceValue {
-                    llval: next(),
-                    llextra: Some(next()),
-                    align: self.layout.align.abi,
-                };
-                OperandValue::Ref(place_val).store(bx, dst);
+                bug!("unsized `ArgAbi` cannot be stored");
             }
             PassMode::Direct(_)
             | PassMode::Indirect { attrs: _, meta_attrs: None, on_stack: _ }
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs
index f571716d9dd..78107d95e5a 100644
--- a/compiler/rustc_codegen_llvm/src/back/lto.rs
+++ b/compiler/rustc_codegen_llvm/src/back/lto.rs
@@ -617,7 +617,7 @@ pub(crate) fn run_pass_manager(
         crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx);
     }
 
-    if cfg!(llvm_enzyme) && enable_ad && !thin {
+    if cfg!(feature = "llvm_enzyme") && enable_ad && !thin {
         let opt_stage = llvm::OptStage::FatLTO;
         let stage = write::AutodiffStage::PostAD;
         if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs
index 423f0da4878..bda81fbd19e 100644
--- a/compiler/rustc_codegen_llvm/src/back/write.rs
+++ b/compiler/rustc_codegen_llvm/src/back/write.rs
@@ -574,7 +574,8 @@ pub(crate) unsafe fn llvm_optimize(
     // FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
     // differentiated.
 
-    let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
+    let consider_ad =
+        cfg!(feature = "llvm_enzyme") && config.autodiff.contains(&config::AutoDiff::Enable);
     let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
     let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore);
     let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter);
@@ -740,7 +741,8 @@ pub(crate) fn optimize(
 
         // If we know that we will later run AD, then we disable vectorization and loop unrolling.
         // Otherwise we pretend AD is already done and run the normal opt pipeline (=PostAD).
-        let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
+        let consider_ad =
+            cfg!(feature = "llvm_enzyme") && config.autodiff.contains(&config::AutoDiff::Enable);
         let autodiff_stage = if consider_ad { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
         // The embedded bitcode is used to run LTO/ThinLTO.
         // The bitcode obtained during the `codegen` phase is no longer suitable for performing LTO.
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 6ddf53cdc87..b66e3dfdeec 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -3,8 +3,9 @@ use std::ptr;
 use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
 use rustc_codegen_ssa::common::TypeKind;
 use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
-use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
+use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
 use rustc_middle::{bug, ty};
+use rustc_target::callconv::PassMode;
 use tracing::debug;
 
 use crate::builder::{Builder, PlaceRef, UNNAMED};
@@ -16,9 +17,12 @@ use crate::value::Value;
 
 pub(crate) fn adjust_activity_to_abi<'tcx>(
     tcx: TyCtxt<'tcx>,
-    fn_ty: Ty<'tcx>,
+    instance: Instance<'tcx>,
+    typing_env: TypingEnv<'tcx>,
     da: &mut Vec<DiffActivity>,
 ) {
+    let fn_ty = instance.ty(tcx, typing_env);
+
     if !matches!(fn_ty.kind(), ty::FnDef(..)) {
         bug!("expected fn def for autodiff, got {:?}", fn_ty);
     }
@@ -27,8 +31,16 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
     // All we do is decide how to handle the arguments.
     let sig = fn_ty.fn_sig(tcx).skip_binder();
 
+    // FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
+    let Ok(fn_abi) =
+        tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
+    else {
+        bug!("failed to get fn_abi of instance with empty varargs");
+    };
+
     let mut new_activities = vec![];
     let mut new_positions = vec![];
+    let mut del_activities = 0;
     for (i, ty) in sig.inputs().iter().enumerate() {
         if let Some(inner_ty) = ty.builtin_deref(true) {
             if inner_ty.is_slice() {
@@ -80,6 +92,34 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
                 continue;
             }
         }
+
+        let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty };
+
+        let layout = match tcx.layout_of(pci) {
+            Ok(layout) => layout.layout,
+            Err(_) => {
+                bug!("failed to compute layout for type {:?}", ty);
+            }
+        };
+
+        let pass_mode = &fn_abi.args[i].mode;
+
+        // For ZST, just ignore and don't add its activity, as this arg won't be present
+        // in the LLVM passed to Enzyme.
+        // Some targets pass ZST indirectly in the C ABI, in that case, handle it as a normal arg
+        // FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const`
+        if *pass_mode == PassMode::Ignore {
+            del_activities += 1;
+            da.remove(i);
+        }
+
+        // If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
+        // Otherwise, the number of activities won't match the number of LLVM arguments and
+        // this will lead to errors when verifying the Enzyme call.
+        if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
+            new_activities.push(da[i].clone());
+            new_positions.push(i + 1 - del_activities);
+        }
     }
     // now add the extra activities coming from slices
     // Reverse order to not invalidate the indices
diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs
index 0903ddab285..aa8b8bd152d 100644
--- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs
+++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs
@@ -821,14 +821,15 @@ fn build_basic_type_di_node<'ll, 'tcx>(
     };
 
     let typedef_di_node = unsafe {
-        llvm::LLVMRustDIBuilderCreateTypedef(
+        llvm::LLVMDIBuilderCreateTypedef(
             DIB(cx),
             ty_di_node,
-            typedef_name.as_c_char_ptr(),
+            typedef_name.as_ptr(),
             typedef_name.len(),
             unknown_file_metadata(cx),
-            0,
-            None,
+            0,    // (no line number)
+            None, // (no scope)
+            0u32, // (no alignment specified)
         )
     };
 
@@ -1034,10 +1035,10 @@ fn create_member_type<'ll, 'tcx>(
     type_di_node: &'ll DIType,
 ) -> &'ll DIType {
     unsafe {
-        llvm::LLVMRustDIBuilderCreateMemberType(
+        llvm::LLVMDIBuilderCreateMemberType(
             DIB(cx),
             owner,
-            name.as_c_char_ptr(),
+            name.as_ptr(),
             name.len(),
             file_metadata,
             line_number,
diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
index a5c80895741..4ecc3086e1b 100644
--- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
+++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
@@ -11,7 +11,7 @@ use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
 use rustc_middle::ty::{self, AdtDef, CoroutineArgs, CoroutineArgsExt, Ty};
 use smallvec::smallvec;
 
-use crate::common::{AsCCharPtr, CodegenCx};
+use crate::common::CodegenCx;
 use crate::debuginfo::dwarf_const::DW_TAG_const_type;
 use crate::debuginfo::metadata::enums::DiscrResult;
 use crate::debuginfo::metadata::type_map::{self, Stub, UniqueTypeId};
@@ -378,20 +378,17 @@ fn build_single_variant_union_fields<'ll, 'tcx>(
             variant_struct_type_wrapper_di_node,
             None,
         ),
-        unsafe {
-            llvm::LLVMRustDIBuilderCreateStaticMemberType(
-                DIB(cx),
-                enum_type_di_node,
-                TAG_FIELD_NAME.as_c_char_ptr(),
-                TAG_FIELD_NAME.len(),
-                unknown_file_metadata(cx),
-                UNKNOWN_LINE_NUMBER,
-                variant_names_type_di_node,
-                visibility_flags,
-                Some(cx.const_u64(SINGLE_VARIANT_VIRTUAL_DISR)),
-                tag_base_type_align.bits() as u32,
-            )
-        }
+        create_static_member_type(
+            cx,
+            enum_type_di_node,
+            TAG_FIELD_NAME,
+            unknown_file_metadata(cx),
+            UNKNOWN_LINE_NUMBER,
+            variant_names_type_di_node,
+            visibility_flags,
+            Some(cx.const_u64(SINGLE_VARIANT_VIRTUAL_DISR)),
+            tag_base_type_align,
+        ),
     ]
 }
 
@@ -570,27 +567,28 @@ fn build_variant_struct_wrapper_type_di_node<'ll, 'tcx>(
             let build_assoc_const = |name: &str,
                                      type_di_node_: &'ll DIType,
                                      value: u64,
-                                     align: Align| unsafe {
+                                     align: Align|
+             -> &'ll llvm::Metadata {
                 // FIXME: Currently we force all DISCR_* values to be u64's as LLDB seems to have
                 // problems inspecting other value types. Since DISCR_* is typically only going to be
                 // directly inspected via the debugger visualizer - which compares it to the `tag` value
                 // (whose type is not modified at all) it shouldn't cause any real problems.
                 let (t_di, align) = if name == ASSOC_CONST_DISCR_NAME {
-                    (type_di_node_, align.bits() as u32)
+                    (type_di_node_, align)
                 } else {
                     let ty_u64 = Ty::new_uint(cx.tcx, ty::UintTy::U64);
-                    (type_di_node(cx, ty_u64), Align::EIGHT.bits() as u32)
+                    (type_di_node(cx, ty_u64), Align::EIGHT)
                 };
 
                 // must wrap type in a `const` modifier for LLDB to be able to inspect the value of the member
-                let field_type =
-                    llvm::LLVMRustDIBuilderCreateQualifiedType(DIB(cx), DW_TAG_const_type, t_di);
+                let field_type = unsafe {
+                    llvm::LLVMDIBuilderCreateQualifiedType(DIB(cx), DW_TAG_const_type, t_di)
+                };
 
-                llvm::LLVMRustDIBuilderCreateStaticMemberType(
-                    DIB(cx),
+                create_static_member_type(
+                    cx,
                     wrapper_struct_type_di_node,
-                    name.as_c_char_ptr(),
-                    name.len(),
+                    name,
                     unknown_file_metadata(cx),
                     UNKNOWN_LINE_NUMBER,
                     field_type,
@@ -975,3 +973,30 @@ fn variant_struct_wrapper_type_name(variant_index: VariantIdx) -> Cow<'static, s
         .map(|&s| Cow::from(s))
         .unwrap_or_else(|| format!("Variant{}", variant_index.as_usize()).into())
 }
+
+fn create_static_member_type<'ll>(
+    cx: &CodegenCx<'ll, '_>,
+    scope: &'ll llvm::Metadata,
+    name: &str,
+    file: &'ll llvm::Metadata,
+    line_number: c_uint,
+    ty: &'ll llvm::Metadata,
+    flags: DIFlags,
+    value: Option<&'ll llvm::Value>,
+    align: Align,
+) -> &'ll llvm::Metadata {
+    unsafe {
+        llvm::LLVMDIBuilderCreateStaticMemberType(
+            DIB(cx),
+            scope,
+            name.as_ptr(),
+            name.len(),
+            file,
+            line_number,
+            ty,
+            flags,
+            value,
+            align.bits() as c_uint,
+        )
+    }
+}
diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs
index 85f71f331a4..e7f4a357048 100644
--- a/compiler/rustc_codegen_llvm/src/intrinsic.rs
+++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs
@@ -1208,7 +1208,8 @@ fn codegen_autodiff<'ll, 'tcx>(
 
     adjust_activity_to_abi(
         tcx,
-        fn_source.ty(tcx, TypingEnv::fully_monomorphized()),
+        fn_source,
+        TypingEnv::fully_monomorphized(),
         &mut diff_attrs.input_activity,
     );
 
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index 56d756e52cc..695435eb6da 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -59,10 +59,10 @@ pub(crate) enum LLVMRustVerifierFailureAction {
     LLVMReturnStatusAction = 2,
 }
 
-#[cfg(llvm_enzyme)]
+#[cfg(feature = "llvm_enzyme")]
 pub(crate) use self::Enzyme_AD::*;
 
-#[cfg(llvm_enzyme)]
+#[cfg(feature = "llvm_enzyme")]
 pub(crate) mod Enzyme_AD {
     use std::ffi::{CString, c_char};
 
@@ -134,10 +134,10 @@ pub(crate) mod Enzyme_AD {
     }
 }
 
-#[cfg(not(llvm_enzyme))]
+#[cfg(not(feature = "llvm_enzyme"))]
 pub(crate) use self::Fallback_AD::*;
 
-#[cfg(not(llvm_enzyme))]
+#[cfg(not(feature = "llvm_enzyme"))]
 pub(crate) mod Fallback_AD {
     #![allow(unused_variables)]
 
diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
index 71d8b7d25fe..1124ebc3d44 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
@@ -24,7 +24,7 @@ use rustc_target::spec::SymbolVisibility;
 
 use super::RustString;
 use super::debuginfo::{
-    DIArray, DIBasicType, DIBuilder, DIDerivedType, DIDescriptor, DIEnumerator, DIFile, DIFlags,
+    DIArray, DIBuilder, DIDerivedType, DIDescriptor, DIEnumerator, DIFile, DIFlags,
     DIGlobalVariableExpression, DILocation, DISPFlags, DIScope, DISubprogram, DISubrange,
     DITemplateTypeParameter, DIType, DIVariable, DebugEmissionKind, DebugNameTableKind,
 };
@@ -1943,6 +1943,52 @@ unsafe extern "C" {
         UniqueId: *const c_uchar, // See "PTR_LEN_STR".
         UniqueIdLen: size_t,
     ) -> &'ll Metadata;
+
+    pub(crate) fn LLVMDIBuilderCreateMemberType<'ll>(
+        Builder: &DIBuilder<'ll>,
+        Scope: &'ll Metadata,
+        Name: *const c_uchar, // See "PTR_LEN_STR".
+        NameLen: size_t,
+        File: &'ll Metadata,
+        LineNo: c_uint,
+        SizeInBits: u64,
+        AlignInBits: u32,
+        OffsetInBits: u64,
+        Flags: DIFlags,
+        Ty: &'ll Metadata,
+    ) -> &'ll Metadata;
+
+    pub(crate) fn LLVMDIBuilderCreateStaticMemberType<'ll>(
+        Builder: &DIBuilder<'ll>,
+        Scope: &'ll Metadata,
+        Name: *const c_uchar, // See "PTR_LEN_STR".
+        NameLen: size_t,
+        File: &'ll Metadata,
+        LineNumber: c_uint,
+        Type: &'ll Metadata,
+        Flags: DIFlags,
+        ConstantVal: Option<&'ll Value>,
+        AlignInBits: u32,
+    ) -> &'ll Metadata;
+
+    /// Creates a "qualified type" in the C/C++ sense, by adding modifiers
+    /// like `const` or `volatile`.
+    pub(crate) fn LLVMDIBuilderCreateQualifiedType<'ll>(
+        Builder: &DIBuilder<'ll>,
+        Tag: c_uint, // (DWARF tag, e.g. `DW_TAG_const_type`)
+        Type: &'ll Metadata,
+    ) -> &'ll Metadata;
+
+    pub(crate) fn LLVMDIBuilderCreateTypedef<'ll>(
+        Builder: &DIBuilder<'ll>,
+        Type: &'ll Metadata,
+        Name: *const c_uchar, // See "PTR_LEN_STR".
+        NameLen: size_t,
+        File: &'ll Metadata,
+        LineNo: c_uint,
+        Scope: Option<&'ll Metadata>,
+        AlignInBits: u32, // (optional; default is 0)
+    ) -> &'ll Metadata;
 }
 
 #[link(name = "llvm-wrapper", kind = "static")]
@@ -2278,30 +2324,6 @@ unsafe extern "C" {
         TParam: &'a DIArray,
     ) -> &'a DISubprogram;
 
-    pub(crate) fn LLVMRustDIBuilderCreateTypedef<'a>(
-        Builder: &DIBuilder<'a>,
-        Type: &'a DIBasicType,
-        Name: *const c_char,
-        NameLen: size_t,
-        File: &'a DIFile,
-        LineNo: c_uint,
-        Scope: Option<&'a DIScope>,
-    ) -> &'a DIDerivedType;
-
-    pub(crate) fn LLVMRustDIBuilderCreateMemberType<'a>(
-        Builder: &DIBuilder<'a>,
-        Scope: &'a DIDescriptor,
-        Name: *const c_char,
-        NameLen: size_t,
-        File: &'a DIFile,
-        LineNo: c_uint,
-        SizeInBits: u64,
-        AlignInBits: u32,
-        OffsetInBits: u64,
-        Flags: DIFlags,
-        Ty: &'a DIType,
-    ) -> &'a DIDerivedType;
-
     pub(crate) fn LLVMRustDIBuilderCreateVariantMemberType<'a>(
         Builder: &DIBuilder<'a>,
         Scope: &'a DIScope,
@@ -2317,25 +2339,6 @@ unsafe extern "C" {
         Ty: &'a DIType,
     ) -> &'a DIType;
 
-    pub(crate) fn LLVMRustDIBuilderCreateStaticMemberType<'a>(
-        Builder: &DIBuilder<'a>,
-        Scope: &'a DIDescriptor,
-        Name: *const c_char,
-        NameLen: size_t,
-        File: &'a DIFile,
-        LineNo: c_uint,
-        Ty: &'a DIType,
-        Flags: DIFlags,
-        val: Option<&'a Value>,
-        AlignInBits: u32,
-    ) -> &'a DIDerivedType;
-
-    pub(crate) fn LLVMRustDIBuilderCreateQualifiedType<'a>(
-        Builder: &DIBuilder<'a>,
-        Tag: c_uint,
-        Type: &'a DIType,
-    ) -> &'a DIDerivedType;
-
     pub(crate) fn LLVMRustDIBuilderCreateStaticVariable<'a>(
         Builder: &DIBuilder<'a>,
         Context: Option<&'a DIScope>,