about summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/pretty/autodiff_forward.pp23
-rw-r--r--tests/pretty/autodiff_forward.rs9
2 files changed, 32 insertions, 0 deletions
diff --git a/tests/pretty/autodiff_forward.pp b/tests/pretty/autodiff_forward.pp
index 4b2fb6166ff..713b8f541ae 100644
--- a/tests/pretty/autodiff_forward.pp
+++ b/tests/pretty/autodiff_forward.pp
@@ -29,6 +29,8 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
     // Make sure, that we add the None for the default return.
 
 
+    // We want to make sure that we can use the macro for functions defined inside of functions
+
     ::core::panicking::panic("not implemented")
 }
 #[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
@@ -158,4 +160,25 @@ fn f8_1(x: &f32, bx_0: &f32) -> f32 {
     ::core::hint::black_box((bx_0,));
     ::core::hint::black_box(<f32>::default())
 }
+pub fn f9() {
+    #[rustc_autodiff]
+    #[inline(never)]
+    fn inner(x: f32) -> f32 { x * x }
+    #[rustc_autodiff(Forward, 1, Dual, Dual)]
+    #[inline(never)]
+    fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
+        unsafe { asm!("NOP", options(pure, nomem)); };
+        ::core::hint::black_box(inner(x));
+        ::core::hint::black_box((bx_0,));
+        ::core::hint::black_box(<(f32, f32)>::default())
+    }
+    #[rustc_autodiff(Forward, 1, Dual, DualOnly)]
+    #[inline(never)]
+    fn d_inner_1(x: f32, bx_0: f32) -> f32 {
+        unsafe { asm!("NOP", options(pure, nomem)); };
+        ::core::hint::black_box(inner(x));
+        ::core::hint::black_box((bx_0,));
+        ::core::hint::black_box(<f32>::default())
+    }
+}
 fn main() {}
diff --git a/tests/pretty/autodiff_forward.rs b/tests/pretty/autodiff_forward.rs
index a765738c2a8..5a0660a08e5 100644
--- a/tests/pretty/autodiff_forward.rs
+++ b/tests/pretty/autodiff_forward.rs
@@ -54,4 +54,13 @@ fn f8(x: &f32) -> f32 {
     unimplemented!()
 }
 
+// We want to make sure that we can use the macro for functions defined inside of functions
+pub fn f9() {
+    #[autodiff(d_inner_1, Forward, Dual, DualOnly)]
+    #[autodiff(d_inner_2, Forward, Dual, Dual)]
+    fn inner(x: f32) -> f32 {
+        x * x
+    }
+}
+
 fn main() {}