about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_ast/src/expand/autodiff_attrs.rs13
-rw-r--r--compiler/rustc_builtin_macros/messages.ftl1
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs339
-rw-r--r--compiler/rustc_builtin_macros/src/errors.rs8
-rw-r--r--compiler/rustc_codegen_llvm/src/back/lto.rs12
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs199
-rw-r--r--compiler/rustc_codegen_llvm/src/consts.rs2
-rw-r--r--compiler/rustc_codegen_llvm/src/context.rs23
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs4
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/ffi.rs2
-rw-r--r--compiler/rustc_codegen_ssa/src/codegen_attrs.rs32
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp10
-rw-r--r--compiler/rustc_session/src/config.rs6
-rw-r--r--compiler/rustc_session/src/options.rs4
-rw-r--r--tests/codegen/autodiff.rs4
-rw-r--r--tests/codegen/autodiffv.rs116
-rw-r--r--tests/pretty/autodiff_forward.pp100
-rw-r--r--tests/pretty/autodiff_forward.rs18
-rw-r--r--tests/pretty/autodiff_reverse.pp22
-rw-r--r--tests/ui/autodiff/autodiff_illegal.rs7
-rw-r--r--tests/ui/autodiff/autodiff_illegal.stderr38
21 files changed, 727 insertions, 233 deletions
diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
index c8ec185ee5e..f01c781f46c 100644
--- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs
+++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
@@ -77,6 +77,17 @@ pub struct AutoDiffAttrs {
     /// e.g. in the [JAX
     /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
     pub mode: DiffMode,
+    /// A user-provided, batching width. If not given, we will default to 1 (no batching).
+    /// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
+    /// - Calling the function 50 times with a batch size of 2
+    /// - Calling the function 25 times with a batch size of 4,
+    /// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
+    /// cache locality, better re-usal of primal values, and other optimizations.
+    /// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
+    /// times, so this massively increases code size. As such, values like 1024 are unlikely to
+    /// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
+    /// experiments for now and focus on documenting the implications of a large width.
+    pub width: u32,
     pub ret_activity: DiffActivity,
     pub input_activity: Vec<DiffActivity>,
 }
@@ -222,6 +233,7 @@ impl AutoDiffAttrs {
     pub const fn error() -> Self {
         AutoDiffAttrs {
             mode: DiffMode::Error,
+            width: 0,
             ret_activity: DiffActivity::None,
             input_activity: Vec::new(),
         }
@@ -229,6 +241,7 @@ impl AutoDiffAttrs {
     pub fn source() -> Self {
         AutoDiffAttrs {
             mode: DiffMode::Source,
+            width: 0,
             ret_activity: DiffActivity::None,
             input_activity: Vec::new(),
         }
diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl
index 3f03834f8d7..603dc90bafc 100644
--- a/compiler/rustc_builtin_macros/messages.ftl
+++ b/compiler/rustc_builtin_macros/messages.ftl
@@ -79,6 +79,7 @@ builtin_macros_autodiff_ret_activity = invalid return activity {$act} in {$mode}
 builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
 builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
 
+builtin_macros_autodiff_width = autodiff width must fit u32, but is {$width}
 builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
     .label = not applicable here
     .label2 = not a `struct`, `enum` or `union`
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 8937d35d53a..7f99f75b2b9 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -12,12 +12,12 @@ mod llvm_enzyme {
         valid_ty_for_activity,
     };
     use rustc_ast::ptr::P;
-    use rustc_ast::token::{Token, TokenKind};
+    use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
     use rustc_ast::tokenstream::*;
     use rustc_ast::visit::AssocCtxt::*;
     use rustc_ast::{
-        self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
-        PatKind, TyKind,
+        self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
+        MetaItemInner, PatKind, QSelf, TyKind,
     };
     use rustc_expand::base::{Annotatable, ExtCtxt};
     use rustc_span::{Ident, Span, Symbol, kw, sym};
@@ -45,6 +45,16 @@ mod llvm_enzyme {
         }
     }
     fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
+        if let Some(l) = x.lit() {
+            match l.kind {
+                ast::LitKind::Int(val, _) => {
+                    // get an Ident from a lit
+                    return rustc_span::Ident::from_str(val.get().to_string().as_str());
+                }
+                _ => {}
+            }
+        }
+
         let segments = &x.meta_item().unwrap().path.segments;
         assert!(segments.len() == 1);
         segments[0].ident
@@ -54,6 +64,14 @@ mod llvm_enzyme {
         first_ident(x).name.to_string()
     }
 
+    fn width(x: &MetaItemInner) -> Option<u128> {
+        let lit = x.lit()?;
+        match lit.kind {
+            ast::LitKind::Int(x, _) => Some(x.get()),
+            _ => return None,
+        }
+    }
+
     pub(crate) fn from_ast(
         ecx: &mut ExtCtxt<'_>,
         meta_item: &ThinVec<MetaItemInner>,
@@ -65,9 +83,32 @@ mod llvm_enzyme {
             dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
             return AutoDiffAttrs::error();
         };
+
+        // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
+        // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
+        let mut first_activity = 2;
+
+        let width = if let [_, _, x, ..] = &meta_item[..]
+            && let Some(x) = width(x)
+        {
+            first_activity = 3;
+            match x.try_into() {
+                Ok(x) => x,
+                Err(_) => {
+                    dcx.emit_err(errors::AutoDiffInvalidWidth {
+                        span: meta_item[2].span(),
+                        width: x,
+                    });
+                    return AutoDiffAttrs::error();
+                }
+            }
+        } else {
+            1
+        };
+
         let mut activities: Vec<DiffActivity> = vec![];
         let mut errors = false;
-        for x in &meta_item[2..] {
+        for x in &meta_item[first_activity..] {
             let activity_str = name(&x);
             let res = DiffActivity::from_str(&activity_str);
             match res {
@@ -98,7 +139,20 @@ mod llvm_enzyme {
             (&DiffActivity::None, activities.as_slice())
         };
 
-        AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
+        AutoDiffAttrs {
+            mode,
+            width,
+            ret_activity: *ret_activity,
+            input_activity: input_activity.to_vec(),
+        }
+    }
+
+    fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
+        let comma: Token = Token::new(TokenKind::Comma, Span::default());
+        let val = first_ident(t);
+        let t = Token::from_ast_ident(val);
+        ts.push(TokenTree::Token(t, Spacing::Joint));
+        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
     }
 
     /// We expand the autodiff macro to generate a new placeholder function which passes
@@ -195,27 +249,49 @@ mod llvm_enzyme {
 
         // create TokenStream from vec elemtents:
         // meta_item doesn't have a .tokens field
-        let comma: Token = Token::new(TokenKind::Comma, Span::default());
         let mut ts: Vec<TokenTree> = vec![];
         if meta_item_vec.len() < 2 {
             // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
             // input and output args.
             dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
             return vec![item];
+        }
+
+        meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
+
+        // Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
+        // If it is not given, we default to 1 (scalar mode).
+        let start_position;
+        let kind: LitKind = LitKind::Integer;
+        let symbol;
+        if meta_item_vec.len() >= 3
+            && let Some(width) = width(&meta_item_vec[2])
+        {
+            start_position = 3;
+            symbol = Symbol::intern(&width.to_string());
         } else {
-            for t in meta_item_vec.clone()[1..].iter() {
-                let val = first_ident(t);
-                let t = Token::from_ast_ident(val);
-                ts.push(TokenTree::Token(t, Spacing::Joint));
-                ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
-            }
+            start_position = 2;
+            symbol = sym::integer(1);
         }
+        let l: Lit = Lit { kind, symbol, suffix: None };
+        let t = Token::new(TokenKind::Literal(l), Span::default());
+        let comma = Token::new(TokenKind::Comma, Span::default());
+        ts.push(TokenTree::Token(t, Spacing::Joint));
+        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
+
+        for t in meta_item_vec.clone()[start_position..].iter() {
+            meta_item_inner_to_ts(t, &mut ts);
+        }
+
         if !has_ret {
             // We don't want users to provide a return activity if the function doesn't return anything.
             // For simplicity, we just add a dummy token to the end of the list.
             let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
             ts.push(TokenTree::Token(t, Spacing::Joint));
+            ts.push(TokenTree::Token(comma, Spacing::Alone));
         }
+        // We remove the last, trailing comma.
+        ts.pop();
         let ts: TokenStream = TokenStream::from_iter(ts);
 
         let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
@@ -470,6 +546,8 @@ mod llvm_enzyme {
             return body;
         }
 
+        // Everything from here onwards just tries to fullfil the return type. Fun!
+
         // having an active-only return means we'll drop the original return type.
         // So that can be treated identical to not having one in the first place.
         let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
@@ -497,86 +575,65 @@ mod llvm_enzyme {
             return body;
         }
 
-        let mut exprs = ThinVec::<P<ast::Expr>>::new();
-        if primal_ret {
-            // We have both primal ret and active floats.
-            // primal ret is first, by construction.
-            exprs.push(primal_call);
-        }
-
-        // Now construct default placeholder for each active float.
-        // Is there something nicer than f32::default() and f64::default()?
+        let mut exprs: P<ast::Expr> = primal_call.clone();
         let d_ret_ty = match d_sig.decl.output {
             FnRetTy::Ty(ref ty) => ty.clone(),
             FnRetTy::Default(span) => {
                 panic!("Did not expect Default ret ty: {:?}", span);
             }
         };
-        let mut d_ret_ty = match d_ret_ty.kind.clone() {
-            TyKind::Tup(ref tys) => tys.clone(),
-            TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
-                if let [segment] = &segments[..]
-                    && segment.args.is_none()
-                {
-                    let id = vec![segments[0].ident];
-                    let kind = TyKind::Path(None, ecx.path(span, id));
-                    let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
-                    thin_vec![ty]
-                } else {
-                    panic!("Expected tuple or simple path return type");
-                }
-            }
-            _ => {
-                // We messed up construction of d_sig
-                panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
-            }
-        };
-
-        if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
-            assert!(d_ret_ty.len() == 2);
-            // both should be identical, by construction
-            let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
-            let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
-            assert!(arg == arg2);
-            let sl: Vec<Symbol> = vec![arg, kw::Default];
-            let tmp = ecx.def_site_path(&sl);
-            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
-            let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
-            exprs.push(default_call_expr);
-        } else if x.mode.is_rev() {
-            if primal_ret {
-                // We have extra handling above for the primal ret
-                d_ret_ty = d_ret_ty[1..].to_vec().into();
-            }
 
-            for arg in d_ret_ty.iter() {
-                let arg = arg.kind.is_simple_path().unwrap();
-                let sl: Vec<Symbol> = vec![arg, kw::Default];
-                let tmp = ecx.def_site_path(&sl);
-                let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
+        if x.mode.is_fwd() {
+            // Fwd mode is easy. If the return activity is Const, we support arbitrary types.
+            // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
+            // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
+            // In all three cases, we can return `std::hint::black_box(<T>::default())`.
+            if x.ret_activity == DiffActivity::Const {
+                // Here we call the primal function, since our dummy function has the same return
+                // type due to the Const return activity.
+                exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
+            } else {
+                let q = QSelf { ty: d_ret_ty.clone(), path_span: span, position: 0 };
+                let y =
+                    ExprKind::Path(Some(P(q)), ecx.path_ident(span, Ident::from_str("default")));
+                let default_call_expr = ecx.expr(span, y);
                 let default_call_expr =
                     ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
-                exprs.push(default_call_expr);
-            }
-        }
-
-        let ret: P<ast::Expr>;
-        match &exprs[..] {
-            [] => {
-                assert!(!has_ret(&d_sig.decl.output));
-                // We don't have to match the return type.
-                return body;
-            }
-            [arg] => {
-                ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]);
+                exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
             }
-            args => {
-                let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
-                ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]);
+        } else if x.mode.is_rev() {
+            if x.width == 1 {
+                // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
+                match d_ret_ty.kind {
+                    TyKind::Tup(ref args) => {
+                        // We have a tuple return type. We need to create a tuple of the same size
+                        // and fill it with default values.
+                        let mut exprs2 = thin_vec![exprs];
+                        for arg in args.iter().skip(1) {
+                            let arg = arg.kind.is_simple_path().unwrap();
+                            let sl: Vec<Symbol> = vec![arg, kw::Default];
+                            let tmp = ecx.def_site_path(&sl);
+                            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
+                            let default_call_expr =
+                                ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
+                            exprs2.push(default_call_expr);
+                        }
+                        exprs = ecx.expr_tuple(new_decl_span, exprs2);
+                    }
+                    _ => {
+                        // Interestingly, even the `-> ArbitraryType` case
+                        // ends up getting matched and handled correctly above,
+                        // so we don't have to handle any other case for now.
+                        panic!("Unsupported return type: {:?}", d_ret_ty);
+                    }
+                }
             }
+            exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
+        } else {
+            unreachable!("Unsupported mode: {:?}", x.mode);
         }
-        assert!(has_ret(&d_sig.decl.output));
-        body.stmts.push(ecx.stmt_expr(ret));
+
+        body.stmts.push(ecx.stmt_expr(exprs));
 
         body
     }
@@ -684,50 +741,55 @@ mod llvm_enzyme {
             match activity {
                 DiffActivity::Active => {
                     act_ret.push(arg.ty.clone());
+                    // if width =/= 1, then push [arg.ty; width] to act_ret
                 }
                 DiffActivity::ActiveOnly => {
                     // We will add the active scalar to the return type.
                     // This is handled later.
                 }
                 DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
-                    let mut shadow_arg = arg.clone();
-                    // We += into the shadow in reverse mode.
-                    shadow_arg.ty = P(assure_mut_ref(&arg.ty));
-                    let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
-                        ident.name
-                    } else {
-                        debug!("{:#?}", &shadow_arg.pat);
-                        panic!("not an ident?");
-                    };
-                    let name: String = format!("d{}", old_name);
-                    new_inputs.push(name.clone());
-                    let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
-                    shadow_arg.pat = P(ast::Pat {
-                        id: ast::DUMMY_NODE_ID,
-                        kind: PatKind::Ident(BindingMode::NONE, ident, None),
-                        span: shadow_arg.pat.span,
-                        tokens: shadow_arg.pat.tokens.clone(),
-                    });
-                    d_inputs.push(shadow_arg);
+                    for i in 0..x.width {
+                        let mut shadow_arg = arg.clone();
+                        // We += into the shadow in reverse mode.
+                        shadow_arg.ty = P(assure_mut_ref(&arg.ty));
+                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
+                            ident.name
+                        } else {
+                            debug!("{:#?}", &shadow_arg.pat);
+                            panic!("not an ident?");
+                        };
+                        let name: String = format!("d{}_{}", old_name, i);
+                        new_inputs.push(name.clone());
+                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
+                        shadow_arg.pat = P(ast::Pat {
+                            id: ast::DUMMY_NODE_ID,
+                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
+                            span: shadow_arg.pat.span,
+                            tokens: shadow_arg.pat.tokens.clone(),
+                        });
+                        d_inputs.push(shadow_arg.clone());
+                    }
                 }
                 DiffActivity::Dual | DiffActivity::DualOnly => {
-                    let mut shadow_arg = arg.clone();
-                    let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
-                        ident.name
-                    } else {
-                        debug!("{:#?}", &shadow_arg.pat);
-                        panic!("not an ident?");
-                    };
-                    let name: String = format!("b{}", old_name);
-                    new_inputs.push(name.clone());
-                    let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
-                    shadow_arg.pat = P(ast::Pat {
-                        id: ast::DUMMY_NODE_ID,
-                        kind: PatKind::Ident(BindingMode::NONE, ident, None),
-                        span: shadow_arg.pat.span,
-                        tokens: shadow_arg.pat.tokens.clone(),
-                    });
-                    d_inputs.push(shadow_arg);
+                    for i in 0..x.width {
+                        let mut shadow_arg = arg.clone();
+                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
+                            ident.name
+                        } else {
+                            debug!("{:#?}", &shadow_arg.pat);
+                            panic!("not an ident?");
+                        };
+                        let name: String = format!("b{}_{}", old_name, i);
+                        new_inputs.push(name.clone());
+                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
+                        shadow_arg.pat = P(ast::Pat {
+                            id: ast::DUMMY_NODE_ID,
+                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
+                            span: shadow_arg.pat.span,
+                            tokens: shadow_arg.pat.tokens.clone(),
+                        });
+                        d_inputs.push(shadow_arg.clone());
+                    }
                 }
                 DiffActivity::Const => {
                     // Nothing to do here.
@@ -783,23 +845,48 @@ mod llvm_enzyme {
         d_decl.inputs = d_inputs.into();
 
         if x.mode.is_fwd() {
+            let ty = match d_decl.output {
+                FnRetTy::Ty(ref ty) => ty.clone(),
+                FnRetTy::Default(span) => {
+                    // We want to return std::hint::black_box(()).
+                    let kind = TyKind::Tup(ThinVec::new());
+                    let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
+                    d_decl.output = FnRetTy::Ty(ty.clone());
+                    assert!(matches!(x.ret_activity, DiffActivity::None));
+                    // this won't be used below, so any type would be fine.
+                    ty
+                }
+            };
+
             if let DiffActivity::Dual = x.ret_activity {
-                let ty = match d_decl.output {
-                    FnRetTy::Ty(ref ty) => ty.clone(),
-                    FnRetTy::Default(span) => {
-                        panic!("Did not expect Default ret ty: {:?}", span);
-                    }
+                let kind = if x.width == 1 {
+                    // Dual can only be used for f32/f64 ret.
+                    // In that case we return now a tuple with two floats.
+                    TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
+                } else {
+                    // We have to return [T; width+1], +1 for the primal return.
+                    let anon_const = rustc_ast::AnonConst {
+                        id: ast::DUMMY_NODE_ID,
+                        value: ecx.expr_usize(span, 1 + x.width as usize),
+                    };
+                    TyKind::Array(ty.clone(), anon_const)
                 };
-                // Dual can only be used for f32/f64 ret.
-                // In that case we return now a tuple with two floats.
-                let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
                 let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
                 d_decl.output = FnRetTy::Ty(ty);
             }
             if let DiffActivity::DualOnly = x.ret_activity {
                 // No need to change the return type,
-                // we will just return the shadow in place
-                // of the primal return.
+                // we will just return the shadow in place of the primal return.
+                // However, if we have a width > 1, then we don't return -> T, but -> [T; width]
+                if x.width > 1 {
+                    let anon_const = rustc_ast::AnonConst {
+                        id: ast::DUMMY_NODE_ID,
+                        value: ecx.expr_usize(span, x.width as usize),
+                    };
+                    let kind = TyKind::Array(ty.clone(), anon_const);
+                    let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
+                    d_decl.output = FnRetTy::Ty(ty);
+                }
             }
         }
 
diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs
index 30597944124..4bbe212f429 100644
--- a/compiler/rustc_builtin_macros/src/errors.rs
+++ b/compiler/rustc_builtin_macros/src/errors.rs
@@ -203,6 +203,14 @@ mod autodiff {
     }
 
     #[derive(Diagnostic)]
+    #[diag(builtin_macros_autodiff_width)]
+    pub(crate) struct AutoDiffInvalidWidth {
+        #[primary_span]
+        pub(crate) span: Span,
+        pub(crate) width: u128,
+    }
+
+    #[derive(Diagnostic)]
     #[diag(builtin_macros_autodiff)]
     pub(crate) struct AutoDiffInvalidApplication {
         #[primary_span]
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs
index f083cfbd7d3..a8b49e9552c 100644
--- a/compiler/rustc_codegen_llvm/src/back/lto.rs
+++ b/compiler/rustc_codegen_llvm/src/back/lto.rs
@@ -610,6 +610,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
             }
             // We handle this below
             config::AutoDiff::PrintModAfter => {}
+            // We handle this below
+            config::AutoDiff::PrintModFinal => {}
             // This is required and already checked
             config::AutoDiff::Enable => {}
         }
@@ -657,14 +659,20 @@ pub(crate) fn run_pass_manager(
     }
 
     if cfg!(llvm_enzyme) && enable_ad {
+        // This is the post-autodiff IR, mainly used for testing and educational purposes.
+        if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
+            unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
+        }
+
         let opt_stage = llvm::OptStage::FatLTO;
         let stage = write::AutodiffStage::PostAD;
         unsafe {
             write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
         }
 
-        // This is the final IR, so people should be able to inspect the optimized autodiff output.
-        if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
+        // This is the final IR, so people should be able to inspect the optimized autodiff output,
+        // for manual inspection.
+        if config.autodiff.contains(&config::AutoDiff::PrintModFinal) {
             unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
         }
     }
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 7cd4ee539d8..7d264ba4d00 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -3,8 +3,10 @@ use std::ptr;
 use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
 use rustc_codegen_ssa::ModuleCodegen;
 use rustc_codegen_ssa::back::write::ModuleConfig;
-use rustc_codegen_ssa::traits::BaseTypeCodegenMethods as _;
+use rustc_codegen_ssa::common::TypeKind;
+use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
 use rustc_errors::FatalError;
+use rustc_middle::bug;
 use tracing::{debug, trace};
 
 use crate::back::write::llvm_err;
@@ -18,21 +20,42 @@ use crate::value::Value;
 use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
 
 fn get_params(fnc: &Value) -> Vec<&Value> {
+    let param_num = llvm::LLVMCountParams(fnc) as usize;
+    let mut fnc_args: Vec<&Value> = vec![];
+    fnc_args.reserve(param_num);
     unsafe {
-        let param_num = llvm::LLVMCountParams(fnc) as usize;
-        let mut fnc_args: Vec<&Value> = vec![];
-        fnc_args.reserve(param_num);
         llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr());
         fnc_args.set_len(param_num);
-        fnc_args
     }
+    fnc_args
 }
 
+fn has_sret(fnc: &Value) -> bool {
+    let num_args = llvm::LLVMCountParams(fnc) as usize;
+    if num_args == 0 {
+        false
+    } else {
+        unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
+    }
+}
+
+// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
+// original inputs, as well as metadata and the additional shadow arguments.
+// This function matches the arguments from the outer function to the inner enzyme call.
+//
+// This function also considers that Rust level arguments not always match the llvm-ir level
+// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
+// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
+// need to match those.
+// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
+// using iterators and peek()?
 fn match_args_from_caller_to_enzyme<'ll>(
     cx: &SimpleCx<'ll>,
+    width: u32,
     args: &mut Vec<&'ll llvm::Value>,
     inputs: &[DiffActivity],
     outer_args: &[&'ll llvm::Value],
+    has_sret: bool,
 ) {
     debug!("matching autodiff arguments");
     // We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -44,6 +67,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
     let mut outer_pos: usize = 0;
     let mut activity_pos = 0;
 
+    if has_sret {
+        // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
+        // inner function will still return something. We increase our outer_pos by one,
+        // and once we're done with all other args we will take the return of the inner call and
+        // update the sret pointer with it
+        outer_pos = 1;
+    }
+
     let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
     let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
     let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
@@ -92,23 +123,20 @@ fn match_args_from_caller_to_enzyme<'ll>(
                 // (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
                 // FIXME(ZuseZ4): We will upstream a safety check later which asserts that
                 // int2 >= int1, which means the shadow vector is large enough to store the gradient.
-                assert!(unsafe {
-                    llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer
-                });
-                let next_outer_arg2 = outer_args[outer_pos + 2];
-                let next_outer_ty2 = cx.val_ty(next_outer_arg2);
-                assert!(unsafe {
-                    llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer
-                });
-                let next_outer_arg3 = outer_args[outer_pos + 3];
-                let next_outer_ty3 = cx.val_ty(next_outer_arg3);
-                assert!(unsafe {
-                    llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer
-                });
-                args.push(next_outer_arg2);
+                assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
+
+                for i in 0..(width as usize) {
+                    let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
+                    let next_outer_ty2 = cx.val_ty(next_outer_arg2);
+                    assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
+                    let next_outer_arg3 = outer_args[outer_pos + 2 * (i + 1) + 1];
+                    let next_outer_ty3 = cx.val_ty(next_outer_arg3);
+                    assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
+                    args.push(next_outer_arg2);
+                }
                 args.push(cx.get_metadata_value(enzyme_const));
                 args.push(next_outer_arg);
-                outer_pos += 4;
+                outer_pos += 2 + 2 * width as usize;
                 activity_pos += 2;
             } else {
                 // A duplicated pointer will have the following two outer_fn arguments:
@@ -116,15 +144,19 @@ fn match_args_from_caller_to_enzyme<'ll>(
                 // (..., metadata! enzyme_dup, ptr, ptr, ...).
                 if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly)
                 {
-                    assert!(
-                        unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty) }
-                            == llvm::TypeKind::Pointer
-                    );
+                    assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Pointer);
                 }
                 // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
                 args.push(next_outer_arg);
                 outer_pos += 2;
                 activity_pos += 1;
+
+                // Now, if width > 1, we need to account for that
+                for _ in 1..width {
+                    let next_outer_arg = outer_args[outer_pos];
+                    args.push(next_outer_arg);
+                    outer_pos += 1;
+                }
             }
         } else {
             // We do not differentiate with resprect to this argument.
@@ -135,6 +167,76 @@ fn match_args_from_caller_to_enzyme<'ll>(
     }
 }
 
+// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
+// arguments. We do however need to declare them with their correct return type.
+// We already figured the correct return type out in our frontend, when generating the outer_fn,
+// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
+// Beyond sret, this article describes our challenges nicely:
+// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
+// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
+fn compute_enzyme_fn_ty<'ll>(
+    cx: &SimpleCx<'ll>,
+    attrs: &AutoDiffAttrs,
+    fn_to_diff: &'ll Value,
+    outer_fn: &'ll Value,
+) -> &'ll llvm::Type {
+    let fn_ty = cx.get_type_of_global(outer_fn);
+    let mut ret_ty = cx.get_return_type(fn_ty);
+
+    let has_sret = has_sret(outer_fn);
+
+    if has_sret {
+        // Now we don't just forward the return type, so we have to figure it out based on the
+        // primal return type, in combination with the autodiff settings.
+        let fn_ty = cx.get_type_of_global(fn_to_diff);
+        let inner_ret_ty = cx.get_return_type(fn_ty);
+
+        let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
+        if inner_ret_ty == void_ty {
+            // This indicates that even the inner function has an sret.
+            // Right now I only look for an sret in the outer function.
+            // This *probably* needs some extra handling, but I never ran
+            // into such a case. So I'll wait for user reports to have a test case.
+            bug!("sret in inner function");
+        }
+
+        if attrs.width == 1 {
+            todo!("Handle sret for scalar ad");
+        } else {
+            // First we check if we also have to deal with the primal return.
+            match attrs.mode {
+                DiffMode::Forward => match attrs.ret_activity {
+                    DiffActivity::Dual => {
+                        let arr_ty =
+                            unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
+                        ret_ty = arr_ty;
+                    }
+                    DiffActivity::DualOnly => {
+                        let arr_ty =
+                            unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
+                        ret_ty = arr_ty;
+                    }
+                    DiffActivity::Const => {
+                        todo!("Not sure, do we need to do something here?");
+                    }
+                    _ => {
+                        bug!("unreachable");
+                    }
+                },
+                DiffMode::Reverse => {
+                    todo!("Handle sret for reverse mode");
+                }
+                _ => {
+                    bug!("unreachable");
+                }
+            }
+        }
+    }
+
+    // LLVM can figure out the input types on it's own, so we take a shortcut here.
+    unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
+}
+
 /// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
 /// function with expected naming and calling conventions[^1] which will be
 /// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -197,17 +299,9 @@ fn generate_enzyme_call<'ll>(
     // }
     // ```
     unsafe {
-        // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
-        // arguments. We do however need to declare them with their correct return type.
-        // We already figured the correct return type out in our frontend, when generating the outer_fn,
-        // so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
-        let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn);
-        let ret_ty = llvm::LLVMGetReturnType(fn_ty);
-
-        // LLVM can figure out the input types on it's own, so we take a shortcut here.
-        let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True);
+        let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
 
-        //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
+        // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
         // think a bit more about what should go here.
         let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
         let ad_fn = declare_simple_fn(
@@ -240,14 +334,27 @@ fn generate_enzyme_call<'ll>(
         if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
             args.push(cx.get_metadata_value(enzyme_primal_ret));
         }
+        if attrs.width > 1 {
+            let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
+            args.push(cx.get_metadata_value(enzyme_width));
+            args.push(cx.get_const_i64(attrs.width as u64));
+        }
 
+        let has_sret = has_sret(outer_fn);
         let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
-        match_args_from_caller_to_enzyme(&cx, &mut args, &attrs.input_activity, &outer_args);
+        match_args_from_caller_to_enzyme(
+            &cx,
+            attrs.width,
+            &mut args,
+            &attrs.input_activity,
+            &outer_args,
+            has_sret,
+        );
 
         let call = builder.call(enzyme_ty, ad_fn, &args, None);
 
         // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
-        // metadata attachted to it, but we just created this code oota. Given that the
+        // metadata attached to it, but we just created this code oota. Given that the
         // differentiated function already has partly confusing metadata, and given that this
         // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
         // dummy code which we inserted at a higher level.
@@ -268,7 +375,22 @@ fn generate_enzyme_call<'ll>(
         // Now that we copied the metadata, get rid of dummy code.
         llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
 
-        if cx.val_ty(call) == cx.type_void() {
+        if cx.val_ty(call) == cx.type_void() || has_sret {
+            if has_sret {
+                // This is what we already have in our outer_fn (shortened):
+                // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
+                //   %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
+                //   <Here we are, we want to add the following two lines>
+                //   store [4 x double] %7, ptr %0, align 8
+                //   ret void
+                // }
+
+                // now store the result of the enzyme call into the sret pointer.
+                let sret_ptr = outer_args[0];
+                let call_ty = cx.val_ty(call);
+                assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
+                llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
+            }
             builder.ret_void();
         } else {
             builder.ret(call);
@@ -300,8 +422,7 @@ pub(crate) fn differentiate<'ll>(
     if !diff_items.is_empty()
         && !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
     {
-        let dcx = cgcx.create_dcx();
-        return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
+        return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
     }
 
     // Before dumping the module, we want all the TypeTrees to become part of the module.
diff --git a/compiler/rustc_codegen_llvm/src/consts.rs b/compiler/rustc_codegen_llvm/src/consts.rs
index 7675e75338a..bf81eb648f8 100644
--- a/compiler/rustc_codegen_llvm/src/consts.rs
+++ b/compiler/rustc_codegen_llvm/src/consts.rs
@@ -430,7 +430,7 @@ impl<'ll> CodegenCx<'ll, '_> {
             let val_llty = self.val_ty(v);
 
             let g = self.get_static_inner(def_id, val_llty);
-            let llty = llvm::LLVMGlobalGetValueType(g);
+            let llty = self.get_type_of_global(g);
 
             let g = if val_llty == llty {
                 g
diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs
index f7b096ff976..3be8cd5f6ac 100644
--- a/compiler/rustc_codegen_llvm/src/context.rs
+++ b/compiler/rustc_codegen_llvm/src/context.rs
@@ -8,6 +8,7 @@ use std::str;
 use rustc_abi::{HasDataLayout, Size, TargetDataLayout, VariantIdx};
 use rustc_codegen_ssa::back::versioned_llvm_target;
 use rustc_codegen_ssa::base::{wants_msvc_seh, wants_wasm_eh};
+use rustc_codegen_ssa::common::TypeKind;
 use rustc_codegen_ssa::errors as ssa_errors;
 use rustc_codegen_ssa::traits::*;
 use rustc_data_structures::base_n::{ALPHANUMERIC_ONLY, ToBaseN};
@@ -38,7 +39,7 @@ use crate::debuginfo::metadata::apply_vcall_visibility_metadata;
 use crate::llvm::Metadata;
 use crate::type_::Type;
 use crate::value::Value;
-use crate::{attributes, coverageinfo, debuginfo, llvm, llvm_util};
+use crate::{attributes, common, coverageinfo, debuginfo, llvm, llvm_util};
 
 /// `TyCtxt` (and related cache datastructures) can't be move between threads.
 /// However, there are various cx related functions which we want to be available to the builder and
@@ -643,7 +644,18 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
         llvm::set_section(g, c"llvm.metadata");
     }
 }
-
+impl<'ll> SimpleCx<'ll> {
+    pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
+        assert_eq!(self.type_kind(ty), TypeKind::Function);
+        unsafe { llvm::LLVMGetReturnType(ty) }
+    }
+    pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type {
+        unsafe { llvm::LLVMGlobalGetValueType(val) }
+    }
+    pub(crate) fn val_ty(&self, v: &'ll Value) -> &'ll Type {
+        common::val_ty(v)
+    }
+}
 impl<'ll> SimpleCx<'ll> {
     pub(crate) fn new(
         llmod: &'ll llvm::Module,
@@ -660,6 +672,13 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
         llvm::LLVMMetadataAsValue(self.llcx(), metadata)
     }
 
+    // FIXME(autodiff): We should split `ConstCodegenMethods` to pull the reusable parts
+    // onto a trait that is also implemented for GenericCx.
+    pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value {
+        let ty = unsafe { llvm::LLVMInt64TypeInContext(self.llcx()) };
+        unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
+    }
+
     pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
         let name = SmallCStr::new(name);
         unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index 79e4cc8aa77..a9b3bdf7344 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -4,7 +4,7 @@
 use libc::{c_char, c_uint};
 
 use super::MetadataKindId;
-use super::ffi::{BasicBlock, Metadata, Module, Type, Value};
+use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
 use crate::llvm::Bool;
 
 #[link(name = "llvm-wrapper", kind = "static")]
@@ -17,6 +17,8 @@ unsafe extern "C" {
     pub(crate) fn LLVMRustEraseInstFromParent(V: &Value);
     pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
     pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
+    pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
+    pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
 }
 
 unsafe extern "C" {
diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
index 3ce3761944b..9ff04f72903 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
@@ -1180,7 +1180,7 @@ unsafe extern "C" {
 
     // Operations on parameters
     pub(crate) fn LLVMIsAArgument(Val: &Value) -> Option<&Value>;
-    pub(crate) fn LLVMCountParams(Fn: &Value) -> c_uint;
+    pub(crate) safe fn LLVMCountParams(Fn: &Value) -> c_uint;
     pub(crate) fn LLVMGetParam(Fn: &Value, Index: c_uint) -> &Value;
 
     // Operations on basic blocks
diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs
index 8a184fc0bef..ddb61188983 100644
--- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs
+++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs
@@ -2,7 +2,7 @@ use std::str::FromStr;
 
 use rustc_abi::ExternAbi;
 use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
-use rustc_ast::{MetaItem, MetaItemInner, attr};
+use rustc_ast::{LitKind, MetaItem, MetaItemInner, attr};
 use rustc_attr_parsing::ReprAttr::ReprAlign;
 use rustc_attr_parsing::{AttributeKind, InlineAttr, InstructionSetAttr, OptimizeAttr};
 use rustc_data_structures::fx::FxHashMap;
@@ -805,8 +805,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
         return Some(AutoDiffAttrs::source());
     }
 
-    let [mode, input_activities @ .., ret_activity] = &list[..] else {
-        span_bug!(attr.span(), "rustc_autodiff attribute must contain mode and activities");
+    let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else {
+        span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities");
     };
     let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode {
         p1.segments.first().unwrap().ident
@@ -823,6 +823,30 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
         }
     };
 
+    let width: u32 = match width_meta {
+        MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => {
+            let w = p1.segments.first().unwrap().ident;
+            match w.as_str().parse() {
+                Ok(val) => val,
+                Err(_) => {
+                    span_bug!(w.span, "rustc_autodiff width should fit u32");
+                }
+            }
+        }
+        MetaItemInner::Lit(lit) => {
+            if let LitKind::Int(val, _) = lit.kind {
+                match val.get().try_into() {
+                    Ok(val) => val,
+                    Err(_) => {
+                        span_bug!(lit.span, "rustc_autodiff width should fit u32");
+                    }
+                }
+            } else {
+                span_bug!(lit.span, "rustc_autodiff width should be an integer");
+            }
+        }
+    };
+
     // First read the ret symbol from the attribute
     let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity {
         p1.segments.first().unwrap().ident
@@ -860,7 +884,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
         }
     }
 
-    Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
+    Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities })
 }
 
 pub(crate) fn provide(providers: &mut Providers) {
diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
index 53df59930f4..32e6da446d7 100644
--- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
@@ -384,6 +384,12 @@ static inline void AddAttributes(T *t, unsigned Index, LLVMAttributeRef *Attrs,
   t->setAttributes(PALNew);
 }
 
+extern "C" bool LLVMRustHasAttributeAtIndex(LLVMValueRef Fn, unsigned Index,
+                                            LLVMRustAttributeKind RustAttr) {
+  Function *F = unwrap<Function>(Fn);
+  return F->hasParamAttribute(Index, fromRust(RustAttr));
+}
+
 extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index,
                                               LLVMAttributeRef *Attrs,
                                               size_t AttrsLen) {
@@ -636,6 +642,10 @@ static InlineAsm::AsmDialect fromRust(LLVMRustAsmDialect Dialect) {
   }
 }
 
+extern "C" uint64_t LLVMRustGetArrayNumElements(LLVMTypeRef Ty) {
+  return unwrap(Ty)->getArrayNumElements();
+}
+
 extern "C" LLVMValueRef
 LLVMRustInlineAsm(LLVMTypeRef Ty, char *AsmString, size_t AsmStringLen,
                   char *Constraints, size_t ConstraintsLen,
diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs
index 1f18950feac..b5f24c77953 100644
--- a/compiler/rustc_session/src/config.rs
+++ b/compiler/rustc_session/src/config.rs
@@ -237,10 +237,12 @@ pub enum AutoDiff {
     PrintPerf,
     /// Print intermediate IR generation steps
     PrintSteps,
-    /// Print the whole module, before running opts.
+    /// Print the module, before running autodiff.
     PrintModBefore,
-    /// Print the module after Enzyme differentiated everything.
+    /// Print the module after running autodiff.
     PrintModAfter,
+    /// Print the module after running autodiff and optimizations.
+    PrintModFinal,
 
     /// Enzyme's loose type debug helper (can cause incorrect gradients!!)
     /// Usable in cases where Enzyme errors with `can not deduce type of X`.
diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs
index 4f544d2c16b..5ed8cc17886 100644
--- a/compiler/rustc_session/src/options.rs
+++ b/compiler/rustc_session/src/options.rs
@@ -711,7 +711,7 @@ mod desc {
     pub(crate) const parse_list: &str = "a space-separated list of strings";
     pub(crate) const parse_list_with_polarity: &str =
         "a comma-separated list of strings, with elements beginning with + or -";
-    pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `LooseTypes`, `Inline`";
+    pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
     pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
     pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
     pub(crate) const parse_number: &str = "a number";
@@ -1359,6 +1359,7 @@ pub mod parse {
                 "PrintSteps" => AutoDiff::PrintSteps,
                 "PrintModBefore" => AutoDiff::PrintModBefore,
                 "PrintModAfter" => AutoDiff::PrintModAfter,
+                "PrintModFinal" => AutoDiff::PrintModFinal,
                 "LooseTypes" => AutoDiff::LooseTypes,
                 "Inline" => AutoDiff::Inline,
                 _ => {
@@ -2093,6 +2094,7 @@ options! {
         `=PrintSteps`
         `=PrintModBefore`
         `=PrintModAfter`
+        `=PrintModFinal`
         `=LooseTypes`
         `=Inline`
         Multiple options can be combined with commas."),
diff --git a/tests/codegen/autodiff.rs b/tests/codegen/autodiff.rs
index cace0edb2b5..85358f5fcb6 100644
--- a/tests/codegen/autodiff.rs
+++ b/tests/codegen/autodiff.rs
@@ -11,7 +11,7 @@ fn square(x: &f64) -> f64 {
     x * x
 }
 
-// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture align 8 %"x'"
+// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
 // CHECK-NEXT:invertstart:
 // CHECK-NEXT:  %_0 = fmul double %x.0.val, %x.0.val
 // CHECK-NEXT:  %0 = fadd fast double %x.0.val, %x.0.val
@@ -22,7 +22,7 @@ fn square(x: &f64) -> f64 {
 // CHECK-NEXT:}
 
 fn main() {
-    let x = 3.0;
+    let x = std::hint::black_box(3.0);
     let output = square(&x);
     assert_eq!(9.0, output);
 
diff --git a/tests/codegen/autodiffv.rs b/tests/codegen/autodiffv.rs
new file mode 100644
index 00000000000..e0047116405
--- /dev/null
+++ b/tests/codegen/autodiffv.rs
@@ -0,0 +1,116 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+//
+// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many
+// breakages. One benefit is that we match the IR generated by Enzyme only after running it
+// through LLVM's O3 pipeline, which will remove most of the noise.
+// However, our integration test could also be affected by changes in how rustc lowers MIR into
+// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should
+// reduce this test to only match the first lines and the ret instructions.
+
+#![feature(autodiff)]
+
+use std::autodiff::autodiff;
+
+#[autodiff(d_square3, Forward, Dual, DualOnly)]
+#[autodiff(d_square2, Forward, 4, Dual, DualOnly)]
+#[autodiff(d_square1, Forward, 4, Dual, Dual)]
+#[no_mangle]
+fn square(x: &f32) -> f32 {
+    x * x
+}
+
+// d_sqaure2
+// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
+// CHECK-NEXT: start:
+// CHECK-NEXT:   %0 = extractvalue [4 x ptr] %"x'", 0
+// CHECK-NEXT:   %"_2'ipl" = load float, ptr %0, align 4
+// CHECK-NEXT:   %1 = extractvalue [4 x ptr] %"x'", 1
+// CHECK-NEXT:   %"_2'ipl1" = load float, ptr %1, align 4
+// CHECK-NEXT:   %2 = extractvalue [4 x ptr] %"x'", 2
+// CHECK-NEXT:   %"_2'ipl2" = load float, ptr %2, align 4
+// CHECK-NEXT:   %3 = extractvalue [4 x ptr] %"x'", 3
+// CHECK-NEXT:   %"_2'ipl3" = load float, ptr %3, align 4
+// CHECK-NEXT:   %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
+// CHECK-NEXT:   %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
+// CHECK-NEXT:   %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
+// CHECK-NEXT:   %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
+// CHECK-NEXT:   %8 = fadd fast <4 x float> %7, %7
+// CHECK-NEXT:   %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
+// CHECK-NEXT:   %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:   %11 = fmul fast <4 x float> %8, %10
+// CHECK-NEXT:   %12 = extractelement <4 x float> %11, i64 0
+// CHECK-NEXT:   %13 = insertvalue [4 x float] undef, float %12, 0
+// CHECK-NEXT:   %14 = extractelement <4 x float> %11, i64 1
+// CHECK-NEXT:   %15 = insertvalue [4 x float] %13, float %14, 1
+// CHECK-NEXT:   %16 = extractelement <4 x float> %11, i64 2
+// CHECK-NEXT:   %17 = insertvalue [4 x float] %15, float %16, 2
+// CHECK-NEXT:   %18 = extractelement <4 x float> %11, i64 3
+// CHECK-NEXT:   %19 = insertvalue [4 x float] %17, float %18, 3
+// CHECK-NEXT:   ret [4 x float] %19
+// CHECK-NEXT: }
+
+// d_square3, the extra float is the original return value (x * x)
+// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
+// CHECK-NEXT: start:
+// CHECK-NEXT:   %0 = extractvalue [4 x ptr] %"x'", 0
+// CHECK-NEXT:   %"_2'ipl" = load float, ptr %0, align 4
+// CHECK-NEXT:   %1 = extractvalue [4 x ptr] %"x'", 1
+// CHECK-NEXT:   %"_2'ipl1" = load float, ptr %1, align 4
+// CHECK-NEXT:   %2 = extractvalue [4 x ptr] %"x'", 2
+// CHECK-NEXT:   %"_2'ipl2" = load float, ptr %2, align 4
+// CHECK-NEXT:   %3 = extractvalue [4 x ptr] %"x'", 3
+// CHECK-NEXT:   %"_2'ipl3" = load float, ptr %3, align 4
+// CHECK-NEXT:   %_0 = fmul float %x.0.val, %x.0.val
+// CHECK-NEXT:   %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
+// CHECK-NEXT:   %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
+// CHECK-NEXT:   %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
+// CHECK-NEXT:   %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
+// CHECK-NEXT:   %8 = fadd fast <4 x float> %7, %7
+// CHECK-NEXT:   %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
+// CHECK-NEXT:   %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:   %11 = fmul fast <4 x float> %8, %10
+// CHECK-NEXT:   %12 = extractelement <4 x float> %11, i64 0
+// CHECK-NEXT:   %13 = insertvalue [4 x float] undef, float %12, 0
+// CHECK-NEXT:   %14 = extractelement <4 x float> %11, i64 1
+// CHECK-NEXT:   %15 = insertvalue [4 x float] %13, float %14, 1
+// CHECK-NEXT:   %16 = extractelement <4 x float> %11, i64 2
+// CHECK-NEXT:   %17 = insertvalue [4 x float] %15, float %16, 2
+// CHECK-NEXT:   %18 = extractelement <4 x float> %11, i64 3
+// CHECK-NEXT:   %19 = insertvalue [4 x float] %17, float %18, 3
+// CHECK-NEXT:   %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
+// CHECK-NEXT:   %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
+// CHECK-NEXT:   ret { float, [4 x float] } %21
+// CHECK-NEXT: }
+
+fn main() {
+    let x = std::hint::black_box(3.0);
+    let output = square(&x);
+    dbg!(&output);
+    assert_eq!(9.0, output);
+    dbg!(square(&x));
+
+    let mut df_dx1 = 1.0;
+    let mut df_dx2 = 2.0;
+    let mut df_dx3 = 3.0;
+    let mut df_dx4 = 0.0;
+    let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
+    dbg!(o1, o2, o3, o4);
+    let [output2, o1, o2, o3, o4] =
+        d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
+    dbg!(o1, o2, o3, o4);
+    assert_eq!(output, output2);
+    assert!((6.0 - o1).abs() < 1e-10);
+    assert!((12.0 - o2).abs() < 1e-10);
+    assert!((18.0 - o3).abs() < 1e-10);
+    assert!((0.0 - o4).abs() < 1e-10);
+    assert_eq!(1.0, df_dx1);
+    assert_eq!(2.0, df_dx2);
+    assert_eq!(3.0, df_dx3);
+    assert_eq!(0.0, df_dx4);
+    assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1);
+    assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2);
+    assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3);
+    assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4);
+}
diff --git a/tests/pretty/autodiff_forward.pp b/tests/pretty/autodiff_forward.pp
index dc7a2712f42..4b2fb6166ff 100644
--- a/tests/pretty/autodiff_forward.pp
+++ b/tests/pretty/autodiff_forward.pp
@@ -25,27 +25,31 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
 
     // We want to be sure that the same function can be differentiated in different ways
 
+
+    // Make sure, that we add the None for the default return.
+
+
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Forward, Dual, Const, Dual,)]
+#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
 #[inline(never)]
-pub fn df1(x: &[f64], bx: &[f64], y: f64) -> (f64, f64) {
+pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f1(x, y));
-    ::core::hint::black_box((bx,));
-    ::core::hint::black_box((f1(x, y), f64::default()))
+    ::core::hint::black_box((bx_0,));
+    ::core::hint::black_box(<(f64, f64)>::default())
 }
 #[rustc_autodiff]
 #[inline(never)]
 pub fn f2(x: &[f64], y: f64) -> f64 {
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Forward, Dual, Const, Const,)]
+#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
 #[inline(never)]
-pub fn df2(x: &[f64], bx: &[f64], y: f64) -> f64 {
+pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f2(x, y));
-    ::core::hint::black_box((bx,));
+    ::core::hint::black_box((bx_0,));
     ::core::hint::black_box(f2(x, y))
 }
 #[rustc_autodiff]
@@ -53,20 +57,20 @@ pub fn df2(x: &[f64], bx: &[f64], y: f64) -> f64 {
 pub fn f3(x: &[f64], y: f64) -> f64 {
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Forward, Dual, Const, Const,)]
+#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
 #[inline(never)]
-pub fn df3(x: &[f64], bx: &[f64], y: f64) -> f64 {
+pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f3(x, y));
-    ::core::hint::black_box((bx,));
+    ::core::hint::black_box((bx_0,));
     ::core::hint::black_box(f3(x, y))
 }
 #[rustc_autodiff]
 #[inline(never)]
 pub fn f4() {}
-#[rustc_autodiff(Forward, None)]
+#[rustc_autodiff(Forward, 1, None)]
 #[inline(never)]
-pub fn df4() {
+pub fn df4() -> () {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f4());
     ::core::hint::black_box(());
@@ -76,28 +80,82 @@ pub fn df4() {
 pub fn f5(x: &[f64], y: f64) -> f64 {
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Forward, Const, Dual, Const,)]
+#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
 #[inline(never)]
-pub fn df5_y(x: &[f64], y: f64, by: f64) -> f64 {
+pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f5(x, y));
-    ::core::hint::black_box((by,));
+    ::core::hint::black_box((by_0,));
     ::core::hint::black_box(f5(x, y))
 }
-#[rustc_autodiff(Forward, Dual, Const, Const,)]
+#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
 #[inline(never)]
-pub fn df5_x(x: &[f64], bx: &[f64], y: f64) -> f64 {
+pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f5(x, y));
-    ::core::hint::black_box((bx,));
+    ::core::hint::black_box((bx_0,));
     ::core::hint::black_box(f5(x, y))
 }
-#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
+#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
 #[inline(never)]
-pub fn df5_rev(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
+pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f5(x, y));
-    ::core::hint::black_box((dx, dret));
+    ::core::hint::black_box((dx_0, dret));
     ::core::hint::black_box(f5(x, y))
 }
+struct DoesNotImplDefault;
+#[rustc_autodiff]
+#[inline(never)]
+pub fn f6() -> DoesNotImplDefault {
+    ::core::panicking::panic("not implemented")
+}
+#[rustc_autodiff(Forward, 1, Const)]
+#[inline(never)]
+pub fn df6() -> DoesNotImplDefault {
+    unsafe { asm!("NOP", options(pure, nomem)); };
+    ::core::hint::black_box(f6());
+    ::core::hint::black_box(());
+    ::core::hint::black_box(f6())
+}
+#[rustc_autodiff]
+#[inline(never)]
+pub fn f7(x: f32) -> () {}
+#[rustc_autodiff(Forward, 1, Const, None)]
+#[inline(never)]
+pub fn df7(x: f32) -> () {
+    unsafe { asm!("NOP", options(pure, nomem)); };
+    ::core::hint::black_box(f7(x));
+    ::core::hint::black_box(());
+}
+#[no_mangle]
+#[rustc_autodiff]
+#[inline(never)]
+fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
+#[rustc_autodiff(Forward, 4, Dual, Dual)]
+#[inline(never)]
+fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
+    -> [f32; 5usize] {
+    unsafe { asm!("NOP", options(pure, nomem)); };
+    ::core::hint::black_box(f8(x));
+    ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
+    ::core::hint::black_box(<[f32; 5usize]>::default())
+}
+#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
+#[inline(never)]
+fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
+    -> [f32; 4usize] {
+    unsafe { asm!("NOP", options(pure, nomem)); };
+    ::core::hint::black_box(f8(x));
+    ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
+    ::core::hint::black_box(<[f32; 4usize]>::default())
+}
+#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
+#[inline(never)]
+fn f8_1(x: &f32, bx_0: &f32) -> f32 {
+    unsafe { asm!("NOP", options(pure, nomem)); };
+    ::core::hint::black_box(f8(x));
+    ::core::hint::black_box((bx_0,));
+    ::core::hint::black_box(<f32>::default())
+}
 fn main() {}
diff --git a/tests/pretty/autodiff_forward.rs b/tests/pretty/autodiff_forward.rs
index bc558211632..a765738c2a8 100644
--- a/tests/pretty/autodiff_forward.rs
+++ b/tests/pretty/autodiff_forward.rs
@@ -36,4 +36,22 @@ pub fn f5(x: &[f64], y: f64) -> f64 {
     unimplemented!()
 }
 
+struct DoesNotImplDefault;
+#[autodiff(df6, Forward, Const)]
+pub fn f6() -> DoesNotImplDefault {
+    unimplemented!()
+}
+
+// Make sure, that we add the None for the default return.
+#[autodiff(df7, Forward, Const)]
+pub fn f7(x: f32) -> () {}
+
+#[autodiff(f8_1, Forward, Dual, DualOnly)]
+#[autodiff(f8_2, Forward, 4, Dual, DualOnly)]
+#[autodiff(f8_3, Forward, 4, Dual, Dual)]
+#[no_mangle]
+fn f8(x: &f32) -> f32 {
+    unimplemented!()
+}
+
 fn main() {}
diff --git a/tests/pretty/autodiff_reverse.pp b/tests/pretty/autodiff_reverse.pp
index b2cf0244af4..31920694a3a 100644
--- a/tests/pretty/autodiff_reverse.pp
+++ b/tests/pretty/autodiff_reverse.pp
@@ -28,18 +28,18 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
 
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
+#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
 #[inline(never)]
-pub fn df1(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
+pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f1(x, y));
-    ::core::hint::black_box((dx, dret));
+    ::core::hint::black_box((dx_0, dret));
     ::core::hint::black_box(f1(x, y))
 }
 #[rustc_autodiff]
 #[inline(never)]
 pub fn f2() {}
-#[rustc_autodiff(Reverse, None)]
+#[rustc_autodiff(Reverse, 1, None)]
 #[inline(never)]
 pub fn df2() {
     unsafe { asm!("NOP", options(pure, nomem)); };
@@ -51,12 +51,12 @@ pub fn df2() {
 pub fn f3(x: &[f64], y: f64) -> f64 {
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
+#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
 #[inline(never)]
-pub fn df3(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
+pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f3(x, y));
-    ::core::hint::black_box((dx, dret));
+    ::core::hint::black_box((dx_0, dret));
     ::core::hint::black_box(f3(x, y))
 }
 enum Foo { Reverse, }
@@ -64,7 +64,7 @@ use Foo::Reverse;
 #[rustc_autodiff]
 #[inline(never)]
 pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
-#[rustc_autodiff(Reverse, Const, None)]
+#[rustc_autodiff(Reverse, 1, Const, None)]
 #[inline(never)]
 pub fn df4(x: f32) {
     unsafe { asm!("NOP", options(pure, nomem)); };
@@ -76,11 +76,11 @@ pub fn df4(x: f32) {
 pub fn f5(x: *const f32, y: &f32) {
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Reverse, DuplicatedOnly, Duplicated, None)]
+#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
 #[inline(never)]
-pub unsafe fn df5(x: *const f32, dx: *mut f32, y: &f32, dy: &mut f32) {
+pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f5(x, y));
-    ::core::hint::black_box((dx, dy));
+    ::core::hint::black_box((dx_0, dy_0));
 }
 fn main() {}
diff --git a/tests/ui/autodiff/autodiff_illegal.rs b/tests/ui/autodiff/autodiff_illegal.rs
index e810b9ba565..2f2cd8d9353 100644
--- a/tests/ui/autodiff/autodiff_illegal.rs
+++ b/tests/ui/autodiff/autodiff_illegal.rs
@@ -177,4 +177,11 @@ fn f21(x: f32) -> f32 {
     unimplemented!()
 }
 
+struct DoesNotImplDefault;
+#[autodiff(df22, Forward, Dual)]
+pub fn f22() -> DoesNotImplDefault {
+    //~^^ ERROR the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
+    unimplemented!()
+}
+
 fn main() {}
diff --git a/tests/ui/autodiff/autodiff_illegal.stderr b/tests/ui/autodiff/autodiff_illegal.stderr
index 47d53492700..3752b27e7dd 100644
--- a/tests/ui/autodiff/autodiff_illegal.stderr
+++ b/tests/ui/autodiff/autodiff_illegal.stderr
@@ -19,32 +19,24 @@ error: expected 1 activities, but found 2
    |
 LL | #[autodiff(df3, Reverse, Duplicated, Const)]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-   |
-   = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
 
 error: expected 1 activities, but found 0
   --> $DIR/autodiff_illegal.rs:27:1
    |
 LL | #[autodiff(df4, Reverse)]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^
-   |
-   = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
 
 error: Dual can not be used in Reverse Mode
   --> $DIR/autodiff_illegal.rs:34:1
    |
 LL | #[autodiff(df5, Reverse, Dual)]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-   |
-   = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
 
 error: Duplicated can not be used in Forward Mode
   --> $DIR/autodiff_illegal.rs:41:1
    |
 LL | #[autodiff(df6, Forward, Duplicated)]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-   |
-   = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
 
 error: Duplicated can not be used for this type
   --> $DIR/autodiff_illegal.rs:42:14
@@ -107,7 +99,6 @@ LL | #[autodiff(fn_exists, Reverse, Active)]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `fn_exists` redefined here
    |
    = note: `fn_exists` must be defined only once in the value namespace of this module
-   = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
 
 error: autodiff requires at least a name and mode
   --> $DIR/autodiff_illegal.rs:95:1
@@ -135,42 +126,49 @@ error: invalid return activity Active in Forward Mode
    |
 LL | #[autodiff(df19, Forward, Dual, Active)]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-   |
-   = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
 
 error: invalid return activity Dual in Reverse Mode
   --> $DIR/autodiff_illegal.rs:167:1
    |
 LL | #[autodiff(df20, Reverse, Active, Dual)]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-   |
-   = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
 
 error: invalid return activity Duplicated in Reverse Mode
   --> $DIR/autodiff_illegal.rs:174:1
    |
 LL | #[autodiff(df21, Reverse, Active, Duplicated)]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-   |
-   = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
 
 error[E0433]: failed to resolve: use of undeclared type `MyFloat`
   --> $DIR/autodiff_illegal.rs:130:1
    |
 LL | #[autodiff(df15, Reverse, Active, Active)]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `MyFloat`
-   |
-   = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
 
 error[E0433]: failed to resolve: use of undeclared type `F64Trans`
   --> $DIR/autodiff_illegal.rs:154:1
    |
 LL | #[autodiff(df18, Reverse, Active, Active)]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `F64Trans`
+
+error[E0599]: the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
+  --> $DIR/autodiff_illegal.rs:181:1
+   |
+LL | struct DoesNotImplDefault;
+   | ------------------------- doesn't satisfy `DoesNotImplDefault: Default`
+LL | #[autodiff(df22, Forward, Dual)]
+   | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ function or associated item cannot be called on `(DoesNotImplDefault, DoesNotImplDefault)` due to unsatisfied trait bounds
+   |
+   = note: the following trait bounds were not satisfied:
+           `DoesNotImplDefault: Default`
+           which is required by `(DoesNotImplDefault, DoesNotImplDefault): Default`
+help: consider annotating `DoesNotImplDefault` with `#[derive(Default)]`
+   |
+LL + #[derive(Default)]
+LL | struct DoesNotImplDefault;
    |
-   = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
 
-error: aborting due to 22 previous errors
+error: aborting due to 23 previous errors
 
-Some errors have detailed explanations: E0428, E0433, E0658.
+Some errors have detailed explanations: E0428, E0433, E0599, E0658.
 For more information about an error, try `rustc --explain E0428`.