about summary refs log tree commit diff
diff options
context:
space:
mode:
authorHaeNoe <git@haenoe.party>2025-04-19 19:28:33 +0200
committerHaeNoe <git@haenoe.party>2025-05-11 17:54:57 +0200
commit8b3228233e079b8d5f02484c1e9a06183e855a6d (patch)
tree19ab9cf7feb2eae6481186de724138fd7461422c
parent56a0c7dfea60a721948e29458fc714b6303e8e4a (diff)
downloadrust-8b3228233e079b8d5f02484c1e9a06183e855a6d.tar.gz
rust-8b3228233e079b8d5f02484c1e9a06183e855a6d.zip
feat: add test for generics in generated function
-rw-r--r--tests/pretty/autodiff/autodiff_forward.pp14
-rw-r--r--tests/pretty/autodiff/autodiff_forward.rs6
2 files changed, 20 insertions, 0 deletions
diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp
index 713b8f541ae..2fa122a618d 100644
--- a/tests/pretty/autodiff/autodiff_forward.pp
+++ b/tests/pretty/autodiff/autodiff_forward.pp
@@ -31,6 +31,8 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
 
     // We want to make sure that we can use the macro for functions defined inside of functions
 
+    // Make sure we can handle generics
+
     ::core::panicking::panic("not implemented")
 }
 #[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
@@ -181,4 +183,16 @@ pub fn f9() {
         ::core::hint::black_box(<f32>::default())
     }
 }
+#[rustc_autodiff]
+#[inline(never)]
+pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
+#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
+#[inline(never)]
+pub fn d_square<T: std::ops::Mul<Output = T> +
+    Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
+    unsafe { asm!("NOP", options(pure, nomem)); };
+    ::core::hint::black_box(f10(x));
+    ::core::hint::black_box((dx_0, dret));
+    ::core::hint::black_box(f10(x))
+}
 fn main() {}
diff --git a/tests/pretty/autodiff/autodiff_forward.rs b/tests/pretty/autodiff/autodiff_forward.rs
index 5a0660a08e5..ae974f9b4db 100644
--- a/tests/pretty/autodiff/autodiff_forward.rs
+++ b/tests/pretty/autodiff/autodiff_forward.rs
@@ -63,4 +63,10 @@ pub fn f9() {
     }
 }
 
+// Make sure we can handle generics
+#[autodiff(d_square, Reverse, Duplicated, Active)]
+pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
+    *x * *x
+}
+
 fn main() {}