about summary refs log tree commit diff
path: root/compiler/rustc_monomorphize/src/partitioning/autodiff.rs
blob: 0c855508434e07f113f3149f6d6b074e64ef7323 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
use rustc_hir::def_id::LOCAL_CRATE;
use rustc_middle::bug;
use rustc_middle::mir::mono::MonoItem;
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
use tracing::{debug, trace};

use crate::partitioning::UsageMap;

fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
    if !matches!(fn_ty.kind(), ty::FnDef(..)) {
        bug!("expected fn def for autodiff, got {:?}", fn_ty);
    }
    let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx);

    // If rustc compiles the unmodified primal, we know that this copy of the function
    // also has correct lifetimes. We know that Enzyme won't free the shadow too early
    // (or actually at all), so let's strip lifetimes when computing the layout.
    let x = tcx.instantiate_bound_regions_with_erased(fnc_binder);
    let mut new_activities = vec![];
    let mut new_positions = vec![];
    for (i, ty) in x.inputs().iter().enumerate() {
        if let Some(inner_ty) = ty.builtin_deref(true) {
            if ty.is_fn_ptr() {
                // FIXME(ZuseZ4): add a nicer error, or just figure out how to support them,
                // since Enzyme itself can handle them.
                tcx.dcx().err("function pointers are currently not supported in autodiff");
            }
            if inner_ty.is_slice() {
                // We know that the length will be passed as extra arg.
                if !da.is_empty() {
                    // We are looking at a slice. The length of that slice will become an
                    // extra integer on llvm level. Integers are always const.
                    // However, if the slice get's duplicated, we want to know to later check the
                    // size. So we mark the new size argument as FakeActivitySize.
                    let activity = match da[i] {
                        DiffActivity::DualOnly
                        | DiffActivity::Dual
                        | DiffActivity::DuplicatedOnly
                        | DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
                        DiffActivity::Const => DiffActivity::Const,
                        _ => bug!("unexpected activity for ptr/ref"),
                    };
                    new_activities.push(activity);
                    new_positions.push(i + 1);
                }
                continue;
            }
        }
    }
    // now add the extra activities coming from slices
    // Reverse order to not invalidate the indices
    for _ in 0..new_activities.len() {
        let pos = new_positions.pop().unwrap();
        let activity = new_activities.pop().unwrap();
        da.insert(pos, activity);
    }
}

pub(crate) fn find_autodiff_source_functions<'tcx>(
    tcx: TyCtxt<'tcx>,
    usage_map: &UsageMap<'tcx>,
    autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
) -> Vec<AutoDiffItem> {
    let mut autodiff_items: Vec<AutoDiffItem> = vec![];
    for (item, instance) in autodiff_mono_items {
        let target_id = instance.def_id();
        let cg_fn_attr = &tcx.codegen_fn_attrs(target_id).autodiff_item;
        let Some(target_attrs) = cg_fn_attr else {
            continue;
        };
        let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
        if target_attrs.is_source() {
            trace!("source found: {:?}", target_id);
        }
        if !target_attrs.apply_autodiff() {
            continue;
        }

        let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);

        let source =
            usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
                MonoItem::Fn(ref instance_s) => {
                    let source_id = instance_s.def_id();
                    if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
                        && ad.is_active()
                    {
                        return Some(instance_s);
                    }
                    None
                }
                _ => None,
            });
        let inst = match source {
            Some(source) => source,
            None => continue,
        };

        debug!("source_id: {:?}", inst.def_id());
        let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
        assert!(fn_ty.is_fn());
        adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
        let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);

        let mut new_target_attrs = target_attrs.clone();
        new_target_attrs.input_activity = input_activities;
        let itm = new_target_attrs.into_item(symb, target_symbol);
        autodiff_items.push(itm);
    }

    if !autodiff_items.is_empty() {
        trace!("AUTODIFF ITEMS EXIST");
        for item in &mut *autodiff_items {
            trace!("{}", &item);
        }
    }

    autodiff_items
}