about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/intrinsic.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/intrinsic.rs')
-rw-r--r--compiler/rustc_codegen_llvm/src/intrinsic.rs215
1 files changed, 184 insertions, 31 deletions
diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs
index fcc0d378f06..e7f4a357048 100644
--- a/compiler/rustc_codegen_llvm/src/intrinsic.rs
+++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs
@@ -3,24 +3,29 @@ use std::cmp::Ordering;
 
 use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size};
 use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
+use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
 use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
 use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
 use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
 use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
 use rustc_codegen_ssa::traits::*;
-use rustc_hir as hir;
+use rustc_hir::def_id::LOCAL_CRATE;
+use rustc_hir::{self as hir};
 use rustc_middle::mir::BinOp;
 use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
-use rustc_middle::ty::{self, GenericArgsRef, Ty};
+use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
 use rustc_middle::{bug, span_bug};
 use rustc_span::{Span, Symbol, sym};
-use rustc_symbol_mangling::mangle_internal_symbol;
+use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
+use rustc_target::callconv::PassMode;
 use rustc_target::spec::PanicStrategy;
 use tracing::debug;
 
 use crate::abi::FnAbiLlvmExt;
 use crate::builder::Builder;
+use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
 use crate::context::CodegenCx;
+use crate::errors::AutoDiffWithoutEnable;
 use crate::llvm::{self, Metadata};
 use crate::type_::Type;
 use crate::type_of::LayoutLlvmExt;
@@ -189,6 +194,10 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                     &[ptr, args[1].immediate()],
                 )
             }
+            sym::autodiff => {
+                codegen_autodiff(self, tcx, instance, args, result);
+                return Ok(());
+            }
             sym::is_val_statically_known => {
                 if let OperandValue::Immediate(imm) = args[0].val {
                     self.call_intrinsic(
@@ -321,10 +330,16 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                     _ => bug!(),
                 };
                 let ptr = args[0].immediate();
+                let locality = fn_args.const_at(1).to_value().valtree.unwrap_leaf().to_i32();
                 self.call_intrinsic(
                     "llvm.prefetch",
                     &[self.val_ty(ptr)],
-                    &[ptr, self.const_i32(rw), args[1].immediate(), self.const_i32(cache_type)],
+                    &[
+                        ptr,
+                        self.const_i32(rw),
+                        self.const_i32(locality),
+                        self.const_i32(cache_type),
+                    ],
                 )
             }
             sym::carrying_mul_add => {
@@ -368,7 +383,9 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
             | sym::rotate_left
             | sym::rotate_right
             | sym::saturating_add
-            | sym::saturating_sub => {
+            | sym::saturating_sub
+            | sym::unchecked_funnel_shl
+            | sym::unchecked_funnel_shr => {
                 let ty = args[0].layout.ty;
                 if !ty.is_integral() {
                     tcx.dcx().emit_err(InvalidMonomorphization::BasicIntegerType {
@@ -382,26 +399,16 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                 let width = size.bits();
                 let llty = self.type_ix(width);
                 match name {
-                    sym::ctlz | sym::cttz => {
-                        let y = self.const_bool(false);
-                        let ret = self.call_intrinsic(
-                            format!("llvm.{name}"),
-                            &[llty],
-                            &[args[0].immediate(), y],
-                        );
-
-                        self.intcast(ret, result.layout.llvm_type(self), false)
-                    }
-                    sym::ctlz_nonzero => {
-                        let y = self.const_bool(true);
-                        let ret =
-                            self.call_intrinsic("llvm.ctlz", &[llty], &[args[0].immediate(), y]);
-                        self.intcast(ret, result.layout.llvm_type(self), false)
-                    }
-                    sym::cttz_nonzero => {
-                        let y = self.const_bool(true);
+                    sym::ctlz | sym::ctlz_nonzero | sym::cttz | sym::cttz_nonzero => {
+                        let y =
+                            self.const_bool(name == sym::ctlz_nonzero || name == sym::cttz_nonzero);
+                        let llvm_name = if name == sym::ctlz || name == sym::ctlz_nonzero {
+                            "llvm.ctlz"
+                        } else {
+                            "llvm.cttz"
+                        };
                         let ret =
-                            self.call_intrinsic("llvm.cttz", &[llty], &[args[0].immediate(), y]);
+                            self.call_intrinsic(llvm_name, &[llty], &[args[0].immediate(), y]);
                         self.intcast(ret, result.layout.llvm_type(self), false)
                     }
                     sym::ctpop => {
@@ -419,18 +426,26 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
                     sym::bitreverse => {
                         self.call_intrinsic("llvm.bitreverse", &[llty], &[args[0].immediate()])
                     }
-                    sym::rotate_left | sym::rotate_right => {
-                        let is_left = name == sym::rotate_left;
-                        let val = args[0].immediate();
-                        let raw_shift = args[1].immediate();
-                        // rotate = funnel shift with first two args the same
+                    sym::rotate_left
+                    | sym::rotate_right
+                    | sym::unchecked_funnel_shl
+                    | sym::unchecked_funnel_shr => {
+                        let is_left = name == sym::rotate_left || name == sym::unchecked_funnel_shl;
+                        let lhs = args[0].immediate();
+                        let (rhs, raw_shift) =
+                            if name == sym::rotate_left || name == sym::rotate_right {
+                                // rotate = funnel shift with first two args the same
+                                (lhs, args[1].immediate())
+                            } else {
+                                (args[1].immediate(), args[2].immediate())
+                            };
                         let llvm_name = format!("llvm.fsh{}", if is_left { 'l' } else { 'r' });
 
                         // llvm expects shift to be the same type as the values, but rust
                         // always uses `u32`.
-                        let raw_shift = self.intcast(raw_shift, self.val_ty(val), false);
+                        let raw_shift = self.intcast(raw_shift, self.val_ty(lhs), false);
 
-                        self.call_intrinsic(llvm_name, &[llty], &[val, val, raw_shift])
+                        self.call_intrinsic(llvm_name, &[llty], &[lhs, rhs, raw_shift])
                     }
                     sym::saturating_add | sym::saturating_sub => {
                         let is_add = name == sym::saturating_add;
@@ -1123,6 +1138,144 @@ fn get_rust_try_fn<'a, 'll, 'tcx>(
     rust_try
 }
 
+fn codegen_autodiff<'ll, 'tcx>(
+    bx: &mut Builder<'_, 'll, 'tcx>,
+    tcx: TyCtxt<'tcx>,
+    instance: ty::Instance<'tcx>,
+    args: &[OperandRef<'tcx, &'ll Value>],
+    result: PlaceRef<'tcx, &'ll Value>,
+) {
+    if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) {
+        let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutEnable);
+    }
+
+    let fn_args = instance.args;
+    let callee_ty = instance.ty(tcx, bx.typing_env());
+
+    let sig = callee_ty.fn_sig(tcx).skip_binder();
+
+    let ret_ty = sig.output();
+    let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);
+
+    // Get source, diff, and attrs
+    let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
+        ty::FnDef(def_id, source_params) => (def_id, source_params),
+        _ => bug!("invalid autodiff intrinsic args"),
+    };
+
+    let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) {
+        Ok(Some(instance)) => instance,
+        Ok(None) => bug!(
+            "could not resolve ({:?}, {:?}) to a specific autodiff instance",
+            source_id,
+            source_args
+        ),
+        Err(_) => {
+            // An error has already been emitted
+            return;
+        }
+    };
+
+    let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
+    let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else {
+        bug!("could not find source function")
+    };
+
+    let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
+        ty::FnDef(def_id, diff_args) => (def_id, diff_args),
+        _ => bug!("invalid args"),
+    };
+
+    let fn_diff = match Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args) {
+        Ok(Some(instance)) => instance,
+        Ok(None) => bug!(
+            "could not resolve ({:?}, {:?}) to a specific autodiff instance",
+            diff_id,
+            diff_args
+        ),
+        Err(_) => {
+            // An error has already been emitted
+            return;
+        }
+    };
+
+    let val_arr = get_args_from_tuple(bx, args[2], fn_diff);
+    let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
+
+    let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else {
+        bug!("could not find autodiff attrs")
+    };
+
+    adjust_activity_to_abi(
+        tcx,
+        fn_source,
+        TypingEnv::fully_monomorphized(),
+        &mut diff_attrs.input_activity,
+    );
+
+    // Build body
+    generate_enzyme_call(
+        bx,
+        bx.cx,
+        fn_to_diff,
+        &diff_symbol,
+        llret_ty,
+        &val_arr,
+        diff_attrs.clone(),
+        result,
+    );
+}
+
+fn get_args_from_tuple<'ll, 'tcx>(
+    bx: &mut Builder<'_, 'll, 'tcx>,
+    tuple_op: OperandRef<'tcx, &'ll Value>,
+    fn_instance: Instance<'tcx>,
+) -> Vec<&'ll Value> {
+    let cx = bx.cx;
+    let fn_abi = cx.fn_abi_of_instance(fn_instance, ty::List::empty());
+
+    match tuple_op.val {
+        OperandValue::Immediate(val) => vec![val],
+        OperandValue::Pair(v1, v2) => vec![v1, v2],
+        OperandValue::Ref(ptr) => {
+            let tuple_place = PlaceRef { val: ptr, layout: tuple_op.layout };
+
+            let mut result = Vec::with_capacity(fn_abi.args.len());
+            let mut tuple_index = 0;
+
+            for arg in &fn_abi.args {
+                match arg.mode {
+                    PassMode::Ignore => {}
+                    PassMode::Direct(_) | PassMode::Cast { .. } => {
+                        let field = tuple_place.project_field(bx, tuple_index);
+                        let llvm_ty = field.layout.llvm_type(bx.cx);
+                        let val = bx.load(llvm_ty, field.val.llval, field.val.align);
+                        result.push(val);
+                        tuple_index += 1;
+                    }
+                    PassMode::Pair(_, _) => {
+                        let field = tuple_place.project_field(bx, tuple_index);
+                        let llvm_ty = field.layout.llvm_type(bx.cx);
+                        let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align);
+                        result.push(bx.extract_value(pair_val, 0));
+                        result.push(bx.extract_value(pair_val, 1));
+                        tuple_index += 1;
+                    }
+                    PassMode::Indirect { .. } => {
+                        let field = tuple_place.project_field(bx, tuple_index);
+                        result.push(field.val.llval);
+                        tuple_index += 1;
+                    }
+                }
+            }
+
+            result
+        }
+
+        OperandValue::ZeroSized => vec![],
+    }
+}
+
 fn generic_simd_intrinsic<'ll, 'tcx>(
     bx: &mut Builder<'_, 'll, 'tcx>,
     name: Symbol,