about summary refs log tree commit diff
diff options
context:
space:
mode:
authorChris Denton <chris@chrisdenton.dev>2025-04-22 15:24:07 +0000
committerGitHub <noreply@github.com>2025-04-22 15:24:07 +0000
commit264249fbe1d9d7ccbf5f7bebdd5fcfa9594c2c24 (patch)
tree693998003cfa838a00ef67cec20e659c30e84ce3
parent107f04daa81212e8973344a10e90a02d8d9e5eac (diff)
parentb8ca0073c8e4d1d9e324db832500beb41c129040 (diff)
downloadrust-264249fbe1d9d7ccbf5f7bebdd5fcfa9594c2c24.tar.gz
rust-264249fbe1d9d7ccbf5f7bebdd5fcfa9594c2c24.zip
Rollup merge of #140104 - Shourya742:2025-04-21-auto-diff-fails-on-impl-block, r=ZuseZ4
Fix auto diff failing on inherent impl blocks

closes: #139557

r? ``@ZuseZ4``
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs14
-rw-r--r--tests/pretty/autodiff/autodiff_forward.pp (renamed from tests/pretty/autodiff_forward.pp)0
-rw-r--r--tests/pretty/autodiff/autodiff_forward.rs (renamed from tests/pretty/autodiff_forward.rs)0
-rw-r--r--tests/pretty/autodiff/autodiff_reverse.pp (renamed from tests/pretty/autodiff_reverse.pp)0
-rw-r--r--tests/pretty/autodiff/autodiff_reverse.rs (renamed from tests/pretty/autodiff_reverse.rs)0
-rw-r--r--tests/pretty/autodiff/inherent_impl.pp41
-rw-r--r--tests/pretty/autodiff/inherent_impl.rs24
7 files changed, 71 insertions, 8 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index e60efdbefd9..6d97dfa3a4d 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -217,14 +217,12 @@ mod llvm_enzyme {
                 ast::StmtKind::Item(iitem) => extract_item_info(iitem),
                 _ => None,
             },
-            Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
-                match &assoc_item.kind {
-                    ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
-                        Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
-                    }
-                    _ => None,
+            Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
+                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
+                    Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
                 }
-            }
+                _ => None,
+            },
             _ => None,
         }) else {
             dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
@@ -365,7 +363,7 @@ mod llvm_enzyme {
                 }
                 Annotatable::Item(iitem.clone())
             }
-            Annotatable::AssocItem(ref mut assoc_item, i @ Impl { of_trait: false }) => {
+            Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
                 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
                     assoc_item.attrs.push(attr);
                 }
diff --git a/tests/pretty/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp
index 713b8f541ae..713b8f541ae 100644
--- a/tests/pretty/autodiff_forward.pp
+++ b/tests/pretty/autodiff/autodiff_forward.pp
diff --git a/tests/pretty/autodiff_forward.rs b/tests/pretty/autodiff/autodiff_forward.rs
index 5a0660a08e5..5a0660a08e5 100644
--- a/tests/pretty/autodiff_forward.rs
+++ b/tests/pretty/autodiff/autodiff_forward.rs
diff --git a/tests/pretty/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp
index 31920694a3a..31920694a3a 100644
--- a/tests/pretty/autodiff_reverse.pp
+++ b/tests/pretty/autodiff/autodiff_reverse.pp
diff --git a/tests/pretty/autodiff_reverse.rs b/tests/pretty/autodiff/autodiff_reverse.rs
index 3c024272f40..3c024272f40 100644
--- a/tests/pretty/autodiff_reverse.rs
+++ b/tests/pretty/autodiff/autodiff_reverse.rs
diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp
new file mode 100644
index 00000000000..97ac766b6b9
--- /dev/null
+++ b/tests/pretty/autodiff/inherent_impl.pp
@@ -0,0 +1,41 @@
+#![feature(prelude_import)]
+#![no_std]
+//@ needs-enzyme
+
+#![feature(autodiff)]
+#[prelude_import]
+use ::std::prelude::rust_2015::*;
+#[macro_use]
+extern crate std;
+//@ pretty-mode:expanded
+//@ pretty-compare-only
+//@ pp-exact:inherent_impl.pp
+
+use std::autodiff::autodiff;
+
+struct Foo {
+    a: f64,
+}
+
+trait MyTrait {
+    fn f(&self, x: f64)
+    -> f64;
+    fn df(&self, x: f64, seed: f64)
+    -> (f64, f64);
+}
+
+impl MyTrait for Foo {
+    #[rustc_autodiff]
+    #[inline(never)]
+    fn f(&self, x: f64) -> f64 {
+        self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
+    }
+    #[rustc_autodiff(Reverse, 1, Const, Active, Active)]
+    #[inline(never)]
+    fn df(&self, x: f64, dret: f64) -> (f64, f64) {
+        unsafe { asm!("NOP", options(pure, nomem)); };
+        ::core::hint::black_box(self.f(x));
+        ::core::hint::black_box((dret,));
+        ::core::hint::black_box((self.f(x), f64::default()))
+    }
+}
diff --git a/tests/pretty/autodiff/inherent_impl.rs b/tests/pretty/autodiff/inherent_impl.rs
new file mode 100644
index 00000000000..59de93f7e0f
--- /dev/null
+++ b/tests/pretty/autodiff/inherent_impl.rs
@@ -0,0 +1,24 @@
+//@ needs-enzyme
+
+#![feature(autodiff)]
+//@ pretty-mode:expanded
+//@ pretty-compare-only
+//@ pp-exact:inherent_impl.pp
+
+use std::autodiff::autodiff;
+
+struct Foo {
+    a: f64,
+}
+
+trait MyTrait {
+    fn f(&self, x: f64) -> f64;
+    fn df(&self, x: f64, seed: f64) -> (f64, f64);
+}
+
+impl MyTrait for Foo {
+    #[autodiff(df, Reverse, Const, Active, Active)]
+    fn f(&self, x: f64) -> f64 {
+        self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
+    }
+}