diff options
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder/autodiff.rs | 20 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/intrinsic.rs | 3 |
2 files changed, 19 insertions, 4 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 78deffa3a7a..566877f4a1e 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -3,8 +3,9 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; -use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; +use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv}; use rustc_middle::{bug, ty}; +use rustc_target::callconv::PassMode; use tracing::debug; use crate::builder::{Builder, PlaceRef, UNNAMED}; @@ -16,9 +17,12 @@ use crate::value::Value; pub(crate) fn adjust_activity_to_abi<'tcx>( tcx: TyCtxt<'tcx>, - fn_ty: Ty<'tcx>, + instance: Instance<'tcx>, + typing_env: TypingEnv<'tcx>, da: &mut Vec<DiffActivity>, ) { + let fn_ty = instance.ty(tcx, typing_env); + if !matches!(fn_ty.kind(), ty::FnDef(..)) { bug!("expected fn def for autodiff, got {:?}", fn_ty); } @@ -27,6 +31,13 @@ pub(crate) fn adjust_activity_to_abi<'tcx>( // All we do is decide how to handle the arguments. let sig = fn_ty.fn_sig(tcx).skip_binder(); + // FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions + let Ok(fn_abi) = + tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty()))) + else { + bug!("failed to get fn_abi of instance with empty varargs"); + }; + let mut new_activities = vec![]; let mut new_positions = vec![]; let mut del_activities = 0; @@ -91,10 +102,13 @@ pub(crate) fn adjust_activity_to_abi<'tcx>( } }; + let pass_mode = &fn_abi.args[i].mode; + // For ZST, just ignore and don't add its activity, as this arg won't be present // in the LLVM passed to Enzyme. + // Some targets pass ZST indirectly in the C ABI, in that case, handle it as a normal arg // FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const` - if layout.is_zst() { + if *pass_mode == PassMode::Ignore { del_activities += 1; da.remove(i); } diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 06c3d8ed6bc..9e6e7606491 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1198,7 +1198,8 @@ fn codegen_autodiff<'ll, 'tcx>( adjust_activity_to_abi( tcx, - fn_source.ty(tcx, TypingEnv::fully_monomorphized()), + fn_source, + TypingEnv::fully_monomorphized(), &mut diff_attrs.input_activity, ); |
