about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs139
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs302
-rw-r--r--compiler/rustc_codegen_llvm/src/context.rs2
-rw-r--r--compiler/rustc_codegen_llvm/src/intrinsic.rs76
-rw-r--r--compiler/rustc_hir_analysis/src/check/intrinsic.rs5
-rw-r--r--compiler/rustc_span/src/symbol.rs1
-rw-r--r--library/core/src/intrinsics/mod.rs4
7 files changed, 284 insertions, 245 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index a662840eda5..3f8585d35bc 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -329,17 +329,22 @@ mod llvm_enzyme {
             .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
             .count() as u32;
         let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
-        let d_body = gen_enzyme_body(
+
+        // TODO(Sa4dUs): Remove this and all the related logic
+        let _d_body = gen_enzyme_body(
             ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
             &generics,
         );
 
+        let d_body =
+            call_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig);
+
         // The first element of it is the name of the function to be generated
         let asdf = Box::new(ast::Fn {
             defaultness: ast::Defaultness::Final,
             sig: d_sig,
             ident: first_ident(&meta_item_vec[0]),
-            generics,
+            generics: generics.clone(),
             contract: None,
             body: Some(d_body),
             define_opaque: None,
@@ -428,12 +433,15 @@ mod llvm_enzyme {
             tokens: ts,
         });
 
+        let vis_clone = vis.clone();
+
+        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
         let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
         let d_annotatable = match &item {
             Annotatable::AssocItem(_, _) => {
                 let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
                 let d_fn = Box::new(ast::AssocItem {
-                    attrs: thin_vec![d_attr, inline_never],
+                    attrs: thin_vec![d_attr],
                     id: ast::DUMMY_NODE_ID,
                     span,
                     vis,
@@ -443,13 +451,13 @@ mod llvm_enzyme {
                 Annotatable::AssocItem(d_fn, Impl { of_trait: false })
             }
             Annotatable::Item(_) => {
-                let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
+                let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
                 d_fn.vis = vis;
 
                 Annotatable::Item(d_fn)
             }
             Annotatable::Stmt(_) => {
-                let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
+                let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
                 d_fn.vis = vis;
 
                 Annotatable::Stmt(Box::new(ast::Stmt {
@@ -463,7 +471,9 @@ mod llvm_enzyme {
             }
         };
 
-        return vec![orig_annotatable, d_annotatable];
+        let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone);
+
+        return vec![orig_annotatable, dummy_const_annotatable, d_annotatable];
     }
 
     // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
@@ -484,6 +494,123 @@ mod llvm_enzyme {
         ty
     }
 
+    // Generate `autodiff` intrinsic call
+    // ```
+    // std::intrinsics::autodiff(source, diff, (args))
+    // ```
+    fn call_autodiff(
+        ecx: &ExtCtxt<'_>,
+        primal: Ident,
+        diff: Ident,
+        span: Span,
+        d_sig: &FnSig,
+    ) -> P<ast::Block> {
+        let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal));
+        let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff));
+
+        let tuple_expr = ecx.expr_tuple(
+            span,
+            d_sig
+                .decl
+                .inputs
+                .iter()
+                .map(|arg| match arg.pat.kind {
+                    PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
+                    _ => todo!(),
+                })
+                .collect::<ThinVec<_>>()
+                .into(),
+        );
+
+        let enzyme_path = ecx.path(
+            span,
+            vec![
+                Ident::from_str("std"),
+                Ident::from_str("intrinsics"),
+                Ident::from_str("autodiff"),
+            ],
+        );
+        let call_expr = ecx.expr_call(
+            span,
+            ecx.expr_path(enzyme_path),
+            vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
+        );
+
+        let block = ecx.block_expr(call_expr);
+
+        block
+    }
+
+    // Generate dummy const to prevent primal function
+    // from being optimized away before applying enzyme
+    // ```
+    // const _: () =
+    // {
+    //     #[used]
+    //     pub static DUMMY_PTR: fn_type = primal_fn;
+    // };
+    // ```
+    fn gen_dummy_const(
+        ecx: &ExtCtxt<'_>,
+        span: Span,
+        primal: Ident,
+        sig: FnSig,
+        generics: Generics,
+        vis: Visibility,
+    ) -> Annotatable {
+        // #[used]
+        let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used)));
+        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
+        let used_attr = outer_normal_attr(&used_attr, new_id, span);
+
+        // static DUMMY_PTR: <fn_type> = <primal_ident>
+        let static_ident = Ident::from_str_and_span("DUMMY_PTR", span);
+        let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy {
+            safety: sig.header.safety,
+            ext: sig.header.ext,
+            generic_params: generics.params,
+            decl: sig.decl,
+            decl_span: sig.span,
+        }));
+        let static_ty = ecx.ty(span, fn_ptr_ty);
+
+        let static_expr = ecx.expr_path(ecx.path(span, vec![primal]));
+        let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem {
+            ident: static_ident,
+            ty: static_ty,
+            safety: ast::Safety::Default,
+            mutability: ast::Mutability::Not,
+            expr: Some(static_expr),
+            define_opaque: None,
+        }));
+
+        let static_item = ast::Item {
+            attrs: thin_vec![used_attr],
+            id: ast::DUMMY_NODE_ID,
+            span,
+            vis,
+            kind: static_item_kind,
+            tokens: None,
+        };
+
+        let block_expr = ecx.expr_block(Box::new(ast::Block {
+            stmts: thin_vec![ecx.stmt_item(span, P(static_item))],
+            id: ast::DUMMY_NODE_ID,
+            rules: ast::BlockCheckMode::Default,
+            span,
+            tokens: None,
+        }));
+
+        let const_item = ecx.item_const(
+            span,
+            Ident::from_str_and_span("_", span),
+            ecx.ty(span, ast::TyKind::Tup(thin_vec![])),
+            block_expr,
+        );
+
+        Annotatable::Item(const_item)
+    }
+
     // Will generate a body of the type:
     // ```
     // {
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 829b3c513c2..66c34fbcfb1 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -3,22 +3,22 @@ use std::ptr;
 use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
 use rustc_codegen_ssa::ModuleCodegen;
 use rustc_codegen_ssa::common::TypeKind;
-use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
+use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
 use rustc_errors::FatalError;
 use rustc_middle::bug;
 use tracing::{debug, trace};
 
 use crate::back::write::llvm_err;
-use crate::builder::{SBuilder, UNNAMED};
+use crate::builder::{Builder, PlaceRef, UNNAMED};
 use crate::context::SimpleCx;
 use crate::declare::declare_simple_fn;
 use crate::errors::{AutoDiffWithoutEnable, LlvmError};
 use crate::llvm::AttributePlace::Function;
-use crate::llvm::{Metadata, True};
+use crate::llvm::{Metadata, True, Type};
 use crate::value::Value;
 use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
 
-fn get_params(fnc: &Value) -> Vec<&Value> {
+fn _get_params(fnc: &Value) -> Vec<&Value> {
     let param_num = llvm::LLVMCountParams(fnc) as usize;
     let mut fnc_args: Vec<&Value> = vec![];
     fnc_args.reserve(param_num);
@@ -29,7 +29,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
     fnc_args
 }
 
-fn has_sret(fnc: &Value) -> bool {
+fn _has_sret(fnc: &Value) -> bool {
     let num_args = llvm::LLVMCountParams(fnc) as usize;
     if num_args == 0 {
         false
@@ -48,14 +48,13 @@ fn has_sret(fnc: &Value) -> bool {
 // need to match those.
 // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
 // using iterators and peek()?
-fn match_args_from_caller_to_enzyme<'ll>(
+fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
     cx: &SimpleCx<'ll>,
-    builder: &SBuilder<'ll, 'll>,
+    builder: &mut Builder<'_, 'll, 'tcx>,
     width: u32,
     args: &mut Vec<&'ll llvm::Value>,
     inputs: &[DiffActivity],
     outer_args: &[&'ll llvm::Value],
-    has_sret: bool,
 ) {
     debug!("matching autodiff arguments");
     // We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -67,20 +66,12 @@ fn match_args_from_caller_to_enzyme<'ll>(
     let mut outer_pos: usize = 0;
     let mut activity_pos = 0;
 
-    if has_sret {
-        // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
-        // inner function will still return something. We increase our outer_pos by one,
-        // and once we're done with all other args we will take the return of the inner call and
-        // update the sret pointer with it
-        outer_pos = 1;
-    }
-
-    let enzyme_const = cx.create_metadata(b"enzyme_const");
-    let enzyme_out = cx.create_metadata(b"enzyme_out");
-    let enzyme_dup = cx.create_metadata(b"enzyme_dup");
-    let enzyme_dupv = cx.create_metadata(b"enzyme_dupv");
-    let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed");
-    let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv");
+    let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
+    let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
+    let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
+    let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
+    let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
+    let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap();
 
     while activity_pos < inputs.len() {
         let diff_activity = inputs[activity_pos as usize];
@@ -193,92 +184,6 @@ fn match_args_from_caller_to_enzyme<'ll>(
     }
 }
 
-// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
-// arguments. We do however need to declare them with their correct return type.
-// We already figured the correct return type out in our frontend, when generating the outer_fn,
-// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
-// Beyond sret, this article describes our challenges nicely:
-// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
-// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
-fn compute_enzyme_fn_ty<'ll>(
-    cx: &SimpleCx<'ll>,
-    attrs: &AutoDiffAttrs,
-    fn_to_diff: &'ll Value,
-    outer_fn: &'ll Value,
-) -> &'ll llvm::Type {
-    let fn_ty = cx.get_type_of_global(outer_fn);
-    let mut ret_ty = cx.get_return_type(fn_ty);
-
-    let has_sret = has_sret(outer_fn);
-
-    if has_sret {
-        // Now we don't just forward the return type, so we have to figure it out based on the
-        // primal return type, in combination with the autodiff settings.
-        let fn_ty = cx.get_type_of_global(fn_to_diff);
-        let inner_ret_ty = cx.get_return_type(fn_ty);
-
-        let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
-        if inner_ret_ty == void_ty {
-            // This indicates that even the inner function has an sret.
-            // Right now I only look for an sret in the outer function.
-            // This *probably* needs some extra handling, but I never ran
-            // into such a case. So I'll wait for user reports to have a test case.
-            bug!("sret in inner function");
-        }
-
-        if attrs.width == 1 {
-            // Enzyme returns a struct of style:
-            // `{ original_ret(if requested), float, float, ... }`
-            let mut struct_elements = vec![];
-            if attrs.has_primal_ret() {
-                struct_elements.push(inner_ret_ty);
-            }
-            // Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
-            // and therefore part of the return struct.
-            let param_tys = cx.func_params_types(fn_ty);
-            for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
-                if matches!(act, DiffActivity::Active) {
-                    // Now find the float type at position i based on the fn_ty,
-                    // to know what (f16/f32/f64/...) to add to the struct.
-                    struct_elements.push(param_ty);
-                }
-            }
-            ret_ty = cx.type_struct(&struct_elements, false);
-        } else {
-            // First we check if we also have to deal with the primal return.
-            match attrs.mode {
-                DiffMode::Forward => match attrs.ret_activity {
-                    DiffActivity::Dual => {
-                        let arr_ty =
-                            unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
-                        ret_ty = arr_ty;
-                    }
-                    DiffActivity::DualOnly => {
-                        let arr_ty =
-                            unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
-                        ret_ty = arr_ty;
-                    }
-                    DiffActivity::Const => {
-                        todo!("Not sure, do we need to do something here?");
-                    }
-                    _ => {
-                        bug!("unreachable");
-                    }
-                },
-                DiffMode::Reverse => {
-                    todo!("Handle sret for reverse mode");
-                }
-                _ => {
-                    bug!("unreachable");
-                }
-            }
-        }
-    }
-
-    // LLVM can figure out the input types on it's own, so we take a shortcut here.
-    unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
-}
-
 /// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
 /// function with expected naming and calling conventions[^1] which will be
 /// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -288,11 +193,15 @@ fn compute_enzyme_fn_ty<'ll>(
 /// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
 // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
 // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
-fn generate_enzyme_call<'ll>(
+pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
+    builder: &mut Builder<'_, 'll, 'tcx>,
     cx: &SimpleCx<'ll>,
     fn_to_diff: &'ll Value,
-    outer_fn: &'ll Value,
+    outer_name: &str,
+    ret_ty: &'ll Type,
+    fn_args: &[&'ll Value],
     attrs: AutoDiffAttrs,
+    dest: PlaceRef<'tcx, &'ll Value>,
 ) {
     // We have to pick the name depending on whether we want forward or reverse mode autodiff.
     let mut ad_name: String = match attrs.mode {
@@ -302,11 +211,9 @@ fn generate_enzyme_call<'ll>(
     }
     .to_string();
 
-    // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
+    // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
     // functions. Unwrap will only panic, if LLVM gave us an invalid string.
-    let name = llvm::get_value_name(outer_fn);
-    let outer_fn_name = std::str::from_utf8(&name).unwrap();
-    ad_name.push_str(outer_fn_name);
+    ad_name.push_str(outer_name);
 
     // Let us assume the user wrote the following function square:
     //
@@ -317,13 +224,7 @@ fn generate_enzyme_call<'ll>(
     //  ret double %0
     // }
     // ```
-    //
-    // The user now applies autodiff to the function square, in which case fn_to_diff will be `square`.
-    // Our macro generates the following placeholder code (slightly simplified):
-    //
-    // ```llvm
     // define double @dsquare(double %x) {
-    //  ; placeholder code
     //  return 0.0;
     // }
     // ```
@@ -340,120 +241,52 @@ fn generate_enzyme_call<'ll>(
     //   ret double %0
     // }
     // ```
-    unsafe {
-        let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
-
-        // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
-        // think a bit more about what should go here.
-        let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
-        let ad_fn = declare_simple_fn(
-            cx,
-            &ad_name,
-            llvm::CallConv::try_from(cc).expect("invalid callconv"),
-            llvm::UnnamedAddr::No,
-            llvm::Visibility::Default,
-            enzyme_ty,
-        );
-
-        // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
-        // do it's work.
-        let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
-        attributes::apply_to_llfn(ad_fn, Function, &[attr]);
-
-        // We add a made-up attribute just such that we can recognize it after AD to update
-        // (no)-inline attributes. We'll then also remove this attribute.
-        let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
-        attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]);
-
-        // first, remove all calls from fnc
-        let entry = llvm::LLVMGetFirstBasicBlock(outer_fn);
-        let br = llvm::LLVMRustGetTerminator(entry);
-        llvm::LLVMRustEraseInstFromParent(br);
-
-        let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
-        let mut builder = SBuilder::build(cx, entry);
-
-        let num_args = llvm::LLVMCountParams(&fn_to_diff);
-        let mut args = Vec::with_capacity(num_args as usize + 1);
-        args.push(fn_to_diff);
-
-        let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return");
-        if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
-            args.push(cx.get_metadata_value(enzyme_primal_ret));
-        }
-        if attrs.width > 1 {
-            let enzyme_width = cx.create_metadata(b"enzyme_width");
-            args.push(cx.get_metadata_value(enzyme_width));
-            args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
-        }
-
-        let has_sret = has_sret(outer_fn);
-        let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
-        match_args_from_caller_to_enzyme(
-            &cx,
-            &builder,
-            attrs.width,
-            &mut args,
-            &attrs.input_activity,
-            &outer_args,
-            has_sret,
-        );
-
-        let call = builder.call(enzyme_ty, ad_fn, &args, None);
-
-        // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
-        // metadata attached to it, but we just created this code oota. Given that the
-        // differentiated function already has partly confusing metadata, and given that this
-        // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
-        // dummy code which we inserted at a higher level.
-        // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have,
-        // and how to best improve it for enzyme core and rust-enzyme.
-        let md_ty = cx.get_md_kind_id("dbg");
-        if llvm::LLVMRustHasMetadata(last_inst, md_ty) {
-            let md = llvm::LLVMRustDIGetInstMetadata(last_inst)
-                .expect("failed to get instruction metadata");
-            let md_todiff = cx.get_metadata_value(md);
-            llvm::LLVMSetMetadata(call, md_ty, md_todiff);
-        } else {
-            // We don't panic, since depending on whether we are in debug or release mode, we might
-            // have no debug info to copy, which would then be ok.
-            trace!("no dbg info");
-        }
-
-        // Now that we copied the metadata, get rid of dummy code.
-        llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
+    let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) };
+
+    // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
+    // think a bit more about what should go here.
+    // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now
+    let cc = 8;
+    let ad_fn = declare_simple_fn(
+        cx,
+        &ad_name,
+        llvm::CallConv::try_from(cc).expect("invalid callconv"),
+        llvm::UnnamedAddr::No,
+        llvm::Visibility::Default,
+        enzyme_ty,
+    );
+
+    // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
+    // do it's work.
+    let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
+    attributes::apply_to_llfn(ad_fn, Function, &[attr]);
+
+    let num_args = llvm::LLVMCountParams(&fn_to_diff);
+    let mut args = Vec::with_capacity(num_args as usize + 1);
+    args.push(fn_to_diff);
+
+    let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap();
+    if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
+        args.push(cx.get_metadata_value(enzyme_primal_ret));
+    }
+    if attrs.width > 1 {
+        let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
+        args.push(cx.get_metadata_value(enzyme_width));
+        args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
+    }
 
-        if cx.val_ty(call) == cx.type_void() || has_sret {
-            if has_sret {
-                // This is what we already have in our outer_fn (shortened):
-                // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
-                //   %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
-                //   <Here we are, we want to add the following two lines>
-                //   store [4 x double] %7, ptr %0, align 8
-                //   ret void
-                // }
+    match_args_from_caller_to_enzyme(
+        &cx,
+        builder,
+        attrs.width,
+        &mut args,
+        &attrs.input_activity,
+        fn_args,
+    );
 
-                // now store the result of the enzyme call into the sret pointer.
-                let sret_ptr = outer_args[0];
-                let call_ty = cx.val_ty(call);
-                if attrs.width == 1 {
-                    assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
-                } else {
-                    assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
-                }
-                llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
-            }
-            builder.ret_void();
-        } else {
-            builder.ret(call);
-        }
+    let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
 
-        // Let's crash in case that we messed something up above and generated invalid IR.
-        llvm::LLVMRustVerifyFunction(
-            outer_fn,
-            llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction,
-        );
-    }
+    builder.store_to_place(call, dest.val);
 }
 
 pub(crate) fn differentiate<'ll>(
@@ -461,6 +294,7 @@ pub(crate) fn differentiate<'ll>(
     cgcx: &CodegenContext<LlvmCodegenBackend>,
     diff_items: Vec<AutoDiffItem>,
 ) -> Result<(), FatalError> {
+    // TODO(Sa4dUs): delete all this logic
     for item in &diff_items {
         trace!("{}", item);
     }
@@ -480,7 +314,7 @@ pub(crate) fn differentiate<'ll>(
     for item in diff_items.iter() {
         let name = item.source.clone();
         let fn_def: Option<&llvm::Value> = cx.get_function(&name);
-        let Some(fn_def) = fn_def else {
+        let Some(_fn_def) = fn_def else {
             return Err(llvm_err(
                 diag_handler.handle(),
                 LlvmError::PrepareAutoDiff {
@@ -492,7 +326,7 @@ pub(crate) fn differentiate<'ll>(
         };
         debug!(?item.target);
         let fn_target: Option<&llvm::Value> = cx.get_function(&item.target);
-        let Some(fn_target) = fn_target else {
+        let Some(_fn_target) = fn_target else {
             return Err(llvm_err(
                 diag_handler.handle(),
                 LlvmError::PrepareAutoDiff {
@@ -503,7 +337,7 @@ pub(crate) fn differentiate<'ll>(
             ));
         };
 
-        generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
+        // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
     }
 
     // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs
index 27ae729a531..8eb15571e82 100644
--- a/compiler/rustc_codegen_llvm/src/context.rs
+++ b/compiler/rustc_codegen_llvm/src/context.rs
@@ -660,7 +660,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
     }
 }
 impl<'ll> SimpleCx<'ll> {
-    pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
+    pub(crate) fn _get_return_type(&self, ty: &'ll Type) -> &'ll Type {
         assert_eq!(self.type_kind(ty), TypeKind::Function);
         unsafe { llvm::LLVMGetReturnType(ty) }
     }
diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs
index 7b27e496986..1102fc1d0c8 100644
--- a/compiler/rustc_codegen_llvm/src/intrinsic.rs
+++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs
@@ -3,23 +3,26 @@ use std::cmp::Ordering;
 
 use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size};
 use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
+use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
 use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
 use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
 use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
 use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
 use rustc_codegen_ssa::traits::*;
 use rustc_hir as hir;
+use rustc_hir::def_id::LOCAL_CRATE;
 use rustc_middle::mir::BinOp;
 use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
-use rustc_middle::ty::{self, GenericArgsRef, Ty};
+use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty};
 use rustc_middle::{bug, span_bug};
 use rustc_span::{Span, Symbol, sym};
-use rustc_symbol_mangling::mangle_internal_symbol;
+use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
 use rustc_target::spec::PanicStrategy;
 use tracing::debug;
 
 use crate::abi::FnAbiLlvmExt;
 use crate::builder::Builder;
+use crate::builder::autodiff::generate_enzyme_call;
 use crate::context::CodegenCx;
 use crate::llvm::{self, Metadata};
 use crate::type_::Type;
@@ -174,10 +177,17 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
         span: Span,
     ) -> Result<(), ty::Instance<'tcx>> {
         let tcx = self.tcx;
+        let callee_ty = instance.ty(tcx, self.typing_env());
 
-        let name = tcx.item_name(instance.def_id());
         let fn_args = instance.args;
 
+        let sig = callee_ty.fn_sig(tcx);
+        let sig = tcx.normalize_erasing_late_bound_regions(self.typing_env(), sig);
+        let ret_ty = sig.output();
+        let name = tcx.item_name(instance.def_id());
+
+        let llret_ty = self.layout_of(ret_ty).llvm_type(self);
+
         let simple = call_simple_intrinsic(self, name, args);
         let llval = match name {
             _ if simple.is_some() => simple.unwrap(),
@@ -189,6 +199,66 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                     &[ptr, args[1].immediate()],
                 )
             }
+            sym::autodiff => {
+                let val_arr: Vec<&'ll Value> = match args[2].val {
+                    crate::intrinsic::OperandValue::Ref(ref place_value) => {
+                        let mut ret_arr = vec![];
+                        let tuple_place = PlaceRef { val: *place_value, layout: args[2].layout };
+
+                        for i in 0..tuple_place.layout.layout.0.fields.count() {
+                            let field_place = tuple_place.project_field(self, i);
+                            let field_layout = tuple_place.layout.field(self, i);
+                            let llvm_ty = field_layout.llvm_type(self.cx);
+
+                            let field_val =
+                                self.load(llvm_ty, field_place.val.llval, field_place.val.align);
+
+                            ret_arr.push(field_val)
+                        }
+
+                        ret_arr
+                    }
+                    crate::intrinsic::OperandValue::Pair(v1, v2) => vec![v1, v2],
+                    OperandValue::Immediate(v) => vec![v],
+                    OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"),
+                };
+
+                // Get source, diff, and attrs
+                let source_id = match fn_args.into_type_list(tcx)[0].kind() {
+                    ty::FnDef(def_id, _) => def_id,
+                    _ => bug!("invalid args"),
+                };
+                let fn_source = Instance::mono(tcx, *source_id);
+                let source_symbol =
+                    symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
+                let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol);
+                let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") };
+
+                let diff_id = match fn_args.into_type_list(tcx)[1].kind() {
+                    ty::FnDef(def_id, _) => def_id,
+                    _ => bug!("invalid args"),
+                };
+                let fn_diff = Instance::mono(tcx, *diff_id);
+                let diff_symbol =
+                    symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
+
+                let diff_attrs = autodiff_attrs(tcx, *diff_id);
+                let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") };
+
+                // Build body
+                generate_enzyme_call(
+                    self,
+                    self.cx,
+                    fn_to_diff,
+                    &diff_symbol,
+                    llret_ty,
+                    &val_arr,
+                    diff_attrs.clone(),
+                    result,
+                );
+
+                return Ok(());
+            }
             sym::is_val_statically_known => {
                 if let OperandValue::Immediate(imm) = args[0].val {
                     self.call_intrinsic(
diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs
index 4441dd6ebd6..46371cfe591 100644
--- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs
+++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs
@@ -135,6 +135,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
         | sym::round_ties_even_f32
         | sym::round_ties_even_f64
         | sym::round_ties_even_f128
+        | sym::autodiff
         | sym::const_eval_select => hir::Safety::Safe,
         _ => hir::Safety::Unsafe,
     };
@@ -171,6 +172,8 @@ pub(crate) fn check_intrinsic_type(
         }
     };
 
+    let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff);
+
     let bound_vars = tcx.mk_bound_variable_kinds(&[
         ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon),
         ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon),
@@ -195,9 +198,9 @@ pub(crate) fn check_intrinsic_type(
         (Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty)
     };
 
-    let safety = intrinsic_operation_unsafety(tcx, intrinsic_id);
     let n_lts = 0;
     let (n_tps, n_cts, inputs, output) = match intrinsic_name {
+        sym::autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)),
         sym::abort => (0, 0, vec![], tcx.types.never),
         sym::unreachable => (0, 0, vec![], tcx.types.never),
         sym::breakpoint => (0, 0, vec![], tcx.types.unit),
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index 416ce27367e..f7a8258a9d8 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -542,6 +542,7 @@ symbols! {
         audit_that,
         augmented_assignments,
         auto_traits,
+        autodiff,
         autodiff_forward,
         autodiff_reverse,
         automatically_derived,
diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs
index 7228ad0ed6d..6c389c55a5f 100644
--- a/library/core/src/intrinsics/mod.rs
+++ b/library/core/src/intrinsics/mod.rs
@@ -3157,6 +3157,10 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64;
 #[rustc_intrinsic]
 pub const unsafe fn copysignf128(x: f128, y: f128) -> f128;
 
+#[rustc_nounwind]
+#[rustc_intrinsic]
+pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) -> R;
+
 /// Inform Miri that a given pointer definitely has a certain alignment.
 #[cfg(miri)]
 #[rustc_allow_const_fn_unstable(const_eval_select)]