about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs32
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs8
2 files changed, 21 insertions, 19 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index c260dca87c0..48d0795af5e 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -192,7 +192,6 @@ mod llvm_enzyme {
     /// which becomes expanded to:
     /// ```
     /// #[rustc_autodiff]
-    /// #[inline(never)]
     /// fn sin(x: &Box<f32>) -> f32 {
     ///     f32::sin(**x)
     /// }
@@ -371,7 +370,7 @@ mod llvm_enzyme {
         let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
         let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
 
-        // We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
+        // We're avoid duplicating the attribute `#[rustc_autodiff]`.
         fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
             match (attr, item) {
                 (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
@@ -384,14 +383,16 @@ mod llvm_enzyme {
             }
         }
 
+        let mut has_inline_never = false;
+
         // Don't add it multiple times:
         let orig_annotatable: Annotatable = match item {
             Annotatable::Item(ref mut iitem) => {
                 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
                     iitem.attrs.push(attr);
                 }
-                if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
-                    iitem.attrs.push(inline_never.clone());
+                if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
+                    has_inline_never = true;
                 }
                 Annotatable::Item(iitem.clone())
             }
@@ -399,8 +400,8 @@ mod llvm_enzyme {
                 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
                     assoc_item.attrs.push(attr);
                 }
-                if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
-                    assoc_item.attrs.push(inline_never.clone());
+                if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
+                    has_inline_never = true;
                 }
                 Annotatable::AssocItem(assoc_item.clone(), i)
             }
@@ -410,9 +411,8 @@ mod llvm_enzyme {
                         if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
                             iitem.attrs.push(attr);
                         }
-                        if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
-                        {
-                            iitem.attrs.push(inline_never.clone());
+                        if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
+                            has_inline_never = true;
                         }
                     }
                     _ => unreachable!("stmt kind checked previously"),
@@ -433,11 +433,19 @@ mod llvm_enzyme {
 
         let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
         let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
+
+        // If the source function has the `#[inline(never)]` attribute, we'll also add it to the diff function
+        let mut d_attrs = thin_vec![d_attr];
+
+        if has_inline_never {
+            d_attrs.push(inline_never);
+        }
+
         let d_annotatable = match &item {
             Annotatable::AssocItem(_, _) => {
                 let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
                 let d_fn = Box::new(ast::AssocItem {
-                    attrs: thin_vec![d_attr],
+                    attrs: d_attrs,
                     id: ast::DUMMY_NODE_ID,
                     span,
                     vis,
@@ -447,13 +455,13 @@ mod llvm_enzyme {
                 Annotatable::AssocItem(d_fn, Impl { of_trait: false })
             }
             Annotatable::Item(_) => {
-                let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
+                let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
                 d_fn.vis = vis;
 
                 Annotatable::Item(d_fn)
             }
             Annotatable::Stmt(_) => {
-                let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
+                let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
                 d_fn.vis = vis;
 
                 Annotatable::Stmt(Box::new(ast::Stmt {
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 56116959a62..e2df3265f6f 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -10,10 +10,9 @@ use tracing::debug;
 use crate::builder::{Builder, PlaceRef, UNNAMED};
 use crate::context::SimpleCx;
 use crate::declare::declare_simple_fn;
-use crate::llvm::AttributePlace::Function;
+use crate::llvm;
 use crate::llvm::{Metadata, True, Type};
 use crate::value::Value;
-use crate::{attributes, llvm};
 
 pub(crate) fn adjust_activity_to_abi<'tcx>(
     tcx: TyCtxt<'tcx>,
@@ -308,11 +307,6 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
         enzyme_ty,
     );
 
-    // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
-    // do it's work.
-    let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
-    attributes::apply_to_llfn(ad_fn, Function, &[attr]);
-
     let num_args = llvm::LLVMCountParams(&fn_to_diff);
     let mut args = Vec::with_capacity(num_args as usize + 1);
     args.push(fn_to_diff);