diff options
| author | Manuel Drehwald <git@manuel.drehwald.info> | 2025-02-05 18:47:23 -0500 |
|---|---|---|
| committer | Manuel Drehwald <git@manuel.drehwald.info> | 2025-02-05 18:47:23 -0500 |
| commit | 70b9ba3d6e1d64e6b00da707e5b1b5127e63b1cf (patch) | |
| tree | e7a6f4dc63c6baf0d1383f43858e4f840ea061fc /compiler/rustc_codegen_llvm/src/builder/autodiff.rs | |
| parent | 335151f8bbadf31c2d8dae7d2a25dbcdab45a3b6 (diff) | |
| download | rust-70b9ba3d6e1d64e6b00da707e5b1b5127e63b1cf.tar.gz rust-70b9ba3d6e1d64e6b00da707e5b1b5127e63b1cf.zip | |
fix fwd-mode autodiff case
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/builder/autodiff.rs')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder/autodiff.rs | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 9e8e4e1c567..474b0940203 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -164,10 +164,10 @@ fn generate_enzyme_call<'ll>( let mut activity_pos = 0; let outer_args: Vec<&llvm::Value> = get_params(outer_fn); while activity_pos < inputs.len() { - let activity = inputs[activity_pos as usize]; + let diff_activity = inputs[activity_pos as usize]; // Duplicated arguments received a shadow argument, into which enzyme will write the // gradient. - let (activity, duplicated): (&Metadata, bool) = match activity { + let (activity, duplicated): (&Metadata, bool) = match diff_activity { DiffActivity::None => panic!("not a valid input activity"), DiffActivity::Const => (enzyme_const, false), DiffActivity::Active => (enzyme_out, false), @@ -222,7 +222,12 @@ fn generate_enzyme_call<'ll>( // A duplicated pointer will have the following two outer_fn arguments: // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call: // (..., metadata! enzyme_dup, ptr, ptr, ...). - assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer); + if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) { + assert!( + llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer + ); + } + // In the case of Dual we don't have assumptions, e.g. f32 would be valid. args.push(next_outer_arg); outer_pos += 2; activity_pos += 1; |
