diff options
| -rw-r--r-- | compiler/rustc_builtin_macros/src/autodiff.rs | 139 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder/autodiff.rs | 302 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/context.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/intrinsic.rs | 76 | ||||
| -rw-r--r-- | compiler/rustc_hir_analysis/src/check/intrinsic.rs | 5 | ||||
| -rw-r--r-- | compiler/rustc_span/src/symbol.rs | 1 | ||||
| -rw-r--r-- | library/core/src/intrinsics/mod.rs | 4 |
7 files changed, 284 insertions, 245 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index a662840eda5..3f8585d35bc 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -329,17 +329,22 @@ mod llvm_enzyme { .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) .count() as u32; let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - let d_body = gen_enzyme_body( + + // TODO(Sa4dUs): Remove this and all the related logic + let _d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, &generics, ); + let d_body = + call_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig); + // The first element of it is the name of the function to be generated let asdf = Box::new(ast::Fn { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics, + generics: generics.clone(), contract: None, body: Some(d_body), define_opaque: None, @@ -428,12 +433,15 @@ mod llvm_enzyme { tokens: ts, }); + let vis_clone = vis.clone(); + + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = Box::new(ast::AssocItem { - attrs: thin_vec![d_attr, inline_never], + attrs: thin_vec![d_attr], id: ast::DUMMY_NODE_ID, span, vis, @@ -443,13 +451,13 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Stmt(Box::new(ast::Stmt { @@ -463,7 +471,9 @@ mod llvm_enzyme { } }; - return vec![orig_annotatable, d_annotatable]; + let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone); + + return vec![orig_annotatable, dummy_const_annotatable, d_annotatable]; } // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be @@ -484,6 +494,123 @@ mod llvm_enzyme { ty } + // Generate `autodiff` intrinsic call + // ``` + // std::intrinsics::autodiff(source, diff, (args)) + // ``` + fn call_autodiff( + ecx: &ExtCtxt<'_>, + primal: Ident, + diff: Ident, + span: Span, + d_sig: &FnSig, + ) -> P<ast::Block> { + let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal)); + let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff)); + + let tuple_expr = ecx.expr_tuple( + span, + d_sig + .decl + .inputs + .iter() + .map(|arg| match arg.pat.kind { + PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)), + _ => todo!(), + }) + .collect::<ThinVec<_>>() + .into(), + ); + + let enzyme_path = ecx.path( + span, + vec![ + Ident::from_str("std"), + Ident::from_str("intrinsics"), + Ident::from_str("autodiff"), + ], + ); + let call_expr = ecx.expr_call( + span, + ecx.expr_path(enzyme_path), + vec![primal_path_expr, diff_path_expr, tuple_expr].into(), + ); + + let block = ecx.block_expr(call_expr); + + block + } + + // Generate dummy const to prevent primal function + // from being optimized away before applying enzyme + // ``` + // const _: () = + // { + // #[used] + // pub static DUMMY_PTR: fn_type = primal_fn; + // }; + // ``` + fn gen_dummy_const( + ecx: &ExtCtxt<'_>, + span: Span, + primal: Ident, + sig: FnSig, + generics: Generics, + vis: Visibility, + ) -> Annotatable { + // #[used] + let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used))); + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); + let used_attr = outer_normal_attr(&used_attr, new_id, span); + + // static DUMMY_PTR: <fn_type> = <primal_ident> + let static_ident = Ident::from_str_and_span("DUMMY_PTR", span); + let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy { + safety: sig.header.safety, + ext: sig.header.ext, + generic_params: generics.params, + decl: sig.decl, + decl_span: sig.span, + })); + let static_ty = ecx.ty(span, fn_ptr_ty); + + let static_expr = ecx.expr_path(ecx.path(span, vec![primal])); + let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem { + ident: static_ident, + ty: static_ty, + safety: ast::Safety::Default, + mutability: ast::Mutability::Not, + expr: Some(static_expr), + define_opaque: None, + })); + + let static_item = ast::Item { + attrs: thin_vec![used_attr], + id: ast::DUMMY_NODE_ID, + span, + vis, + kind: static_item_kind, + tokens: None, + }; + + let block_expr = ecx.expr_block(Box::new(ast::Block { + stmts: thin_vec![ecx.stmt_item(span, P(static_item))], + id: ast::DUMMY_NODE_ID, + rules: ast::BlockCheckMode::Default, + span, + tokens: None, + })); + + let const_item = ecx.item_const( + span, + Ident::from_str_and_span("_", span), + ecx.ty(span, ast::TyKind::Tup(thin_vec![])), + block_expr, + ); + + Annotatable::Item(const_item) + } + // Will generate a body of the type: // ``` // { diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 829b3c513c2..66c34fbcfb1 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -3,22 +3,22 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::common::TypeKind; -use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; +use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_errors::FatalError; use rustc_middle::bug; use tracing::{debug, trace}; use crate::back::write::llvm_err; -use crate::builder::{SBuilder, UNNAMED}; +use crate::builder::{Builder, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; -use crate::llvm::{Metadata, True}; +use crate::llvm::{Metadata, True, Type}; use crate::value::Value; use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; -fn get_params(fnc: &Value) -> Vec<&Value> { +fn _get_params(fnc: &Value) -> Vec<&Value> { let param_num = llvm::LLVMCountParams(fnc) as usize; let mut fnc_args: Vec<&Value> = vec![]; fnc_args.reserve(param_num); @@ -29,7 +29,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> { fnc_args } -fn has_sret(fnc: &Value) -> bool { +fn _has_sret(fnc: &Value) -> bool { let num_args = llvm::LLVMCountParams(fnc) as usize; if num_args == 0 { false @@ -48,14 +48,13 @@ fn has_sret(fnc: &Value) -> bool { // need to match those. // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it // using iterators and peek()? -fn match_args_from_caller_to_enzyme<'ll>( +fn match_args_from_caller_to_enzyme<'ll, 'tcx>( cx: &SimpleCx<'ll>, - builder: &SBuilder<'ll, 'll>, + builder: &mut Builder<'_, 'll, 'tcx>, width: u32, args: &mut Vec<&'ll llvm::Value>, inputs: &[DiffActivity], outer_args: &[&'ll llvm::Value], - has_sret: bool, ) { debug!("matching autodiff arguments"); // We now handle the issue that Rust level arguments not always match the llvm-ir level @@ -67,20 +66,12 @@ fn match_args_from_caller_to_enzyme<'ll>( let mut outer_pos: usize = 0; let mut activity_pos = 0; - if has_sret { - // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the - // inner function will still return something. We increase our outer_pos by one, - // and once we're done with all other args we will take the return of the inner call and - // update the sret pointer with it - outer_pos = 1; - } - - let enzyme_const = cx.create_metadata(b"enzyme_const"); - let enzyme_out = cx.create_metadata(b"enzyme_out"); - let enzyme_dup = cx.create_metadata(b"enzyme_dup"); - let enzyme_dupv = cx.create_metadata(b"enzyme_dupv"); - let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed"); - let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv"); + let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); + let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap(); + let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap(); + let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap(); + let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap(); + let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap(); while activity_pos < inputs.len() { let diff_activity = inputs[activity_pos as usize]; @@ -193,92 +184,6 @@ fn match_args_from_caller_to_enzyme<'ll>( } } -// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input -// arguments. We do however need to declare them with their correct return type. -// We already figured the correct return type out in our frontend, when generating the outer_fn, -// so we can now just go ahead and use that. This is not always trivial, e.g. because sret. -// Beyond sret, this article describes our challenges nicely: -// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/> -// I.e. (i32, f32) will get merged into i64, but we don't handle that yet. -fn compute_enzyme_fn_ty<'ll>( - cx: &SimpleCx<'ll>, - attrs: &AutoDiffAttrs, - fn_to_diff: &'ll Value, - outer_fn: &'ll Value, -) -> &'ll llvm::Type { - let fn_ty = cx.get_type_of_global(outer_fn); - let mut ret_ty = cx.get_return_type(fn_ty); - - let has_sret = has_sret(outer_fn); - - if has_sret { - // Now we don't just forward the return type, so we have to figure it out based on the - // primal return type, in combination with the autodiff settings. - let fn_ty = cx.get_type_of_global(fn_to_diff); - let inner_ret_ty = cx.get_return_type(fn_ty); - - let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) }; - if inner_ret_ty == void_ty { - // This indicates that even the inner function has an sret. - // Right now I only look for an sret in the outer function. - // This *probably* needs some extra handling, but I never ran - // into such a case. So I'll wait for user reports to have a test case. - bug!("sret in inner function"); - } - - if attrs.width == 1 { - // Enzyme returns a struct of style: - // `{ original_ret(if requested), float, float, ... }` - let mut struct_elements = vec![]; - if attrs.has_primal_ret() { - struct_elements.push(inner_ret_ty); - } - // Next, we push the list of active floats, since they will be lowered to `enzyme_out`, - // and therefore part of the return struct. - let param_tys = cx.func_params_types(fn_ty); - for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) { - if matches!(act, DiffActivity::Active) { - // Now find the float type at position i based on the fn_ty, - // to know what (f16/f32/f64/...) to add to the struct. - struct_elements.push(param_ty); - } - } - ret_ty = cx.type_struct(&struct_elements, false); - } else { - // First we check if we also have to deal with the primal return. - match attrs.mode { - DiffMode::Forward => match attrs.ret_activity { - DiffActivity::Dual => { - let arr_ty = - unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) }; - ret_ty = arr_ty; - } - DiffActivity::DualOnly => { - let arr_ty = - unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) }; - ret_ty = arr_ty; - } - DiffActivity::Const => { - todo!("Not sure, do we need to do something here?"); - } - _ => { - bug!("unreachable"); - } - }, - DiffMode::Reverse => { - todo!("Handle sret for reverse mode"); - } - _ => { - bug!("unreachable"); - } - } - } - } - - // LLVM can figure out the input types on it's own, so we take a shortcut here. - unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) } -} - /// When differentiating `fn_to_diff`, take a `outer_fn` and generate another /// function with expected naming and calling conventions[^1] which will be /// discovered by the enzyme LLVM pass and its body populated with the differentiated @@ -288,11 +193,15 @@ fn compute_enzyme_fn_ty<'ll>( /// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/> // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. -fn generate_enzyme_call<'ll>( +pub(crate) fn generate_enzyme_call<'ll, 'tcx>( + builder: &mut Builder<'_, 'll, 'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, - outer_fn: &'ll Value, + outer_name: &str, + ret_ty: &'ll Type, + fn_args: &[&'ll Value], attrs: AutoDiffAttrs, + dest: PlaceRef<'tcx, &'ll Value>, ) { // We have to pick the name depending on whether we want forward or reverse mode autodiff. let mut ad_name: String = match attrs.mode { @@ -302,11 +211,9 @@ fn generate_enzyme_call<'ll>( } .to_string(); - // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple + // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple // functions. Unwrap will only panic, if LLVM gave us an invalid string. - let name = llvm::get_value_name(outer_fn); - let outer_fn_name = std::str::from_utf8(&name).unwrap(); - ad_name.push_str(outer_fn_name); + ad_name.push_str(outer_name); // Let us assume the user wrote the following function square: // @@ -317,13 +224,7 @@ fn generate_enzyme_call<'ll>( // ret double %0 // } // ``` - // - // The user now applies autodiff to the function square, in which case fn_to_diff will be `square`. - // Our macro generates the following placeholder code (slightly simplified): - // - // ```llvm // define double @dsquare(double %x) { - // ; placeholder code // return 0.0; // } // ``` @@ -340,120 +241,52 @@ fn generate_enzyme_call<'ll>( // ret double %0 // } // ``` - unsafe { - let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn); - - // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and - // think a bit more about what should go here. - let cc = llvm::LLVMGetFunctionCallConv(outer_fn); - let ad_fn = declare_simple_fn( - cx, - &ad_name, - llvm::CallConv::try_from(cc).expect("invalid callconv"), - llvm::UnnamedAddr::No, - llvm::Visibility::Default, - enzyme_ty, - ); - - // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to - // do it's work. - let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); - attributes::apply_to_llfn(ad_fn, Function, &[attr]); - - // We add a made-up attribute just such that we can recognize it after AD to update - // (no)-inline attributes. We'll then also remove this attribute. - let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker"); - attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]); - - // first, remove all calls from fnc - let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); - let br = llvm::LLVMRustGetTerminator(entry); - llvm::LLVMRustEraseInstFromParent(br); - - let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); - let mut builder = SBuilder::build(cx, entry); - - let num_args = llvm::LLVMCountParams(&fn_to_diff); - let mut args = Vec::with_capacity(num_args as usize + 1); - args.push(fn_to_diff); - - let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return"); - if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { - args.push(cx.get_metadata_value(enzyme_primal_ret)); - } - if attrs.width > 1 { - let enzyme_width = cx.create_metadata(b"enzyme_width"); - args.push(cx.get_metadata_value(enzyme_width)); - args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); - } - - let has_sret = has_sret(outer_fn); - let outer_args: Vec<&llvm::Value> = get_params(outer_fn); - match_args_from_caller_to_enzyme( - &cx, - &builder, - attrs.width, - &mut args, - &attrs.input_activity, - &outer_args, - has_sret, - ); - - let call = builder.call(enzyme_ty, ad_fn, &args, None); - - // This part is a bit iffy. LLVM requires that a call to an inlineable function has some - // metadata attached to it, but we just created this code oota. Given that the - // differentiated function already has partly confusing metadata, and given that this - // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the - // dummy code which we inserted at a higher level. - // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have, - // and how to best improve it for enzyme core and rust-enzyme. - let md_ty = cx.get_md_kind_id("dbg"); - if llvm::LLVMRustHasMetadata(last_inst, md_ty) { - let md = llvm::LLVMRustDIGetInstMetadata(last_inst) - .expect("failed to get instruction metadata"); - let md_todiff = cx.get_metadata_value(md); - llvm::LLVMSetMetadata(call, md_ty, md_todiff); - } else { - // We don't panic, since depending on whether we are in debug or release mode, we might - // have no debug info to copy, which would then be ok. - trace!("no dbg info"); - } - - // Now that we copied the metadata, get rid of dummy code. - llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst); + let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }; + + // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and + // think a bit more about what should go here. + // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now + let cc = 8; + let ad_fn = declare_simple_fn( + cx, + &ad_name, + llvm::CallConv::try_from(cc).expect("invalid callconv"), + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + enzyme_ty, + ); + + // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to + // do it's work. + let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); + attributes::apply_to_llfn(ad_fn, Function, &[attr]); + + let num_args = llvm::LLVMCountParams(&fn_to_diff); + let mut args = Vec::with_capacity(num_args as usize + 1); + args.push(fn_to_diff); + + let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap(); + if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { + args.push(cx.get_metadata_value(enzyme_primal_ret)); + } + if attrs.width > 1 { + let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap(); + args.push(cx.get_metadata_value(enzyme_width)); + args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); + } - if cx.val_ty(call) == cx.type_void() || has_sret { - if has_sret { - // This is what we already have in our outer_fn (shortened): - // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) { - // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>) - // <Here we are, we want to add the following two lines> - // store [4 x double] %7, ptr %0, align 8 - // ret void - // } + match_args_from_caller_to_enzyme( + &cx, + builder, + attrs.width, + &mut args, + &attrs.input_activity, + fn_args, + ); - // now store the result of the enzyme call into the sret pointer. - let sret_ptr = outer_args[0]; - let call_ty = cx.val_ty(call); - if attrs.width == 1 { - assert_eq!(cx.type_kind(call_ty), TypeKind::Struct); - } else { - assert_eq!(cx.type_kind(call_ty), TypeKind::Array); - } - llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr); - } - builder.ret_void(); - } else { - builder.ret(call); - } + let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); - // Let's crash in case that we messed something up above and generated invalid IR. - llvm::LLVMRustVerifyFunction( - outer_fn, - llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction, - ); - } + builder.store_to_place(call, dest.val); } pub(crate) fn differentiate<'ll>( @@ -461,6 +294,7 @@ pub(crate) fn differentiate<'ll>( cgcx: &CodegenContext<LlvmCodegenBackend>, diff_items: Vec<AutoDiffItem>, ) -> Result<(), FatalError> { + // TODO(Sa4dUs): delete all this logic for item in &diff_items { trace!("{}", item); } @@ -480,7 +314,7 @@ pub(crate) fn differentiate<'ll>( for item in diff_items.iter() { let name = item.source.clone(); let fn_def: Option<&llvm::Value> = cx.get_function(&name); - let Some(fn_def) = fn_def else { + let Some(_fn_def) = fn_def else { return Err(llvm_err( diag_handler.handle(), LlvmError::PrepareAutoDiff { @@ -492,7 +326,7 @@ pub(crate) fn differentiate<'ll>( }; debug!(?item.target); let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); - let Some(fn_target) = fn_target else { + let Some(_fn_target) = fn_target else { return Err(llvm_err( diag_handler.handle(), LlvmError::PrepareAutoDiff { @@ -503,7 +337,7 @@ pub(crate) fn differentiate<'ll>( )); }; - generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); + // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); } // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 27ae729a531..8eb15571e82 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -660,7 +660,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { } } impl<'ll> SimpleCx<'ll> { - pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type { + pub(crate) fn _get_return_type(&self, ty: &'ll Type) -> &'ll Type { assert_eq!(self.type_kind(ty), TypeKind::Function); unsafe { llvm::LLVMGetReturnType(ty) } } diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 7b27e496986..1102fc1d0c8 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -3,23 +3,26 @@ use std::cmp::Ordering; use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size}; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; +use rustc_codegen_ssa::codegen_attrs::autodiff_attrs; use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue}; use rustc_codegen_ssa::traits::*; use rustc_hir as hir; +use rustc_hir::def_id::LOCAL_CRATE; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; -use rustc_middle::ty::{self, GenericArgsRef, Ty}; +use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; -use rustc_symbol_mangling::mangle_internal_symbol; +use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; use rustc_target::spec::PanicStrategy; use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; +use crate::builder::autodiff::generate_enzyme_call; use crate::context::CodegenCx; use crate::llvm::{self, Metadata}; use crate::type_::Type; @@ -174,10 +177,17 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { span: Span, ) -> Result<(), ty::Instance<'tcx>> { let tcx = self.tcx; + let callee_ty = instance.ty(tcx, self.typing_env()); - let name = tcx.item_name(instance.def_id()); let fn_args = instance.args; + let sig = callee_ty.fn_sig(tcx); + let sig = tcx.normalize_erasing_late_bound_regions(self.typing_env(), sig); + let ret_ty = sig.output(); + let name = tcx.item_name(instance.def_id()); + + let llret_ty = self.layout_of(ret_ty).llvm_type(self); + let simple = call_simple_intrinsic(self, name, args); let llval = match name { _ if simple.is_some() => simple.unwrap(), @@ -189,6 +199,66 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[ptr, args[1].immediate()], ) } + sym::autodiff => { + let val_arr: Vec<&'ll Value> = match args[2].val { + crate::intrinsic::OperandValue::Ref(ref place_value) => { + let mut ret_arr = vec![]; + let tuple_place = PlaceRef { val: *place_value, layout: args[2].layout }; + + for i in 0..tuple_place.layout.layout.0.fields.count() { + let field_place = tuple_place.project_field(self, i); + let field_layout = tuple_place.layout.field(self, i); + let llvm_ty = field_layout.llvm_type(self.cx); + + let field_val = + self.load(llvm_ty, field_place.val.llval, field_place.val.align); + + ret_arr.push(field_val) + } + + ret_arr + } + crate::intrinsic::OperandValue::Pair(v1, v2) => vec![v1, v2], + OperandValue::Immediate(v) => vec![v], + OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"), + }; + + // Get source, diff, and attrs + let source_id = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, _) => def_id, + _ => bug!("invalid args"), + }; + let fn_source = Instance::mono(tcx, *source_id); + let source_symbol = + symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); + let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); + let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; + + let diff_id = match fn_args.into_type_list(tcx)[1].kind() { + ty::FnDef(def_id, _) => def_id, + _ => bug!("invalid args"), + }; + let fn_diff = Instance::mono(tcx, *diff_id); + let diff_symbol = + symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); + + let diff_attrs = autodiff_attrs(tcx, *diff_id); + let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; + + // Build body + generate_enzyme_call( + self, + self.cx, + fn_to_diff, + &diff_symbol, + llret_ty, + &val_arr, + diff_attrs.clone(), + result, + ); + + return Ok(()); + } sym::is_val_statically_known => { if let OperandValue::Immediate(imm) = args[0].val { self.call_intrinsic( diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 4441dd6ebd6..46371cfe591 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -135,6 +135,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi | sym::round_ties_even_f32 | sym::round_ties_even_f64 | sym::round_ties_even_f128 + | sym::autodiff | sym::const_eval_select => hir::Safety::Safe, _ => hir::Safety::Unsafe, }; @@ -171,6 +172,8 @@ pub(crate) fn check_intrinsic_type( } }; + let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff); + let bound_vars = tcx.mk_bound_variable_kinds(&[ ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), @@ -195,9 +198,9 @@ pub(crate) fn check_intrinsic_type( (Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty) }; - let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { + sym::autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)), sym::abort => (0, 0, vec![], tcx.types.never), sym::unreachable => (0, 0, vec![], tcx.types.never), sym::breakpoint => (0, 0, vec![], tcx.types.unit), diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 416ce27367e..f7a8258a9d8 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -542,6 +542,7 @@ symbols! { audit_that, augmented_assignments, auto_traits, + autodiff, autodiff_forward, autodiff_reverse, automatically_derived, diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index 7228ad0ed6d..6c389c55a5f 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3157,6 +3157,10 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64; #[rustc_intrinsic] pub const unsafe fn copysignf128(x: f128, y: f128) -> f128; +#[rustc_nounwind] +#[rustc_intrinsic] +pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) -> R; + /// Inform Miri that a given pointer definitely has a certain alignment. #[cfg(miri)] #[rustc_allow_const_fn_unstable(const_eval_select)] |
