about summary refs log tree commit diff
path: root/compiler/rustc_builtin_macros/src
diff options
context:
space:
mode:
authorMarcelo Domínguez <dmmarcelo27@gmail.com>2025-08-14 15:27:57 +0000
committerMarcelo Domínguez <dmmarcelo27@gmail.com>2025-08-14 16:30:15 +0000
commit250d77e5d72fde69a6406050a3b037635f685378 (patch)
tree67749136fca27852b5fb784c864f7d3564a42a09 /compiler/rustc_builtin_macros/src
parent5c631041aa0b0ad9e161b966b78e6dfdb8011023 (diff)
downloadrust-250d77e5d72fde69a6406050a3b037635f685378.tar.gz
rust-250d77e5d72fde69a6406050a3b037635f685378.zip
Complete functionality and general cleanup
Diffstat (limited to 'compiler/rustc_builtin_macros/src')
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs492
1 files changed, 91 insertions, 401 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 3f8585d35bc..c260dca87c0 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -15,11 +15,12 @@ mod llvm_enzyme {
     use rustc_ast::tokenstream::*;
     use rustc_ast::visit::AssocCtxt::*;
     use rustc_ast::{
-        self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
-        MetaItemInner, PatKind, QSelf, TyKind, Visibility,
+        self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,
+        FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,
+        MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility,
     };
     use rustc_expand::base::{Annotatable, ExtCtxt};
-    use rustc_span::{Ident, Span, Symbol, kw, sym};
+    use rustc_span::{Ident, Span, Symbol, sym};
     use thin_vec::{ThinVec, thin_vec};
     use tracing::{debug, trace};
 
@@ -179,11 +180,8 @@ mod llvm_enzyme {
     }
 
     /// We expand the autodiff macro to generate a new placeholder function which passes
-    /// type-checking and can be called by users. The function body of the placeholder function will
-    /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
-    /// should just prevent early inlining and optimizations which alter the function signature.
-    /// The exact signature of the generated function depends on the configuration provided by the
-    /// user, but here is an example:
+    /// type-checking and can be called by users. The exact signature of the generated function
+    /// depends on the configuration provided by the user, but here is an example:
     ///
     /// ```
     /// #[autodiff(cos_box, Reverse, Duplicated, Active)]
@@ -199,14 +197,8 @@ mod llvm_enzyme {
     ///     f32::sin(**x)
     /// }
     /// #[rustc_autodiff(Reverse, Duplicated, Active)]
-    /// #[inline(never)]
     /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
-    ///     unsafe {
-    ///         asm!("NOP");
-    ///     };
-    ///     ::core::hint::black_box(sin(x));
-    ///     ::core::hint::black_box((dx, dret));
-    ///     ::core::hint::black_box(sin(x))
+    ///     std::intrinsics::autodiff(sin::<>, cos_box::<>, (x, dx, dret))
     /// }
     /// ```
     /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@@ -227,16 +219,24 @@ mod llvm_enzyme {
         // first get information about the annotable item: visibility, signature, name and generic
         // parameters.
         // these will be used to generate the differentiated version of the function
-        let Some((vis, sig, primal, generics)) = (match &item {
-            Annotatable::Item(iitem) => extract_item_info(iitem),
+        let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item {
+            Annotatable::Item(iitem) => {
+                extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
+            }
             Annotatable::Stmt(stmt) => match &stmt.kind {
-                ast::StmtKind::Item(iitem) => extract_item_info(iitem),
+                ast::StmtKind::Item(iitem) => {
+                    extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
+                }
                 _ => None,
             },
-            Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
-                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
-                    Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
-                }
+            Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind {
+                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((
+                    assoc_item.vis.clone(),
+                    sig.clone(),
+                    ident.clone(),
+                    generics.clone(),
+                    *of_trait,
+                )),
                 _ => None,
             },
             _ => None,
@@ -254,7 +254,6 @@ mod llvm_enzyme {
         };
 
         let has_ret = has_ret(&sig.decl.output);
-        let sig_span = ecx.with_call_site_ctxt(sig.span);
 
         // create TokenStream from vec elemtents:
         // meta_item doesn't have a .tokens field
@@ -323,28 +322,27 @@ mod llvm_enzyme {
         }
         let span = ecx.with_def_site_ctxt(expand_span);
 
-        let n_active: u32 = x
-            .input_activity
-            .iter()
-            .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
-            .count() as u32;
-        let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
-
-        // TODO(Sa4dUs): Remove this and all the related logic
-        let _d_body = gen_enzyme_body(
-            ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
-            &generics,
-        );
+        let d_sig = gen_enzyme_decl(ecx, &sig, &x, span);
 
-        let d_body =
-            call_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig);
+        let d_body = ecx.block(
+            span,
+            thin_vec![call_autodiff(
+                ecx,
+                primal,
+                first_ident(&meta_item_vec[0]),
+                span,
+                &d_sig,
+                &generics,
+                impl_of_trait,
+            )],
+        );
 
         // The first element of it is the name of the function to be generated
-        let asdf = Box::new(ast::Fn {
+        let d_fn = Box::new(ast::Fn {
             defaultness: ast::Defaultness::Final,
             sig: d_sig,
             ident: first_ident(&meta_item_vec[0]),
-            generics: generics.clone(),
+            generics,
             contract: None,
             body: Some(d_body),
             define_opaque: None,
@@ -433,13 +431,11 @@ mod llvm_enzyme {
             tokens: ts,
         });
 
-        let vis_clone = vis.clone();
-
         let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
         let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
         let d_annotatable = match &item {
             Annotatable::AssocItem(_, _) => {
-                let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
+                let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
                 let d_fn = Box::new(ast::AssocItem {
                     attrs: thin_vec![d_attr],
                     id: ast::DUMMY_NODE_ID,
@@ -451,13 +447,13 @@ mod llvm_enzyme {
                 Annotatable::AssocItem(d_fn, Impl { of_trait: false })
             }
             Annotatable::Item(_) => {
-                let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
+                let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
                 d_fn.vis = vis;
 
                 Annotatable::Item(d_fn)
             }
             Annotatable::Stmt(_) => {
-                let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
+                let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
                 d_fn.vis = vis;
 
                 Annotatable::Stmt(Box::new(ast::Stmt {
@@ -471,9 +467,7 @@ mod llvm_enzyme {
             }
         };
 
-        let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone);
-
-        return vec![orig_annotatable, dummy_const_annotatable, d_annotatable];
+        return vec![orig_annotatable, d_annotatable];
     }
 
     // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
@@ -504,9 +498,11 @@ mod llvm_enzyme {
         diff: Ident,
         span: Span,
         d_sig: &FnSig,
-    ) -> P<ast::Block> {
-        let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal));
-        let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff));
+        generics: &Generics,
+        is_impl: bool,
+    ) -> rustc_ast::Stmt {
+        let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);
+        let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);
 
         let tuple_expr = ecx.expr_tuple(
             span,
@@ -522,371 +518,65 @@ mod llvm_enzyme {
                 .into(),
         );
 
-        let enzyme_path = ecx.path(
-            span,
-            vec![
-                Ident::from_str("std"),
-                Ident::from_str("intrinsics"),
-                Ident::from_str("autodiff"),
-            ],
-        );
+        let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::autodiff]);
+        let enzyme_path = ecx.path(span, enzyme_path_idents);
         let call_expr = ecx.expr_call(
             span,
             ecx.expr_path(enzyme_path),
             vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
         );
 
-        let block = ecx.block_expr(call_expr);
-
-        block
-    }
-
-    // Generate dummy const to prevent primal function
-    // from being optimized away before applying enzyme
-    // ```
-    // const _: () =
-    // {
-    //     #[used]
-    //     pub static DUMMY_PTR: fn_type = primal_fn;
-    // };
-    // ```
-    fn gen_dummy_const(
-        ecx: &ExtCtxt<'_>,
-        span: Span,
-        primal: Ident,
-        sig: FnSig,
-        generics: Generics,
-        vis: Visibility,
-    ) -> Annotatable {
-        // #[used]
-        let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used)));
-        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
-        let used_attr = outer_normal_attr(&used_attr, new_id, span);
-
-        // static DUMMY_PTR: <fn_type> = <primal_ident>
-        let static_ident = Ident::from_str_and_span("DUMMY_PTR", span);
-        let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy {
-            safety: sig.header.safety,
-            ext: sig.header.ext,
-            generic_params: generics.params,
-            decl: sig.decl,
-            decl_span: sig.span,
-        }));
-        let static_ty = ecx.ty(span, fn_ptr_ty);
-
-        let static_expr = ecx.expr_path(ecx.path(span, vec![primal]));
-        let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem {
-            ident: static_ident,
-            ty: static_ty,
-            safety: ast::Safety::Default,
-            mutability: ast::Mutability::Not,
-            expr: Some(static_expr),
-            define_opaque: None,
-        }));
-
-        let static_item = ast::Item {
-            attrs: thin_vec![used_attr],
-            id: ast::DUMMY_NODE_ID,
-            span,
-            vis,
-            kind: static_item_kind,
-            tokens: None,
-        };
-
-        let block_expr = ecx.expr_block(Box::new(ast::Block {
-            stmts: thin_vec![ecx.stmt_item(span, P(static_item))],
-            id: ast::DUMMY_NODE_ID,
-            rules: ast::BlockCheckMode::Default,
-            span,
-            tokens: None,
-        }));
-
-        let const_item = ecx.item_const(
-            span,
-            Ident::from_str_and_span("_", span),
-            ecx.ty(span, ast::TyKind::Tup(thin_vec![])),
-            block_expr,
-        );
-
-        Annotatable::Item(const_item)
+        ecx.stmt_expr(call_expr)
     }
 
-    // Will generate a body of the type:
-    // ```
-    // {
-    //   unsafe {
-    //   asm!("NOP");
-    //   }
-    //   ::core::hint::black_box(primal(args));
-    //   ::core::hint::black_box((args, ret));
-    //   <This part remains to be done by following function>
-    // }
-    // ```
-    fn init_body_helper(
+    // Generate turbofish expression from fn name and generics
+    // Given `foo` and `<A, B, C>` params, gen `foo::<A, B, C>`
+    // We use this expression when passing primal and diff function to the autodiff intrinsic
+    fn gen_turbofish_expr(
         ecx: &ExtCtxt<'_>,
-        span: Span,
-        primal: Ident,
-        new_names: &[String],
-        sig_span: Span,
-        new_decl_span: Span,
-        idents: &[Ident],
-        errored: bool,
+        ident: Ident,
         generics: &Generics,
-    ) -> (Box<ast::Block>, Box<ast::Expr>, Box<ast::Expr>, Box<ast::Expr>) {
-        let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
-        let noop = ast::InlineAsm {
-            asm_macro: ast::AsmMacro::Asm,
-            template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
-            template_strs: Box::new([]),
-            operands: vec![],
-            clobber_abis: vec![],
-            options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
-            line_spans: vec![],
-        };
-        let noop_expr = ecx.expr_asm(span, Box::new(noop));
-        let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
-        let unsf_block = ast::Block {
-            stmts: thin_vec![ecx.stmt_semi(noop_expr)],
-            id: ast::DUMMY_NODE_ID,
-            tokens: None,
-            rules: unsf,
-            span,
-        };
-        let unsf_expr = ecx.expr_block(Box::new(unsf_block));
-        let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
-        let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
-        let black_box_primal_call = ecx.expr_call(
-            new_decl_span,
-            blackbox_call_expr.clone(),
-            thin_vec![primal_call.clone()],
-        );
-        let tup_args = new_names
-            .iter()
-            .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
-            .collect();
-
-        let black_box_remaining_args = ecx.expr_call(
-            sig_span,
-            blackbox_call_expr.clone(),
-            thin_vec![ecx.expr_tuple(sig_span, tup_args)],
-        );
-
-        let mut body = ecx.block(span, ThinVec::new());
-        body.stmts.push(ecx.stmt_semi(unsf_expr));
-
-        // This uses primal args which won't be available if we errored before
-        if !errored {
-            body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
-        }
-        body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
-
-        (body, primal_call, black_box_primal_call, blackbox_call_expr)
-    }
-
-    /// We only want this function to type-check, since we will replace the body
-    /// later on llvm level. Using `loop {}` does not cover all return types anymore,
-    /// so instead we manually build something that should pass the type checker.
-    /// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
-    /// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
-    /// bug would ever try to accidentally differentiate this placeholder function body.
-    /// Finally, we also add back_box usages of all input arguments, to prevent rustc
-    /// from optimizing any arguments away.
-    fn gen_enzyme_body(
-        ecx: &ExtCtxt<'_>,
-        x: &AutoDiffAttrs,
-        n_active: u32,
-        sig: &ast::FnSig,
-        d_sig: &ast::FnSig,
-        primal: Ident,
-        new_names: &[String],
         span: Span,
-        sig_span: Span,
-        idents: Vec<Ident>,
-        errored: bool,
-        generics: &Generics,
-    ) -> Box<ast::Block> {
-        let new_decl_span = d_sig.span;
-
-        // Just adding some default inline-asm and black_box usages to prevent early inlining
-        // and optimizations which alter the function signature.
-        //
-        // The bb_primal_call is the black_box call of the primal function. We keep it around,
-        // since it has the convenient property of returning the type of the primal function,
-        // Remember, we only care to match types here.
-        // No matter which return we pick, we always wrap it into a std::hint::black_box call,
-        // to prevent rustc from propagating it into the caller.
-        let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
-            ecx,
-            span,
-            primal,
-            new_names,
-            sig_span,
-            new_decl_span,
-            &idents,
-            errored,
-            generics,
-        );
-
-        if !has_ret(&d_sig.decl.output) {
-            // there is no return type that we have to match, () works fine.
-            return body;
-        }
-
-        // Everything from here onwards just tries to fulfil the return type. Fun!
-
-        // having an active-only return means we'll drop the original return type.
-        // So that can be treated identical to not having one in the first place.
-        let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
-
-        if primal_ret && n_active == 0 && x.mode.is_rev() {
-            // We only have the primal ret.
-            body.stmts.push(ecx.stmt_expr(bb_primal_call));
-            return body;
-        }
-
-        if !primal_ret && n_active == 1 {
-            // Again no tuple return, so return default float val.
-            let ty = match d_sig.decl.output {
-                FnRetTy::Ty(ref ty) => ty.clone(),
-                FnRetTy::Default(span) => {
-                    panic!("Did not expect Default ret ty: {:?}", span);
+        is_impl: bool,
+    ) -> Box<ast::Expr> {
+        let generic_args = generics
+            .params
+            .iter()
+            .filter_map(|p| match &p.kind {
+                GenericParamKind::Type { .. } => {
+                    let path = ast::Path::from_ident(p.ident);
+                    let ty = ecx.ty_path(path);
+                    Some(AngleBracketedArg::Arg(GenericArg::Type(ty)))
                 }
-            };
-            let arg = ty.kind.is_simple_path().unwrap();
-            let tmp = ecx.def_site_path(&[arg, kw::Default]);
-            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
-            let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
-            body.stmts.push(ecx.stmt_expr(default_call_expr));
-            return body;
-        }
-
-        let mut exprs: Box<ast::Expr> = primal_call;
-        let d_ret_ty = match d_sig.decl.output {
-            FnRetTy::Ty(ref ty) => ty.clone(),
-            FnRetTy::Default(span) => {
-                panic!("Did not expect Default ret ty: {:?}", span);
-            }
-        };
-        if x.mode.is_fwd() {
-            // Fwd mode is easy. If the return activity is Const, we support arbitrary types.
-            // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
-            // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
-            // In all three cases, we can return `std::hint::black_box(<T>::default())`.
-            if x.ret_activity == DiffActivity::Const {
-                // Here we call the primal function, since our dummy function has the same return
-                // type due to the Const return activity.
-                exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
-            } else {
-                let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 };
-                let y = ExprKind::Path(
-                    Some(Box::new(q)),
-                    ecx.path_ident(span, Ident::with_dummy_span(kw::Default)),
-                );
-                let default_call_expr = ecx.expr(span, y);
-                let default_call_expr =
-                    ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
-                exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
-            }
-        } else if x.mode.is_rev() {
-            if x.width == 1 {
-                // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
-                match d_ret_ty.kind {
-                    TyKind::Tup(ref args) => {
-                        // We have a tuple return type. We need to create a tuple of the same size
-                        // and fill it with default values.
-                        let mut exprs2 = thin_vec![exprs];
-                        for arg in args.iter().skip(1) {
-                            let arg = arg.kind.is_simple_path().unwrap();
-                            let tmp = ecx.def_site_path(&[arg, kw::Default]);
-                            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
-                            let default_call_expr =
-                                ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
-                            exprs2.push(default_call_expr);
-                        }
-                        exprs = ecx.expr_tuple(new_decl_span, exprs2);
-                    }
-                    _ => {
-                        // Interestingly, even the `-> ArbitraryType` case
-                        // ends up getting matched and handled correctly above,
-                        // so we don't have to handle any other case for now.
-                        panic!("Unsupported return type: {:?}", d_ret_ty);
-                    }
+                GenericParamKind::Const { .. } => {
+                    let expr = ecx.expr_path(ast::Path::from_ident(p.ident));
+                    let anon_const = AnonConst { id: ast::DUMMY_NODE_ID, value: expr };
+                    Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))
                 }
-            }
-            exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
-        } else {
-            unreachable!("Unsupported mode: {:?}", x.mode);
-        }
-
-        body.stmts.push(ecx.stmt_expr(exprs));
+                GenericParamKind::Lifetime { .. } => None,
+            })
+            .collect::<ThinVec<_>>();
 
-        body
-    }
+        let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args };
 
-    fn gen_primal_call(
-        ecx: &ExtCtxt<'_>,
-        span: Span,
-        primal: Ident,
-        idents: &[Ident],
-        generics: &Generics,
-    ) -> Box<ast::Expr> {
-        let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
+        let segment = PathSegment {
+            ident,
+            id: ast::DUMMY_NODE_ID,
+            args: Some(Box::new(GenericArgs::AngleBracketed(args))),
+        };
 
-        if has_self {
-            let args: ThinVec<_> =
-                idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
-            let self_expr = ecx.expr_self(span);
-            ecx.expr_method_call(span, self_expr, primal, args)
+        let segments = if is_impl {
+            thin_vec![
+                PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None },
+                segment,
+            ]
         } else {
-            let args: ThinVec<_> =
-                idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
-            let mut primal_path = ecx.path_ident(span, primal);
-
-            let is_generic = !generics.params.is_empty();
-
-            match (is_generic, primal_path.segments.last_mut()) {
-                (true, Some(function_path)) => {
-                    let primal_generic_types = generics
-                        .params
-                        .iter()
-                        .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
-
-                    let generated_generic_types = primal_generic_types
-                        .map(|type_param| {
-                            let generic_param = TyKind::Path(
-                                None,
-                                ast::Path {
-                                    span,
-                                    segments: thin_vec![ast::PathSegment {
-                                        ident: type_param.ident,
-                                        args: None,
-                                        id: ast::DUMMY_NODE_ID,
-                                    }],
-                                    tokens: None,
-                                },
-                            );
-
-                            ast::AngleBracketedArg::Arg(ast::GenericArg::Type(Box::new(ast::Ty {
-                                id: type_param.id,
-                                span,
-                                kind: generic_param,
-                                tokens: None,
-                            })))
-                        })
-                        .collect();
-
-                    function_path.args =
-                        Some(Box::new(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
-                            span,
-                            args: generated_generic_types,
-                        })));
-                }
-                _ => {}
-            }
+            thin_vec![segment]
+        };
 
-            let primal_call_expr = ecx.expr_path(primal_path);
-            ecx.expr_call(span, primal_call_expr, args)
-        }
+        let path = Path { span, segments, tokens: None };
+
+        ecx.expr_path(path)
     }
 
     // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
@@ -905,7 +595,7 @@ mod llvm_enzyme {
         sig: &ast::FnSig,
         x: &AutoDiffAttrs,
         span: Span,
-    ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
+    ) -> ast::FnSig {
         let dcx = ecx.sess.dcx();
         let has_ret = has_ret(&sig.decl.output);
         let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
@@ -917,7 +607,7 @@ mod llvm_enzyme {
                 found: num_activities,
             });
             // This is not the right signature, but we can continue parsing.
-            return (sig.clone(), vec![], vec![], true);
+            return sig.clone();
         }
         assert!(sig.decl.inputs.len() == x.input_activity.len());
         assert!(has_ret == x.has_ret_activity());
@@ -960,7 +650,7 @@ mod llvm_enzyme {
 
         if errors {
             // This is not the right signature, but we can continue parsing.
-            return (sig.clone(), new_inputs, idents, true);
+            return sig.clone();
         }
 
         let unsafe_activities = x
@@ -1174,7 +864,7 @@ mod llvm_enzyme {
         }
         let d_sig = FnSig { header: d_header, decl: d_decl, span };
         trace!("Generated signature: {:?}", d_sig);
-        (d_sig, new_inputs, idents, false)
+        d_sig
     }
 }