about summary refs log tree commit diff
diff options
context:
space:
mode:
authorManuel Drehwald <git@manuel.drehwald.info>2024-10-11 19:13:31 +0200
committerManuel Drehwald <git@manuel.drehwald.info>2024-10-11 19:13:31 +0200
commit624c071b997f3bab0ef086af24f59ebc95f37c59 (patch)
tree44091cd44e62db9aa4d0f0b6053de7b84769a5cb
parent52fd9983996d9fcfb719749838336be66dee68f9 (diff)
downloadrust-624c071b997f3bab0ef086af24f59ebc95f37c59.tar.gz
rust-624c071b997f3bab0ef086af24f59ebc95f37c59.zip
Single commit implementing the enzyme/autodiff frontend
Co-authored-by: Lorenz Schmidt <bytesnake@mailbox.org>
-rw-r--r--compiler/rustc_ast/src/expand/autodiff_attrs.rs283
-rw-r--r--compiler/rustc_ast/src/expand/mod.rs2
-rw-r--r--compiler/rustc_ast/src/expand/typetree.rs90
-rw-r--r--compiler/rustc_builtin_macros/Cargo.toml4
-rw-r--r--compiler/rustc_builtin_macros/messages.ftl9
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs820
-rw-r--r--compiler/rustc_builtin_macros/src/errors.rs72
-rw-r--r--compiler/rustc_builtin_macros/src/lib.rs3
-rw-r--r--compiler/rustc_expand/src/build.rs29
-rw-r--r--compiler/rustc_feature/src/builtin_attrs.rs5
-rw-r--r--compiler/rustc_passes/messages.ftl4
-rw-r--r--compiler/rustc_passes/src/check_attr.rs15
-rw-r--r--compiler/rustc_passes/src/errors.rs8
-rw-r--r--compiler/rustc_span/src/symbol.rs5
-rw-r--r--library/core/src/lib.rs9
-rw-r--r--library/core/src/macros/mod.rs18
-rw-r--r--library/std/src/lib.rs9
17 files changed, 1384 insertions, 1 deletions
diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
new file mode 100644
index 00000000000..05714731b9d
--- /dev/null
+++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
@@ -0,0 +1,283 @@
+//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
+//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
+//! is the function to which the autodiff attribute is applied, and the target is the function
+//! getting generated by us (with a name given by the user as the first autodiff arg).
+
+use std::fmt::{self, Display, Formatter};
+use std::str::FromStr;
+
+use crate::expand::typetree::TypeTree;
+use crate::expand::{Decodable, Encodable, HashStable_Generic};
+use crate::ptr::P;
+use crate::{Ty, TyKind};
+
+/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
+/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
+/// are a hack to support higher order derivatives. We need to compute first order derivatives
+/// before we compute second order derivatives, otherwise we would differentiate our placeholder
+/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
+/// as it's already done in the C++ and Julia frontend of Enzyme.
+///
+/// (FIXME) remove *First variants.
+/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
+/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
+#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub enum DiffMode {
+    /// No autodiff is applied (used during error handling).
+    Error,
+    /// The primal function which we will differentiate.
+    Source,
+    /// The target function, to be created using forward mode AD.
+    Forward,
+    /// The target function, to be created using reverse mode AD.
+    Reverse,
+    /// The target function, to be created using forward mode AD.
+    /// This target function will also be used as a source for higher order derivatives,
+    /// so compute it before all Forward/Reverse targets and optimize it through llvm.
+    ForwardFirst,
+    /// The target function, to be created using reverse mode AD.
+    /// This target function will also be used as a source for higher order derivatives,
+    /// so compute it before all Forward/Reverse targets and optimize it through llvm.
+    ReverseFirst,
+}
+
+/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
+/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
+/// we add to the previous shadow value. To not surprise users, we picked different names.
+/// Dual numbers is also a quite well known name for forward mode AD types.
+#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub enum DiffActivity {
+    /// Implicit or Explicit () return type, so a special case of Const.
+    None,
+    /// Don't compute derivatives with respect to this input/output.
+    Const,
+    /// Reverse Mode, Compute derivatives for this scalar input/output.
+    Active,
+    /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
+    /// the original return value.
+    ActiveOnly,
+    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
+    /// with it.
+    Dual,
+    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
+    /// with it. Drop the code which updates the original input/output for maximum performance.
+    DualOnly,
+    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
+    Duplicated,
+    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
+    /// Drop the code which updates the original input for maximum performance.
+    DuplicatedOnly,
+    /// All Integers must be Const, but these are used to mark the integer which represents the
+    /// length of a slice/vec. This is used for safety checks on slices.
+    FakeActivitySize,
+}
+/// We generate one of these structs for each `#[autodiff(...)]` attribute.
+#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub struct AutoDiffItem {
+    /// The name of the function getting differentiated
+    pub source: String,
+    /// The name of the function being generated
+    pub target: String,
+    pub attrs: AutoDiffAttrs,
+    /// Describe the memory layout of input types
+    pub inputs: Vec<TypeTree>,
+    /// Describe the memory layout of the output type
+    pub output: TypeTree,
+}
+#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub struct AutoDiffAttrs {
+    /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
+    /// e.g. in the [JAX
+    /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
+    pub mode: DiffMode,
+    pub ret_activity: DiffActivity,
+    pub input_activity: Vec<DiffActivity>,
+}
+
+impl DiffMode {
+    pub fn is_rev(&self) -> bool {
+        matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst)
+    }
+    pub fn is_fwd(&self) -> bool {
+        matches!(self, DiffMode::Forward | DiffMode::ForwardFirst)
+    }
+}
+
+impl Display for DiffMode {
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+        match self {
+            DiffMode::Error => write!(f, "Error"),
+            DiffMode::Source => write!(f, "Source"),
+            DiffMode::Forward => write!(f, "Forward"),
+            DiffMode::Reverse => write!(f, "Reverse"),
+            DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
+            DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
+        }
+    }
+}
+
+/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
+/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
+/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
+/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
+/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
+pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
+    if activity == DiffActivity::None {
+        // Only valid if primal returns (), but we can't check that here.
+        return true;
+    }
+    match mode {
+        DiffMode::Error => false,
+        DiffMode::Source => false,
+        DiffMode::Forward | DiffMode::ForwardFirst => {
+            activity == DiffActivity::Dual
+                || activity == DiffActivity::DualOnly
+                || activity == DiffActivity::Const
+        }
+        DiffMode::Reverse | DiffMode::ReverseFirst => {
+            activity == DiffActivity::Const
+                || activity == DiffActivity::Active
+                || activity == DiffActivity::ActiveOnly
+        }
+    }
+}
+
+/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
+/// for the given argument, but we generally can't know the size of such a type.
+/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
+/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
+/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
+/// users here from marking scalars as Duplicated, due to type aliases.
+pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
+    use DiffActivity::*;
+    // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
+    if matches!(activity, Const) {
+        return true;
+    }
+    if matches!(activity, Dual | DualOnly) {
+        return true;
+    }
+    // FIXME(ZuseZ4) We should make this more robust to also
+    // handle type aliases. Once that is done, we can be more restrictive here.
+    if matches!(activity, Active | ActiveOnly) {
+        return true;
+    }
+    matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
+        && matches!(activity, Duplicated | DuplicatedOnly)
+}
+pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
+    use DiffActivity::*;
+    return match mode {
+        DiffMode::Error => false,
+        DiffMode::Source => false,
+        DiffMode::Forward | DiffMode::ForwardFirst => {
+            matches!(activity, Dual | DualOnly | Const)
+        }
+        DiffMode::Reverse | DiffMode::ReverseFirst => {
+            matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
+        }
+    };
+}
+
+impl Display for DiffActivity {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        match self {
+            DiffActivity::None => write!(f, "None"),
+            DiffActivity::Const => write!(f, "Const"),
+            DiffActivity::Active => write!(f, "Active"),
+            DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
+            DiffActivity::Dual => write!(f, "Dual"),
+            DiffActivity::DualOnly => write!(f, "DualOnly"),
+            DiffActivity::Duplicated => write!(f, "Duplicated"),
+            DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
+            DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
+        }
+    }
+}
+
+impl FromStr for DiffMode {
+    type Err = ();
+
+    fn from_str(s: &str) -> Result<DiffMode, ()> {
+        match s {
+            "Error" => Ok(DiffMode::Error),
+            "Source" => Ok(DiffMode::Source),
+            "Forward" => Ok(DiffMode::Forward),
+            "Reverse" => Ok(DiffMode::Reverse),
+            "ForwardFirst" => Ok(DiffMode::ForwardFirst),
+            "ReverseFirst" => Ok(DiffMode::ReverseFirst),
+            _ => Err(()),
+        }
+    }
+}
+impl FromStr for DiffActivity {
+    type Err = ();
+
+    fn from_str(s: &str) -> Result<DiffActivity, ()> {
+        match s {
+            "None" => Ok(DiffActivity::None),
+            "Active" => Ok(DiffActivity::Active),
+            "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
+            "Const" => Ok(DiffActivity::Const),
+            "Dual" => Ok(DiffActivity::Dual),
+            "DualOnly" => Ok(DiffActivity::DualOnly),
+            "Duplicated" => Ok(DiffActivity::Duplicated),
+            "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
+            _ => Err(()),
+        }
+    }
+}
+
+impl AutoDiffAttrs {
+    pub fn has_ret_activity(&self) -> bool {
+        self.ret_activity != DiffActivity::None
+    }
+    pub fn has_active_only_ret(&self) -> bool {
+        self.ret_activity == DiffActivity::ActiveOnly
+    }
+
+    pub fn error() -> Self {
+        AutoDiffAttrs {
+            mode: DiffMode::Error,
+            ret_activity: DiffActivity::None,
+            input_activity: Vec::new(),
+        }
+    }
+    pub fn source() -> Self {
+        AutoDiffAttrs {
+            mode: DiffMode::Source,
+            ret_activity: DiffActivity::None,
+            input_activity: Vec::new(),
+        }
+    }
+
+    pub fn is_active(&self) -> bool {
+        self.mode != DiffMode::Error
+    }
+
+    pub fn is_source(&self) -> bool {
+        self.mode == DiffMode::Source
+    }
+    pub fn apply_autodiff(&self) -> bool {
+        !matches!(self.mode, DiffMode::Error | DiffMode::Source)
+    }
+
+    pub fn into_item(
+        self,
+        source: String,
+        target: String,
+        inputs: Vec<TypeTree>,
+        output: TypeTree,
+    ) -> AutoDiffItem {
+        AutoDiffItem { source, target, inputs, output, attrs: self }
+    }
+}
+
+impl fmt::Display for AutoDiffItem {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "Differentiating {} -> {}", self.source, self.target)?;
+        write!(f, " with attributes: {:?}", self.attrs)?;
+        write!(f, " with inputs: {:?}", self.inputs)?;
+        write!(f, " with output: {:?}", self.output)
+    }
+}
diff --git a/compiler/rustc_ast/src/expand/mod.rs b/compiler/rustc_ast/src/expand/mod.rs
index 13413281bc7..d259677e98e 100644
--- a/compiler/rustc_ast/src/expand/mod.rs
+++ b/compiler/rustc_ast/src/expand/mod.rs
@@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
 use crate::MetaItem;
 
 pub mod allocator;
+pub mod autodiff_attrs;
+pub mod typetree;
 
 #[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
 pub struct StrippedCfgItem<ModId = DefId> {
diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs
new file mode 100644
index 00000000000..9a2dd2e85e0
--- /dev/null
+++ b/compiler/rustc_ast/src/expand/typetree.rs
@@ -0,0 +1,90 @@
+//! This module contains the definition of the `TypeTree` and `Type` structs.
+//! They are thin Rust wrappers around the TypeTrees used by Enzyme as the LLVM based autodiff
+//! backend. The Enzyme TypeTrees currently have various limitations and should be rewritten, so the
+//! Rust frontend obviously has the same limitations. The main motivation of TypeTrees is to
+//! represent how a type looks like "in memory". Enzyme can deduce this based on usage patterns in
+//! the user code, but this is extremely slow and not even always sufficient. As such we lower some
+//! information from rustc to help Enzyme. For a full explanation of their design it is necessary to
+//! analyze the implementation in Enzyme core itself. As a rough summary, `-1` in Enzyme speech means
+//! everywhere. That is `{0:-1: Float}` means at index 0 you have a ptr, if you dereference it it
+//! will be floats everywhere. Thus `* f32`. If you have `{-1:int}` it means int's everywhere,
+//! e.g. [i32; N]. `{0:-1:-1 float}` then means one pointer at offset 0, if you dereference it there
+//! will be only pointers, if you dereference these new pointers they will point to array of floats.
+//! Generally, it allows byte-specific descriptions.
+//! FIXME: This description might be partly inaccurate and should be extended, along with
+//! adding documentation to the corresponding Enzyme core code.
+//! FIXME: Rewrite the TypeTree logic in Enzyme core to reduce the need for the rustc frontend to
+//! provide typetree information.
+//! FIXME: We should also re-evaluate where we create TypeTrees from Rust types, since MIR
+//! representations of some types might not be accurate. For example a vector of floats might be
+//! represented as a vector of u8s in MIR in some cases.
+
+use std::fmt;
+
+use crate::expand::{Decodable, Encodable, HashStable_Generic};
+
+#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub enum Kind {
+    Anything,
+    Integer,
+    Pointer,
+    Half,
+    Float,
+    Double,
+    Unknown,
+}
+
+#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub struct TypeTree(pub Vec<Type>);
+
+impl TypeTree {
+    pub fn new() -> Self {
+        Self(Vec::new())
+    }
+    pub fn all_ints() -> Self {
+        Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
+    }
+    pub fn int(size: usize) -> Self {
+        let mut ints = Vec::with_capacity(size);
+        for i in 0..size {
+            ints.push(Type {
+                offset: i as isize,
+                size: 1,
+                kind: Kind::Integer,
+                child: TypeTree::new(),
+            });
+        }
+        Self(ints)
+    }
+}
+
+#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub struct FncTree {
+    pub args: Vec<TypeTree>,
+    pub ret: TypeTree,
+}
+
+#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub struct Type {
+    pub offset: isize,
+    pub size: usize,
+    pub kind: Kind,
+    pub child: TypeTree,
+}
+
+impl Type {
+    pub fn add_offset(self, add: isize) -> Self {
+        let offset = match self.offset {
+            -1 => add,
+            x => add + x,
+        };
+
+        Self { size: self.size, kind: self.kind, child: self.child, offset }
+    }
+}
+
+impl fmt::Display for Type {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        <Self as fmt::Debug>::fmt(self, f)
+    }
+}
diff --git a/compiler/rustc_builtin_macros/Cargo.toml b/compiler/rustc_builtin_macros/Cargo.toml
index 21b87be4b81..ef48486f6f1 100644
--- a/compiler/rustc_builtin_macros/Cargo.toml
+++ b/compiler/rustc_builtin_macros/Cargo.toml
@@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
 version = "0.0.0"
 edition = "2021"
 
+
+[lints.rust]
+unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }
+
 [lib]
 doctest = false
 
diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl
index 77cb8dc63c4..6ebc2fd870c 100644
--- a/compiler/rustc_builtin_macros/messages.ftl
+++ b/compiler/rustc_builtin_macros/messages.ftl
@@ -69,6 +69,15 @@ builtin_macros_assert_requires_boolean = macro requires a boolean expression as
 builtin_macros_assert_requires_expression = macro requires an expression as an argument
     .suggestion = try removing semicolon
 
+builtin_macros_autodiff = autodiff must be applied to function
+builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
+builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
+builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
+builtin_macros_autodiff_not_build = this rustc version does not support autodiff
+builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
+builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
+
+builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
 builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
     .label = not applicable here
     .label2 = not a `struct`, `enum` or `union`
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
new file mode 100644
index 00000000000..66bb11ca522
--- /dev/null
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -0,0 +1,820 @@
+//! This module contains the implementation of the `#[autodiff]` attribute.
+//! Currently our linter isn't smart enough to see that each import is used in one of the two
+//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.
+//! FIXME(ZuseZ4): Remove this once we have a smarter linter.
+
+#[cfg(llvm_enzyme)]
+mod llvm_enzyme {
+    use std::str::FromStr;
+    use std::string::String;
+
+    use rustc_ast::expand::autodiff_attrs::{
+        AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ty_for_activity,
+    };
+    use rustc_ast::ptr::P;
+    use rustc_ast::token::{Token, TokenKind};
+    use rustc_ast::tokenstream::*;
+    use rustc_ast::visit::AssocCtxt::*;
+    use rustc_ast::{
+        self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
+        PatKind, TyKind,
+    };
+    use rustc_expand::base::{Annotatable, ExtCtxt};
+    use rustc_span::symbol::{Ident, kw, sym};
+    use rustc_span::{Span, Symbol};
+    use thin_vec::{ThinVec, thin_vec};
+    use tracing::{debug, trace};
+
+    use crate::errors;
+
+    // If we have a default `()` return type or explicitley `()` return type,
+    // then we often can skip doing some work.
+    fn has_ret(ty: &FnRetTy) -> bool {
+        match ty {
+            FnRetTy::Ty(ty) => !ty.kind.is_unit(),
+            FnRetTy::Default(_) => false,
+        }
+    }
+    fn first_ident(x: &MetaItemInner) -> rustc_span::symbol::Ident {
+        let segments = &x.meta_item().unwrap().path.segments;
+        assert!(segments.len() == 1);
+        segments[0].ident
+    }
+
+    fn name(x: &MetaItemInner) -> String {
+        first_ident(x).name.to_string()
+    }
+
+    pub(crate) fn from_ast(
+        ecx: &mut ExtCtxt<'_>,
+        meta_item: &ThinVec<MetaItemInner>,
+        has_ret: bool,
+    ) -> AutoDiffAttrs {
+        let dcx = ecx.sess.dcx();
+        let mode = name(&meta_item[1]);
+        let Ok(mode) = DiffMode::from_str(&mode) else {
+            dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
+            return AutoDiffAttrs::error();
+        };
+        let mut activities: Vec<DiffActivity> = vec![];
+        let mut errors = false;
+        for x in &meta_item[2..] {
+            let activity_str = name(&x);
+            let res = DiffActivity::from_str(&activity_str);
+            match res {
+                Ok(x) => activities.push(x),
+                Err(_) => {
+                    dcx.emit_err(errors::AutoDiffUnknownActivity {
+                        span: x.span(),
+                        act: activity_str,
+                    });
+                    errors = true;
+                }
+            };
+        }
+        if errors {
+            return AutoDiffAttrs::error();
+        }
+
+        // If a return type exist, we need to split the last activity,
+        // otherwise we return None as placeholder.
+        let (ret_activity, input_activity) = if has_ret {
+            let Some((last, rest)) = activities.split_last() else {
+                unreachable!(
+                    "should not be reachable because we counted the number of activities previously"
+                );
+            };
+            (last, rest)
+        } else {
+            (&DiffActivity::None, activities.as_slice())
+        };
+
+        AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
+    }
+
+    /// We expand the autodiff macro to generate a new placeholder function which passes
+    /// type-checking and can be called by users. The function body of the placeholder function will
+    /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
+    /// should just prevent early inlining and optimizations which alter the function signature.
+    /// The exact signature of the generated function depends on the configuration provided by the
+    /// user, but here is an example:
+    ///
+    /// ```
+    /// #[autodiff(cos_box, Reverse, Duplicated, Active)]
+    /// fn sin(x: &Box<f32>) -> f32 {
+    ///     f32::sin(**x)
+    /// }
+    /// ```
+    /// which becomes expanded to:
+    /// ```
+    /// #[rustc_autodiff]
+    /// #[inline(never)]
+    /// fn sin(x: &Box<f32>) -> f32 {
+    ///     f32::sin(**x)
+    /// }
+    /// #[rustc_autodiff(Reverse, Duplicated, Active)]
+    /// #[inline(never)]
+    /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
+    ///     unsafe {
+    ///         asm!("NOP");
+    ///     };
+    ///     ::core::hint::black_box(sin(x));
+    ///     ::core::hint::black_box((dx, dret));
+    ///     ::core::hint::black_box(sin(x))
+    /// }
+    /// ```
+    /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
+    /// in CI.
+    pub(crate) fn expand(
+        ecx: &mut ExtCtxt<'_>,
+        expand_span: Span,
+        meta_item: &ast::MetaItem,
+        mut item: Annotatable,
+    ) -> Vec<Annotatable> {
+        let dcx = ecx.sess.dcx();
+        // first get the annotable item:
+        let (sig, is_impl): (FnSig, bool) = match &item {
+            Annotatable::Item(ref iitem) => {
+                let sig = match &iitem.kind {
+                    ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
+                    _ => {
+                        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
+                        return vec![item];
+                    }
+                };
+                (sig.clone(), false)
+            }
+            Annotatable::AssocItem(ref assoc_item, _) => {
+                let sig = match &assoc_item.kind {
+                    ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig,
+                    _ => {
+                        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
+                        return vec![item];
+                    }
+                };
+                (sig.clone(), true)
+            }
+            _ => {
+                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
+                return vec![item];
+            }
+        };
+
+        let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
+            ast::MetaItemKind::List(ref vec) => vec.clone(),
+            _ => {
+                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
+                return vec![item];
+            }
+        };
+
+        let has_ret = has_ret(&sig.decl.output);
+        let sig_span = ecx.with_call_site_ctxt(sig.span);
+
+        let (vis, primal) = match &item {
+            Annotatable::Item(ref iitem) => (iitem.vis.clone(), iitem.ident.clone()),
+            Annotatable::AssocItem(ref assoc_item, _) => {
+                (assoc_item.vis.clone(), assoc_item.ident.clone())
+            }
+            _ => {
+                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
+                return vec![item];
+            }
+        };
+
+        // create TokenStream from vec elemtents:
+        // meta_item doesn't have a .tokens field
+        let comma: Token = Token::new(TokenKind::Comma, Span::default());
+        let mut ts: Vec<TokenTree> = vec![];
+        if meta_item_vec.len() < 2 {
+            // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
+            // input and output args.
+            dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
+            return vec![item];
+        } else {
+            for t in meta_item_vec.clone()[1..].iter() {
+                let val = first_ident(t);
+                let t = Token::from_ast_ident(val);
+                ts.push(TokenTree::Token(t, Spacing::Joint));
+                ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
+            }
+        }
+        if !has_ret {
+            // We don't want users to provide a return activity if the function doesn't return anything.
+            // For simplicity, we just add a dummy token to the end of the list.
+            let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
+            ts.push(TokenTree::Token(t, Spacing::Joint));
+        }
+        let ts: TokenStream = TokenStream::from_iter(ts);
+
+        let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
+        if !x.is_active() {
+            // We encountered an error, so we return the original item.
+            // This allows us to potentially parse other attributes.
+            return vec![item];
+        }
+        let span = ecx.with_def_site_ctxt(expand_span);
+
+        let n_active: u32 = x
+            .input_activity
+            .iter()
+            .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
+            .count() as u32;
+        let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
+        let new_decl_span = d_sig.span;
+        let d_body = gen_enzyme_body(
+            ecx,
+            &x,
+            n_active,
+            &sig,
+            &d_sig,
+            primal,
+            &new_args,
+            span,
+            sig_span,
+            new_decl_span,
+            idents,
+            errored,
+        );
+        let d_ident = first_ident(&meta_item_vec[0]);
+
+        // The first element of it is the name of the function to be generated
+        let asdf = Box::new(ast::Fn {
+            defaultness: ast::Defaultness::Final,
+            sig: d_sig,
+            generics: Generics::default(),
+            body: Some(d_body),
+        });
+        let mut rustc_ad_attr =
+            P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
+
+        let ts2: Vec<TokenTree> = vec![TokenTree::Token(
+            Token::new(TokenKind::Ident(sym::never, false.into()), span),
+            Spacing::Joint,
+        )];
+        let never_arg = ast::DelimArgs {
+            dspan: ast::tokenstream::DelimSpan::from_single(span),
+            delim: ast::token::Delimiter::Parenthesis,
+            tokens: ast::tokenstream::TokenStream::from_iter(ts2),
+        };
+        let inline_item = ast::AttrItem {
+            unsafety: ast::Safety::Default,
+            path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
+            args: ast::AttrArgs::Delimited(never_arg),
+            tokens: None,
+        };
+        let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
+        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
+        let attr: ast::Attribute = ast::Attribute {
+            kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
+            id: new_id,
+            style: ast::AttrStyle::Outer,
+            span,
+        };
+        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
+        let inline_never: ast::Attribute = ast::Attribute {
+            kind: ast::AttrKind::Normal(inline_never_attr),
+            id: new_id,
+            style: ast::AttrStyle::Outer,
+            span,
+        };
+
+        // Don't add it multiple times:
+        let orig_annotatable: Annotatable = match item {
+            Annotatable::Item(ref mut iitem) => {
+                if !iitem.attrs.iter().any(|a| a.id == attr.id) {
+                    iitem.attrs.push(attr.clone());
+                }
+                if !iitem.attrs.iter().any(|a| a.id == inline_never.id) {
+                    iitem.attrs.push(inline_never.clone());
+                }
+                Annotatable::Item(iitem.clone())
+            }
+            Annotatable::AssocItem(ref mut assoc_item, i @ Impl) => {
+                if !assoc_item.attrs.iter().any(|a| a.id == attr.id) {
+                    assoc_item.attrs.push(attr.clone());
+                }
+                if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) {
+                    assoc_item.attrs.push(inline_never.clone());
+                }
+                Annotatable::AssocItem(assoc_item.clone(), i)
+            }
+            _ => {
+                unreachable!("annotatable kind checked previously")
+            }
+        };
+        // Now update for d_fn
+        rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
+            dspan: DelimSpan::dummy(),
+            delim: rustc_ast::token::Delimiter::Parenthesis,
+            tokens: ts,
+        });
+        let d_attr: ast::Attribute = ast::Attribute {
+            kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
+            id: new_id,
+            style: ast::AttrStyle::Outer,
+            span,
+        };
+
+        let d_annotatable = if is_impl {
+            let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
+            let d_fn = P(ast::AssocItem {
+                attrs: thin_vec![d_attr.clone(), inline_never],
+                id: ast::DUMMY_NODE_ID,
+                span,
+                vis,
+                ident: d_ident,
+                kind: assoc_item,
+                tokens: None,
+            });
+            Annotatable::AssocItem(d_fn, Impl)
+        } else {
+            let mut d_fn = ecx.item(
+                span,
+                d_ident,
+                thin_vec![d_attr.clone(), inline_never],
+                ItemKind::Fn(asdf),
+            );
+            d_fn.vis = vis;
+            Annotatable::Item(d_fn)
+        };
+
+        return vec![orig_annotatable, d_annotatable];
+    }
+
+    // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
+    // mutable references or ptrs, because Enzyme will write into them.
+    fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
+        let mut ty = ty.clone();
+        match ty.kind {
+            TyKind::Ptr(ref mut mut_ty) => {
+                mut_ty.mutbl = ast::Mutability::Mut;
+            }
+            TyKind::Ref(_, ref mut mut_ty) => {
+                mut_ty.mutbl = ast::Mutability::Mut;
+            }
+            _ => {
+                panic!("unsupported type: {:?}", ty);
+            }
+        }
+        ty
+    }
+
+    /// We only want this function to type-check, since we will replace the body
+    /// later on llvm level. Using `loop {}` does not cover all return types anymore,
+    /// so instead we build something that should pass. We also add a inline_asm
+    /// line, as one more barrier for rustc to prevent inlining of this function.
+    /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
+    /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
+    /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
+    /// this function (which should never happen, since it is only a placeholder).
+    /// Finally, we also add back_box usages of all input arguments, to prevent rustc
+    /// from optimizing any arguments away.
+    fn gen_enzyme_body(
+        ecx: &ExtCtxt<'_>,
+        x: &AutoDiffAttrs,
+        n_active: u32,
+        sig: &ast::FnSig,
+        d_sig: &ast::FnSig,
+        primal: Ident,
+        new_names: &[String],
+        span: Span,
+        sig_span: Span,
+        new_decl_span: Span,
+        idents: Vec<Ident>,
+        errored: bool,
+    ) -> P<ast::Block> {
+        let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
+        let noop = ast::InlineAsm {
+            asm_macro: ast::AsmMacro::Asm,
+            template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
+            template_strs: Box::new([]),
+            operands: vec![],
+            clobber_abis: vec![],
+            options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
+            line_spans: vec![],
+        };
+        let noop_expr = ecx.expr_asm(span, P(noop));
+        let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
+        let unsf_block = ast::Block {
+            stmts: thin_vec![ecx.stmt_semi(noop_expr)],
+            id: ast::DUMMY_NODE_ID,
+            tokens: None,
+            rules: unsf,
+            span,
+            could_be_bare_literal: false,
+        };
+        let unsf_expr = ecx.expr_block(P(unsf_block));
+        let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
+        let primal_call = gen_primal_call(ecx, span, primal, idents);
+        let black_box_primal_call =
+            ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![
+                primal_call.clone()
+            ]);
+        let tup_args = new_names
+            .iter()
+            .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
+            .collect();
+
+        let black_box_remaining_args =
+            ecx.expr_call(sig_span, blackbox_call_expr.clone(), thin_vec![
+                ecx.expr_tuple(sig_span, tup_args)
+            ]);
+
+        let mut body = ecx.block(span, ThinVec::new());
+        body.stmts.push(ecx.stmt_semi(unsf_expr));
+
+        // This uses primal args which won't be available if we errored before
+        if !errored {
+            body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
+        }
+        body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
+
+        if !has_ret(&d_sig.decl.output) {
+            // there is no return type that we have to match, () works fine.
+            return body;
+        }
+
+        // having an active-only return means we'll drop the original return type.
+        // So that can be treated identical to not having one in the first place.
+        let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
+
+        if primal_ret && n_active == 0 && x.mode.is_rev() {
+            // We only have the primal ret.
+            body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone()));
+            return body;
+        }
+
+        if !primal_ret && n_active == 1 {
+            // Again no tuple return, so return default float val.
+            let ty = match d_sig.decl.output {
+                FnRetTy::Ty(ref ty) => ty.clone(),
+                FnRetTy::Default(span) => {
+                    panic!("Did not expect Default ret ty: {:?}", span);
+                }
+            };
+            let arg = ty.kind.is_simple_path().unwrap();
+            let sl: Vec<Symbol> = vec![arg, kw::Default];
+            let tmp = ecx.def_site_path(&sl);
+            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
+            let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
+            body.stmts.push(ecx.stmt_expr(default_call_expr));
+            return body;
+        }
+
+        let mut exprs = ThinVec::<P<ast::Expr>>::new();
+        if primal_ret {
+            // We have both primal ret and active floats.
+            // primal ret is first, by construction.
+            exprs.push(primal_call.clone());
+        }
+
+        // Now construct default placeholder for each active float.
+        // Is there something nicer than f32::default() and f64::default()?
+        let d_ret_ty = match d_sig.decl.output {
+            FnRetTy::Ty(ref ty) => ty.clone(),
+            FnRetTy::Default(span) => {
+                panic!("Did not expect Default ret ty: {:?}", span);
+            }
+        };
+        let mut d_ret_ty = match d_ret_ty.kind.clone() {
+            TyKind::Tup(ref tys) => tys.clone(),
+            TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
+                if let [segment] = &segments[..]
+                    && segment.args.is_none()
+                {
+                    let id = vec![segments[0].ident];
+                    let kind = TyKind::Path(None, ecx.path(span, id));
+                    let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
+                    thin_vec![ty]
+                } else {
+                    panic!("Expected tuple or simple path return type");
+                }
+            }
+            _ => {
+                // We messed up construction of d_sig
+                panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
+            }
+        };
+
+        if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
+            assert!(d_ret_ty.len() == 2);
+            // both should be identical, by construction
+            let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
+            let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
+            assert!(arg == arg2);
+            let sl: Vec<Symbol> = vec![arg, kw::Default];
+            let tmp = ecx.def_site_path(&sl);
+            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
+            let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
+            exprs.push(default_call_expr);
+        } else if x.mode.is_rev() {
+            if primal_ret {
+                // We have extra handling above for the primal ret
+                d_ret_ty = d_ret_ty[1..].to_vec().into();
+            }
+
+            for arg in d_ret_ty.iter() {
+                let arg = arg.kind.is_simple_path().unwrap();
+                let sl: Vec<Symbol> = vec![arg, kw::Default];
+                let tmp = ecx.def_site_path(&sl);
+                let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
+                let default_call_expr =
+                    ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
+                exprs.push(default_call_expr);
+            }
+        }
+
+        let ret: P<ast::Expr>;
+        match &exprs[..] {
+            [] => {
+                assert!(!has_ret(&d_sig.decl.output));
+                // We don't have to match the return type.
+                return body;
+            }
+            [arg] => {
+                ret = ecx
+                    .expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![arg.clone()]);
+            }
+            args => {
+                let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
+                ret =
+                    ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]);
+            }
+        }
+        assert!(has_ret(&d_sig.decl.output));
+        body.stmts.push(ecx.stmt_expr(ret));
+
+        body
+    }
+
+    fn gen_primal_call(
+        ecx: &ExtCtxt<'_>,
+        span: Span,
+        primal: Ident,
+        idents: Vec<Ident>,
+    ) -> P<ast::Expr> {
+        let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
+        if has_self {
+            let args: ThinVec<_> =
+                idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
+            let self_expr = ecx.expr_self(span);
+            ecx.expr_method_call(span, self_expr, primal, args.clone())
+        } else {
+            let args: ThinVec<_> =
+                idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
+            let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
+            ecx.expr_call(span, primal_call_expr, args)
+        }
+    }
+
+    // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
+    // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
+    // Active arguments must be scalars. Their shadow argument is added to the return type (and will be
+    // zero-initialized by Enzyme).
+    // Each argument of the primal function (and the return type if existing) must be annotated with an
+    // activity.
+    //
+    // Error handling: If the user provides an invalid configuration (incorrect numbers, types, or
+    // both), we emit an error and return the original signature. This allows us to continue parsing.
+    fn gen_enzyme_decl(
+        ecx: &ExtCtxt<'_>,
+        sig: &ast::FnSig,
+        x: &AutoDiffAttrs,
+        span: Span,
+    ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
+        let dcx = ecx.sess.dcx();
+        let has_ret = has_ret(&sig.decl.output);
+        let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
+        let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
+        if sig_args != num_activities {
+            dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
+                span,
+                expected: sig_args,
+                found: num_activities,
+            });
+            // This is not the right signature, but we can continue parsing.
+            return (sig.clone(), vec![], vec![], true);
+        }
+        assert!(sig.decl.inputs.len() == x.input_activity.len());
+        assert!(has_ret == x.has_ret_activity());
+        let mut d_decl = sig.decl.clone();
+        let mut d_inputs = Vec::new();
+        let mut new_inputs = Vec::new();
+        let mut idents = Vec::new();
+        let mut act_ret = ThinVec::new();
+
+        // We have two loops, a first one just to check the activities and types and possibly report
+        // multiple errors in one compilation session.
+        let mut errors = false;
+        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
+            if !valid_input_activity(x.mode, *activity) {
+                dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
+                    span,
+                    mode: x.mode.to_string(),
+                    act: activity.to_string(),
+                });
+                errors = true;
+            }
+            if !valid_ty_for_activity(&arg.ty, *activity) {
+                dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
+                    span: arg.ty.span,
+                    act: activity.to_string(),
+                });
+                errors = true;
+            }
+        }
+        if errors {
+            // This is not the right signature, but we can continue parsing.
+            return (sig.clone(), new_inputs, idents, true);
+        }
+        let unsafe_activities = x
+            .input_activity
+            .iter()
+            .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
+        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
+            d_inputs.push(arg.clone());
+            match activity {
+                DiffActivity::Active => {
+                    act_ret.push(arg.ty.clone());
+                }
+                DiffActivity::ActiveOnly => {
+                    // We will add the active scalar to the return type.
+                    // This is handled later.
+                }
+                DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
+                    let mut shadow_arg = arg.clone();
+                    // We += into the shadow in reverse mode.
+                    shadow_arg.ty = P(assure_mut_ref(&arg.ty));
+                    let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
+                        ident.name
+                    } else {
+                        debug!("{:#?}", &shadow_arg.pat);
+                        panic!("not an ident?");
+                    };
+                    let name: String = format!("d{}", old_name);
+                    new_inputs.push(name.clone());
+                    let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
+                    shadow_arg.pat = P(ast::Pat {
+                        id: ast::DUMMY_NODE_ID,
+                        kind: PatKind::Ident(BindingMode::NONE, ident, None),
+                        span: shadow_arg.pat.span,
+                        tokens: shadow_arg.pat.tokens.clone(),
+                    });
+                    d_inputs.push(shadow_arg);
+                }
+                DiffActivity::Dual | DiffActivity::DualOnly => {
+                    let mut shadow_arg = arg.clone();
+                    let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
+                        ident.name
+                    } else {
+                        debug!("{:#?}", &shadow_arg.pat);
+                        panic!("not an ident?");
+                    };
+                    let name: String = format!("b{}", old_name);
+                    new_inputs.push(name.clone());
+                    let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
+                    shadow_arg.pat = P(ast::Pat {
+                        id: ast::DUMMY_NODE_ID,
+                        kind: PatKind::Ident(BindingMode::NONE, ident, None),
+                        span: shadow_arg.pat.span,
+                        tokens: shadow_arg.pat.tokens.clone(),
+                    });
+                    d_inputs.push(shadow_arg);
+                }
+                DiffActivity::Const => {
+                    // Nothing to do here.
+                }
+                DiffActivity::None | DiffActivity::FakeActivitySize => {
+                    panic!("Should not happen");
+                }
+            }
+            if let PatKind::Ident(_, ident, _) = arg.pat.kind {
+                idents.push(ident.clone());
+            } else {
+                panic!("not an ident?");
+            }
+        }
+
+        let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
+        if active_only_ret {
+            assert!(x.mode.is_rev());
+        }
+
+        // If we return a scalar in the primal and the scalar is active,
+        // then add it as last arg to the inputs.
+        if x.mode.is_rev() {
+            match x.ret_activity {
+                DiffActivity::Active | DiffActivity::ActiveOnly => {
+                    let ty = match d_decl.output {
+                        FnRetTy::Ty(ref ty) => ty.clone(),
+                        FnRetTy::Default(span) => {
+                            panic!("Did not expect Default ret ty: {:?}", span);
+                        }
+                    };
+                    let name = "dret".to_string();
+                    let ident = Ident::from_str_and_span(&name, ty.span);
+                    let shadow_arg = ast::Param {
+                        attrs: ThinVec::new(),
+                        ty: ty.clone(),
+                        pat: P(ast::Pat {
+                            id: ast::DUMMY_NODE_ID,
+                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
+                            span: ty.span,
+                            tokens: None,
+                        }),
+                        id: ast::DUMMY_NODE_ID,
+                        span: ty.span,
+                        is_placeholder: false,
+                    };
+                    d_inputs.push(shadow_arg);
+                    new_inputs.push(name);
+                }
+                _ => {}
+            }
+        }
+        d_decl.inputs = d_inputs.into();
+
+        if x.mode.is_fwd() {
+            if let DiffActivity::Dual = x.ret_activity {
+                let ty = match d_decl.output {
+                    FnRetTy::Ty(ref ty) => ty.clone(),
+                    FnRetTy::Default(span) => {
+                        panic!("Did not expect Default ret ty: {:?}", span);
+                    }
+                };
+                // Dual can only be used for f32/f64 ret.
+                // In that case we return now a tuple with two floats.
+                let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
+                let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
+                d_decl.output = FnRetTy::Ty(ty);
+            }
+            if let DiffActivity::DualOnly = x.ret_activity {
+                // No need to change the return type,
+                // we will just return the shadow in place
+                // of the primal return.
+            }
+        }
+
+        // If we use ActiveOnly, drop the original return value.
+        d_decl.output =
+            if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
+
+        trace!("act_ret: {:?}", act_ret);
+
+        // If we have an active input scalar, add it's gradient to the
+        // return type. This might require changing the return type to a
+        // tuple.
+        if act_ret.len() > 0 {
+            let ret_ty = match d_decl.output {
+                FnRetTy::Ty(ref ty) => {
+                    if !active_only_ret {
+                        act_ret.insert(0, ty.clone());
+                    }
+                    let kind = TyKind::Tup(act_ret);
+                    P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
+                }
+                FnRetTy::Default(span) => {
+                    if act_ret.len() == 1 {
+                        act_ret[0].clone()
+                    } else {
+                        let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
+                        P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
+                    }
+                }
+            };
+            d_decl.output = FnRetTy::Ty(ret_ty);
+        }
+
+        let mut d_header = sig.header.clone();
+        if unsafe_activities {
+            d_header.safety = rustc_ast::Safety::Unsafe(span);
+        }
+        let d_sig = FnSig { header: d_header, decl: d_decl, span };
+        trace!("Generated signature: {:?}", d_sig);
+        (d_sig, new_inputs, idents, false)
+    }
+}
+
+#[cfg(not(llvm_enzyme))]
+mod ad_fallback {
+    use rustc_ast::ast;
+    use rustc_expand::base::{Annotatable, ExtCtxt};
+    use rustc_span::Span;
+
+    use crate::errors;
+    pub(crate) fn expand(
+        ecx: &mut ExtCtxt<'_>,
+        _expand_span: Span,
+        meta_item: &ast::MetaItem,
+        item: Annotatable,
+    ) -> Vec<Annotatable> {
+        ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
+        return vec![item];
+    }
+}
+
+#[cfg(not(llvm_enzyme))]
+pub(crate) use ad_fallback::expand;
+#[cfg(llvm_enzyme)]
+pub(crate) use llvm_enzyme::expand;
diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs
index 639c2aa231c..f8e65661e52 100644
--- a/compiler/rustc_builtin_macros/src/errors.rs
+++ b/compiler/rustc_builtin_macros/src/errors.rs
@@ -145,6 +145,78 @@ pub(crate) struct AllocMustStatics {
     pub(crate) span: Span,
 }
 
+#[cfg(llvm_enzyme)]
+pub(crate) use autodiff::*;
+
+#[cfg(llvm_enzyme)]
+mod autodiff {
+    use super::*;
+    #[derive(Diagnostic)]
+    #[diag(builtin_macros_autodiff_missing_config)]
+    pub(crate) struct AutoDiffMissingConfig {
+        #[primary_span]
+        pub(crate) span: Span,
+    }
+    #[derive(Diagnostic)]
+    #[diag(builtin_macros_autodiff_unknown_activity)]
+    pub(crate) struct AutoDiffUnknownActivity {
+        #[primary_span]
+        pub(crate) span: Span,
+        pub(crate) act: String,
+    }
+    #[derive(Diagnostic)]
+    #[diag(builtin_macros_autodiff_ty_activity)]
+    pub(crate) struct AutoDiffInvalidTypeForActivity {
+        #[primary_span]
+        pub(crate) span: Span,
+        pub(crate) act: String,
+    }
+    #[derive(Diagnostic)]
+    #[diag(builtin_macros_autodiff_number_activities)]
+    pub(crate) struct AutoDiffInvalidNumberActivities {
+        #[primary_span]
+        pub(crate) span: Span,
+        pub(crate) expected: usize,
+        pub(crate) found: usize,
+    }
+    #[derive(Diagnostic)]
+    #[diag(builtin_macros_autodiff_mode_activity)]
+    pub(crate) struct AutoDiffInvalidApplicationModeAct {
+        #[primary_span]
+        pub(crate) span: Span,
+        pub(crate) mode: String,
+        pub(crate) act: String,
+    }
+
+    #[derive(Diagnostic)]
+    #[diag(builtin_macros_autodiff_mode)]
+    pub(crate) struct AutoDiffInvalidMode {
+        #[primary_span]
+        pub(crate) span: Span,
+        pub(crate) mode: String,
+    }
+
+    #[derive(Diagnostic)]
+    #[diag(builtin_macros_autodiff)]
+    pub(crate) struct AutoDiffInvalidApplication {
+        #[primary_span]
+        pub(crate) span: Span,
+    }
+}
+
+#[cfg(not(llvm_enzyme))]
+pub(crate) use ad_fallback::*;
+#[cfg(not(llvm_enzyme))]
+mod ad_fallback {
+    use super::*;
+    #[derive(Diagnostic)]
+    #[diag(builtin_macros_autodiff_not_build)]
+    pub(crate) struct AutoDiffSupportNotBuild {
+        #[primary_span]
+        pub(crate) span: Span,
+    }
+}
+
 #[derive(Diagnostic)]
 #[diag(builtin_macros_concat_bytes_invalid)]
 pub(crate) struct ConcatBytesInvalid {
diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs
index ebe5e2b5442..377d7f542cf 100644
--- a/compiler/rustc_builtin_macros/src/lib.rs
+++ b/compiler/rustc_builtin_macros/src/lib.rs
@@ -5,6 +5,7 @@
 #![allow(internal_features)]
 #![allow(rustc::diagnostic_outside_of_impl)]
 #![allow(rustc::untranslatable_diagnostic)]
+#![cfg_attr(not(bootstrap), feature(autodiff))]
 #![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
 #![doc(rust_logo)]
 #![feature(assert_matches)]
@@ -29,6 +30,7 @@ use crate::deriving::*;
 
 mod alloc_error_handler;
 mod assert;
+mod autodiff;
 mod cfg;
 mod cfg_accessible;
 mod cfg_eval;
@@ -106,6 +108,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
 
     register_attr! {
         alloc_error_handler: alloc_error_handler::expand,
+        autodiff: autodiff::expand,
         bench: test::expand_bench,
         cfg_accessible: cfg_accessible::Expander,
         cfg_eval: cfg_eval::expand,
diff --git a/compiler/rustc_expand/src/build.rs b/compiler/rustc_expand/src/build.rs
index b5945759d43..743a9854f79 100644
--- a/compiler/rustc_expand/src/build.rs
+++ b/compiler/rustc_expand/src/build.rs
@@ -220,6 +220,10 @@ impl<'a> ExtCtxt<'a> {
         self.stmt_local(local, span)
     }
 
+    pub fn stmt_semi(&self, expr: P<ast::Expr>) -> ast::Stmt {
+        ast::Stmt { id: ast::DUMMY_NODE_ID, span: expr.span, kind: ast::StmtKind::Semi(expr) }
+    }
+
     pub fn stmt_local(&self, local: P<ast::Local>, span: Span) -> ast::Stmt {
         ast::Stmt { id: ast::DUMMY_NODE_ID, kind: ast::StmtKind::Let(local), span }
     }
@@ -287,6 +291,25 @@ impl<'a> ExtCtxt<'a> {
         self.expr(sp, ast::ExprKind::Paren(e))
     }
 
+    pub fn expr_method_call(
+        &self,
+        span: Span,
+        expr: P<ast::Expr>,
+        ident: Ident,
+        args: ThinVec<P<ast::Expr>>,
+    ) -> P<ast::Expr> {
+        let seg = ast::PathSegment::from_ident(ident);
+        self.expr(
+            span,
+            ast::ExprKind::MethodCall(Box::new(ast::MethodCall {
+                seg,
+                receiver: expr,
+                args,
+                span,
+            })),
+        )
+    }
+
     pub fn expr_call(
         &self,
         span: Span,
@@ -295,6 +318,12 @@ impl<'a> ExtCtxt<'a> {
     ) -> P<ast::Expr> {
         self.expr(span, ast::ExprKind::Call(expr, args))
     }
+    pub fn expr_loop(&self, sp: Span, block: P<ast::Block>) -> P<ast::Expr> {
+        self.expr(sp, ast::ExprKind::Loop(block, None, sp))
+    }
+    pub fn expr_asm(&self, sp: Span, expr: P<ast::InlineAsm>) -> P<ast::Expr> {
+        self.expr(sp, ast::ExprKind::InlineAsm(expr))
+    }
     pub fn expr_call_ident(
         &self,
         span: Span,
diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs
index 17827b4e43b..477760a4597 100644
--- a/compiler/rustc_feature/src/builtin_attrs.rs
+++ b/compiler/rustc_feature/src/builtin_attrs.rs
@@ -752,6 +752,11 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[
         template!(NameValueStr: "transparent|semitransparent|opaque"), ErrorFollowing,
         EncodeCrossCrate::Yes, "used internally for testing macro hygiene",
     ),
+    rustc_attr!(
+        rustc_autodiff, Normal,
+        template!(Word, List: r#""...""#), DuplicatesOk,
+        EncodeCrossCrate::No, INTERNAL_UNSTABLE
+    ),
 
     // ==========================================================================
     // Internal attributes, Diagnostics related:
diff --git a/compiler/rustc_passes/messages.ftl b/compiler/rustc_passes/messages.ftl
index be76d6cef2b..d60d7e00d4b 100644
--- a/compiler/rustc_passes/messages.ftl
+++ b/compiler/rustc_passes/messages.ftl
@@ -49,6 +49,10 @@ passes_attr_crate_level =
 passes_attr_only_in_functions =
     `{$attr}` attribute can only be used on functions
 
+passes_autodiff_attr =
+    `#[autodiff]` should be applied to a function
+    .label = not a function
+
 passes_both_ffi_const_and_pure =
     `#[ffi_const]` function cannot be `#[ffi_pure]`
 
diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs
index 44a62383e6e..7ce29260e36 100644
--- a/compiler/rustc_passes/src/check_attr.rs
+++ b/compiler/rustc_passes/src/check_attr.rs
@@ -243,6 +243,9 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
                     self.check_generic_attr(hir_id, attr, target, Target::Fn);
                     self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
                 }
+                [sym::autodiff, ..] => {
+                    self.check_autodiff(hir_id, attr, span, target)
+                }
                 [sym::coroutine, ..] => {
                     self.check_coroutine(attr, target);
                 }
@@ -2345,6 +2348,18 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
             self.dcx().emit_err(errors::RustcPubTransparent { span, attr_span });
         }
     }
+
+    /// Checks if `#[autodiff]` is applied to an item other than a function item.
+    fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) {
+        debug!("check_autodiff");
+        match target {
+            Target::Fn => {}
+            _ => {
+                self.dcx().emit_err(errors::AutoDiffAttr { attr_span: span });
+                self.abort.set(true);
+            }
+        }
+    }
 }
 
 impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> {
diff --git a/compiler/rustc_passes/src/errors.rs b/compiler/rustc_passes/src/errors.rs
index 4f00c90fa3b..70f525ec0e8 100644
--- a/compiler/rustc_passes/src/errors.rs
+++ b/compiler/rustc_passes/src/errors.rs
@@ -20,6 +20,14 @@ use crate::lang_items::Duplicate;
 #[diag(passes_incorrect_do_not_recommend_location)]
 pub(crate) struct IncorrectDoNotRecommendLocation;
 
+#[derive(Diagnostic)]
+#[diag(passes_autodiff_attr)]
+pub(crate) struct AutoDiffAttr {
+    #[primary_span]
+    #[label]
+    pub attr_span: Span,
+}
+
 #[derive(LintDiagnostic)]
 #[diag(passes_outer_crate_level_attr)]
 pub(crate) struct OuterCrateLevelAttr;
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index 8e0009695db..92946ef5262 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -481,6 +481,8 @@ symbols! {
         audit_that,
         augmented_assignments,
         auto_traits,
+        autodiff,
+        autodiff_fallback,
         automatically_derived,
         avx,
         avx512_target_feature,
@@ -544,6 +546,7 @@ symbols! {
         cfg_accessible,
         cfg_attr,
         cfg_attr_multi,
+        cfg_autodiff_fallback,
         cfg_boolean_literals,
         cfg_doctest,
         cfg_eval,
@@ -998,6 +1001,7 @@ symbols! {
         hashset_iter_ty,
         hexagon_target_feature,
         hidden,
+        hint,
         homogeneous_aggregate,
         host,
         html_favicon_url,
@@ -1650,6 +1654,7 @@ symbols! {
         rustc_allow_incoherent_impl,
         rustc_allowed_through_unstable_modules,
         rustc_attrs,
+        rustc_autodiff,
         rustc_box,
         rustc_builtin_macro,
         rustc_capture_analysis,
diff --git a/library/core/src/lib.rs b/library/core/src/lib.rs
index 96ab5755328..f69a33bca84 100644
--- a/library/core/src/lib.rs
+++ b/library/core/src/lib.rs
@@ -278,6 +278,15 @@ pub mod assert_matches {
     pub use crate::macros::{assert_matches, debug_assert_matches};
 }
 
+// We don't export this through #[macro_export] for now, to avoid breakage.
+#[cfg(not(bootstrap))]
+#[unstable(feature = "autodiff", issue = "124509")]
+/// Unstable module containing the unstable `autodiff` macro.
+pub mod autodiff {
+    #[unstable(feature = "autodiff", issue = "124509")]
+    pub use crate::macros::builtin::autodiff;
+}
+
 #[unstable(feature = "cfg_match", issue = "115585")]
 pub use crate::macros::cfg_match;
 
diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs
index aa0646846e4..b5e5b58f705 100644
--- a/library/core/src/macros/mod.rs
+++ b/library/core/src/macros/mod.rs
@@ -1539,6 +1539,24 @@ pub(crate) mod builtin {
         ($file:expr $(,)?) => {{ /* compiler built-in */ }};
     }
 
+    /// Automatic Differentiation macro which allows generating a new function to compute
+    /// the derivative of a given function. It may only be applied to a function.
+    /// The expected usage syntax is
+    /// `#[autodiff(NAME, MODE, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
+    /// where:
+    /// NAME is a string that represents a valid function name.
+    /// MODE is any of Forward, Reverse, ForwardFirst, ReverseFirst.
+    /// INPUT_ACTIVITIES consists of one valid activity for each input parameter.
+    /// OUTPUT_ACTIVITY must not be set if we implicitely return nothing (or explicitely return
+    /// `-> ()`. Otherwise it must be set to one of the allowed activities.
+    #[unstable(feature = "autodiff", issue = "124509")]
+    #[allow_internal_unstable(rustc_attrs)]
+    #[rustc_builtin_macro]
+    #[cfg(not(bootstrap))]
+    pub macro autodiff($item:item) {
+        /* compiler built-in */
+    }
+
     /// Asserts that a boolean expression is `true` at runtime.
     ///
     /// This will invoke the [`panic!`] macro if the provided expression cannot be
diff --git a/library/std/src/lib.rs b/library/std/src/lib.rs
index 65a9aa66c7c..35ed761759b 100644
--- a/library/std/src/lib.rs
+++ b/library/std/src/lib.rs
@@ -267,6 +267,7 @@
 #![allow(unused_features)]
 //
 // Features:
+#![cfg_attr(not(bootstrap), feature(autodiff))]
 #![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))]
 #![cfg_attr(
     all(target_vendor = "fortanix", target_env = "sgx"),
@@ -627,7 +628,13 @@ pub mod simd {
     #[doc(inline)]
     pub use crate::std_float::StdFloat;
 }
-
+#[cfg(not(bootstrap))]
+#[unstable(feature = "autodiff", issue = "124509")]
+/// This module provides support for automatic differentiation.
+pub mod autodiff {
+    /// This macro handles automatic differentiation.
+    pub use core::autodiff::autodiff;
+}
 #[stable(feature = "futures_api", since = "1.36.0")]
 pub mod task {
     //! Types and Traits for working with asynchronous tasks.