diff options
| author | Chris Denton <chris@chrisdenton.dev> | 2025-04-22 15:24:07 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-22 15:24:07 +0000 |
| commit | 264249fbe1d9d7ccbf5f7bebdd5fcfa9594c2c24 (patch) | |
| tree | 693998003cfa838a00ef67cec20e659c30e84ce3 | |
| parent | 107f04daa81212e8973344a10e90a02d8d9e5eac (diff) | |
| parent | b8ca0073c8e4d1d9e324db832500beb41c129040 (diff) | |
| download | rust-264249fbe1d9d7ccbf5f7bebdd5fcfa9594c2c24.tar.gz rust-264249fbe1d9d7ccbf5f7bebdd5fcfa9594c2c24.zip | |
Rollup merge of #140104 - Shourya742:2025-04-21-auto-diff-fails-on-impl-block, r=ZuseZ4
Fix auto diff failing on inherent impl blocks closes: #139557 r? ``@ZuseZ4``
| -rw-r--r-- | compiler/rustc_builtin_macros/src/autodiff.rs | 14 | ||||
| -rw-r--r-- | tests/pretty/autodiff/autodiff_forward.pp (renamed from tests/pretty/autodiff_forward.pp) | 0 | ||||
| -rw-r--r-- | tests/pretty/autodiff/autodiff_forward.rs (renamed from tests/pretty/autodiff_forward.rs) | 0 | ||||
| -rw-r--r-- | tests/pretty/autodiff/autodiff_reverse.pp (renamed from tests/pretty/autodiff_reverse.pp) | 0 | ||||
| -rw-r--r-- | tests/pretty/autodiff/autodiff_reverse.rs (renamed from tests/pretty/autodiff_reverse.rs) | 0 | ||||
| -rw-r--r-- | tests/pretty/autodiff/inherent_impl.pp | 41 | ||||
| -rw-r--r-- | tests/pretty/autodiff/inherent_impl.rs | 24 |
7 files changed, 71 insertions, 8 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index e60efdbefd9..6d97dfa3a4d 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -217,14 +217,12 @@ mod llvm_enzyme { ast::StmtKind::Item(iitem) => extract_item_info(iitem), _ => None, }, - Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => { - match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => { - Some((assoc_item.vis.clone(), sig.clone(), ident.clone())) - } - _ => None, + Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => { + Some((assoc_item.vis.clone(), sig.clone(), ident.clone())) } - } + _ => None, + }, _ => None, }) else { dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); @@ -365,7 +363,7 @@ mod llvm_enzyme { } Annotatable::Item(iitem.clone()) } - Annotatable::AssocItem(ref mut assoc_item, i @ Impl { of_trait: false }) => { + Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => { if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { assoc_item.attrs.push(attr); } diff --git a/tests/pretty/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index 713b8f541ae..713b8f541ae 100644 --- a/tests/pretty/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp diff --git a/tests/pretty/autodiff_forward.rs b/tests/pretty/autodiff/autodiff_forward.rs index 5a0660a08e5..5a0660a08e5 100644 --- a/tests/pretty/autodiff_forward.rs +++ b/tests/pretty/autodiff/autodiff_forward.rs diff --git a/tests/pretty/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp index 31920694a3a..31920694a3a 100644 --- a/tests/pretty/autodiff_reverse.pp +++ b/tests/pretty/autodiff/autodiff_reverse.pp diff --git a/tests/pretty/autodiff_reverse.rs b/tests/pretty/autodiff/autodiff_reverse.rs index 3c024272f40..3c024272f40 100644 --- a/tests/pretty/autodiff_reverse.rs +++ b/tests/pretty/autodiff/autodiff_reverse.rs diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp new file mode 100644 index 00000000000..97ac766b6b9 --- /dev/null +++ b/tests/pretty/autodiff/inherent_impl.pp @@ -0,0 +1,41 @@ +#![feature(prelude_import)] +#![no_std] +//@ needs-enzyme + +#![feature(autodiff)] +#[prelude_import] +use ::std::prelude::rust_2015::*; +#[macro_use] +extern crate std; +//@ pretty-mode:expanded +//@ pretty-compare-only +//@ pp-exact:inherent_impl.pp + +use std::autodiff::autodiff; + +struct Foo { + a: f64, +} + +trait MyTrait { + fn f(&self, x: f64) + -> f64; + fn df(&self, x: f64, seed: f64) + -> (f64, f64); +} + +impl MyTrait for Foo { + #[rustc_autodiff] + #[inline(never)] + fn f(&self, x: f64) -> f64 { + self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) + } + #[rustc_autodiff(Reverse, 1, Const, Active, Active)] + #[inline(never)] + fn df(&self, x: f64, dret: f64) -> (f64, f64) { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(self.f(x)); + ::core::hint::black_box((dret,)); + ::core::hint::black_box((self.f(x), f64::default())) + } +} diff --git a/tests/pretty/autodiff/inherent_impl.rs b/tests/pretty/autodiff/inherent_impl.rs new file mode 100644 index 00000000000..59de93f7e0f --- /dev/null +++ b/tests/pretty/autodiff/inherent_impl.rs @@ -0,0 +1,24 @@ +//@ needs-enzyme + +#![feature(autodiff)] +//@ pretty-mode:expanded +//@ pretty-compare-only +//@ pp-exact:inherent_impl.pp + +use std::autodiff::autodiff; + +struct Foo { + a: f64, +} + +trait MyTrait { + fn f(&self, x: f64) -> f64; + fn df(&self, x: f64, seed: f64) -> (f64, f64); +} + +impl MyTrait for Foo { + #[autodiff(df, Reverse, Const, Active, Active)] + fn f(&self, x: f64) -> f64 { + self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) + } +} |
