about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
diff options
context:
space:
mode:
authorManuel Drehwald <git@manuel.drehwald.info>2025-02-05 18:47:23 -0500
committerManuel Drehwald <git@manuel.drehwald.info>2025-02-05 18:47:23 -0500
commit70b9ba3d6e1d64e6b00da707e5b1b5127e63b1cf (patch)
treee7a6f4dc63c6baf0d1383f43858e4f840ea061fc /compiler/rustc_codegen_llvm/src/builder/autodiff.rs
parent335151f8bbadf31c2d8dae7d2a25dbcdab45a3b6 (diff)
downloadrust-70b9ba3d6e1d64e6b00da707e5b1b5127e63b1cf.tar.gz
rust-70b9ba3d6e1d64e6b00da707e5b1b5127e63b1cf.zip
fix fwd-mode autodiff case
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/builder/autodiff.rs')
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs11
1 files changed, 8 insertions, 3 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 9e8e4e1c567..474b0940203 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -164,10 +164,10 @@ fn generate_enzyme_call<'ll>(
         let mut activity_pos = 0;
         let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
         while activity_pos < inputs.len() {
-            let activity = inputs[activity_pos as usize];
+            let diff_activity = inputs[activity_pos as usize];
             // Duplicated arguments received a shadow argument, into which enzyme will write the
             // gradient.
-            let (activity, duplicated): (&Metadata, bool) = match activity {
+            let (activity, duplicated): (&Metadata, bool) = match diff_activity {
                 DiffActivity::None => panic!("not a valid input activity"),
                 DiffActivity::Const => (enzyme_const, false),
                 DiffActivity::Active => (enzyme_out, false),
@@ -222,7 +222,12 @@ fn generate_enzyme_call<'ll>(
                     // A duplicated pointer will have the following two outer_fn arguments:
                     // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
                     // (..., metadata! enzyme_dup, ptr, ptr, ...).
-                    assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer);
+                    if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) {
+                        assert!(
+                            llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
+                        );
+                    }
+                    // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
                     args.push(next_outer_arg);
                     outer_pos += 2;
                     activity_pos += 1;