about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
-rw-r--r--compiler/rustc_codegen_llvm/src/allocator.rs74
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs4
-rw-r--r--compiler/rustc_codegen_llvm/src/common.rs4
-rw-r--r--compiler/rustc_codegen_llvm/src/context.rs7
4 files changed, 48 insertions, 41 deletions
diff --git a/compiler/rustc_codegen_llvm/src/allocator.rs b/compiler/rustc_codegen_llvm/src/allocator.rs
index 4a78e694979..9dca63cfc8d 100644
--- a/compiler/rustc_codegen_llvm/src/allocator.rs
+++ b/compiler/rustc_codegen_llvm/src/allocator.rs
@@ -57,7 +57,7 @@ pub(crate) unsafe fn codegen(
             let from_name = mangle_internal_symbol(tcx, &global_fn_name(method.name));
             let to_name = mangle_internal_symbol(tcx, &default_fn_name(method.name));
 
-            create_wrapper_function(tcx, &cx, &from_name, &to_name, &args, output, false);
+            create_wrapper_function(tcx, &cx, &from_name, Some(&to_name), &args, output, false);
         }
     }
 
@@ -66,7 +66,7 @@ pub(crate) unsafe fn codegen(
         tcx,
         &cx,
         &mangle_internal_symbol(tcx, "__rust_alloc_error_handler"),
-        &mangle_internal_symbol(tcx, alloc_error_handler_name(alloc_error_handler_kind)),
+        Some(&mangle_internal_symbol(tcx, alloc_error_handler_name(alloc_error_handler_kind))),
         &[usize, usize], // size, align
         None,
         true,
@@ -81,11 +81,16 @@ pub(crate) unsafe fn codegen(
         let llval = llvm::LLVMConstInt(i8, val as u64, False);
         llvm::set_initializer(ll_g, llval);
 
-        let name = mangle_internal_symbol(tcx, NO_ALLOC_SHIM_IS_UNSTABLE);
-        let ll_g = cx.declare_global(&name, i8);
-        llvm::set_visibility(ll_g, llvm::Visibility::from_generic(tcx.sess.default_visibility()));
-        let llval = llvm::LLVMConstInt(i8, 0, False);
-        llvm::set_initializer(ll_g, llval);
+        // __rust_no_alloc_shim_is_unstable_v2
+        create_wrapper_function(
+            tcx,
+            &cx,
+            &mangle_internal_symbol(tcx, NO_ALLOC_SHIM_IS_UNSTABLE),
+            None,
+            &[],
+            None,
+            false,
+        );
     }
 
     if tcx.sess.opts.debuginfo != DebugInfo::None {
@@ -99,7 +104,7 @@ fn create_wrapper_function(
     tcx: TyCtxt<'_>,
     cx: &SimpleCx<'_>,
     from_name: &str,
-    to_name: &str,
+    to_name: Option<&str>,
     args: &[&Type],
     output: Option<&Type>,
     no_return: bool,
@@ -128,33 +133,38 @@ fn create_wrapper_function(
         attributes::apply_to_llfn(llfn, llvm::AttributePlace::Function, &[uwtable]);
     }
 
-    let callee = declare_simple_fn(
-        &cx,
-        to_name,
-        llvm::CallConv::CCallConv,
-        llvm::UnnamedAddr::Global,
-        llvm::Visibility::Hidden,
-        ty,
-    );
-    if let Some(no_return) = no_return {
-        // -> ! DIFlagNoReturn
-        attributes::apply_to_llfn(callee, llvm::AttributePlace::Function, &[no_return]);
-    }
-    llvm::set_visibility(callee, llvm::Visibility::Hidden);
-
     let llbb = unsafe { llvm::LLVMAppendBasicBlockInContext(cx.llcx, llfn, c"entry".as_ptr()) };
-
     let mut bx = SBuilder::build(&cx, llbb);
-    let args = args
-        .iter()
-        .enumerate()
-        .map(|(i, _)| llvm::get_param(llfn, i as c_uint))
-        .collect::<Vec<_>>();
-    let ret = bx.call(ty, callee, &args, None);
-    llvm::LLVMSetTailCall(ret, True);
-    if output.is_some() {
-        bx.ret(ret);
+
+    if let Some(to_name) = to_name {
+        let callee = declare_simple_fn(
+            &cx,
+            to_name,
+            llvm::CallConv::CCallConv,
+            llvm::UnnamedAddr::Global,
+            llvm::Visibility::Hidden,
+            ty,
+        );
+        if let Some(no_return) = no_return {
+            // -> ! DIFlagNoReturn
+            attributes::apply_to_llfn(callee, llvm::AttributePlace::Function, &[no_return]);
+        }
+        llvm::set_visibility(callee, llvm::Visibility::Hidden);
+
+        let args = args
+            .iter()
+            .enumerate()
+            .map(|(i, _)| llvm::get_param(llfn, i as c_uint))
+            .collect::<Vec<_>>();
+        let ret = bx.call(ty, callee, &args, None);
+        llvm::LLVMSetTailCall(ret, True);
+        if output.is_some() {
+            bx.ret(ret);
+        } else {
+            bx.ret_void()
+        }
     } else {
+        assert!(output.is_none());
         bx.ret_void()
     }
 }
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index c5c13ac097a..b07d9a5cfca 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -114,7 +114,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
             let mul = unsafe {
                 llvm::LLVMBuildMul(
                     builder.llbuilder,
-                    cx.get_const_i64(elem_bytes_size),
+                    cx.get_const_int(cx.type_i64(), elem_bytes_size),
                     next_outer_arg,
                     UNNAMED,
                 )
@@ -385,7 +385,7 @@ fn generate_enzyme_call<'ll>(
         if attrs.width > 1 {
             let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
             args.push(cx.get_metadata_value(enzyme_width));
-            args.push(cx.get_const_i64(attrs.width as u64));
+            args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
         }
 
         let has_sret = has_sret(outer_fn);
diff --git a/compiler/rustc_codegen_llvm/src/common.rs b/compiler/rustc_codegen_llvm/src/common.rs
index 3cfa96393e9..ae5add59322 100644
--- a/compiler/rustc_codegen_llvm/src/common.rs
+++ b/compiler/rustc_codegen_llvm/src/common.rs
@@ -99,14 +99,14 @@ impl<'ll, CX: Borrow<SCx<'ll>>> BackendTypes for GenericCx<'ll, CX> {
     type DIVariable = &'ll llvm::debuginfo::DIVariable;
 }
 
-impl<'ll> CodegenCx<'ll, '_> {
+impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
     pub(crate) fn const_array(&self, ty: &'ll Type, elts: &[&'ll Value]) -> &'ll Value {
         let len = u64::try_from(elts.len()).expect("LLVMConstArray2 elements len overflow");
         unsafe { llvm::LLVMConstArray2(ty, elts.as_ptr(), len) }
     }
 
     pub(crate) fn const_bytes(&self, bytes: &[u8]) -> &'ll Value {
-        bytes_in_context(self.llcx, bytes)
+        bytes_in_context(self.llcx(), bytes)
     }
 
     pub(crate) fn const_get_elt(&self, v: &'ll Value, idx: u64) -> &'ll Value {
diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs
index bff95ea46fa..0324dff6ff2 100644
--- a/compiler/rustc_codegen_llvm/src/context.rs
+++ b/compiler/rustc_codegen_llvm/src/context.rs
@@ -679,11 +679,8 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
         llvm::LLVMMetadataAsValue(self.llcx(), metadata)
     }
 
-    // FIXME(autodiff): We should split `ConstCodegenMethods` to pull the reusable parts
-    // onto a trait that is also implemented for GenericCx.
-    pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value {
-        let ty = unsafe { llvm::LLVMInt64TypeInContext(self.llcx()) };
-        unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
+    pub(crate) fn get_const_int(&self, ty: &'ll Type, val: u64) -> &'ll Value {
+        unsafe { llvm::LLVMConstInt(ty, val, llvm::False) }
     }
 
     pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {