about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorLuca Versari <veluca93@gmail.com>2025-01-14 22:51:09 +0100
committerLuca Versari <veluca93@gmail.com>2025-01-16 10:54:33 +0100
commit8fee6a77394ffda819e58a648d4b44e1e566f34b (patch)
treebfa6ea3c2116cd35e07e42b45a96377353d5745a /compiler
parent341f60327fa5302732a4be366949c16f91870b6a (diff)
downloadrust-8fee6a77394ffda819e58a648d4b44e1e566f34b.tar.gz
rust-8fee6a77394ffda819e58a648d4b44e1e566f34b.zip
Coerce safe-to-call target_feature functions to fn pointers.
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_borrowck/src/type_check/mod.rs15
-rw-r--r--compiler/rustc_hir_typeck/src/coercion.rs23
-rw-r--r--compiler/rustc_middle/src/ty/context.rs33
-rw-r--r--compiler/rustc_mir_build/src/check_unsafety.rs11
4 files changed, 61 insertions, 21 deletions
diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs
index a1979c8b8ab..eca8a688ff4 100644
--- a/compiler/rustc_borrowck/src/type_check/mod.rs
+++ b/compiler/rustc_borrowck/src/type_check/mod.rs
@@ -1654,7 +1654,20 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                 match *cast_kind {
                     CastKind::PointerCoercion(PointerCoercion::ReifyFnPointer, coercion_source) => {
                         let is_implicit_coercion = coercion_source == CoercionSource::Implicit;
-                        let src_sig = op.ty(body, tcx).fn_sig(tcx);
+                        let src_ty = op.ty(body, tcx);
+                        let mut src_sig = src_ty.fn_sig(tcx);
+                        if let ty::FnDef(def_id, _) = src_ty.kind()
+                            && let ty::FnPtr(_, target_hdr) = *ty.kind()
+                            && tcx.codegen_fn_attrs(def_id).safe_target_features
+                            && target_hdr.safety.is_safe()
+                            && let Some(safe_sig) = tcx.adjust_target_feature_sig(
+                                *def_id,
+                                src_sig,
+                                body.source.def_id(),
+                            )
+                        {
+                            src_sig = safe_sig;
+                        }
 
                         // HACK: This shouldn't be necessary... We can remove this when we actually
                         // get binders with where clauses, then elaborate implied bounds into that
diff --git a/compiler/rustc_hir_typeck/src/coercion.rs b/compiler/rustc_hir_typeck/src/coercion.rs
index ec7c1efa38e..6945dbc3216 100644
--- a/compiler/rustc_hir_typeck/src/coercion.rs
+++ b/compiler/rustc_hir_typeck/src/coercion.rs
@@ -920,7 +920,7 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
 
         match b.kind() {
             ty::FnPtr(_, b_hdr) => {
-                let a_sig = a.fn_sig(self.tcx);
+                let mut a_sig = a.fn_sig(self.tcx);
                 if let ty::FnDef(def_id, _) = *a.kind() {
                     // Intrinsics are not coercible to function pointers
                     if self.tcx.intrinsic(def_id).is_some() {
@@ -932,19 +932,20 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
                         return Err(TypeError::ForceInlineCast);
                     }
 
-                    let fn_attrs = self.tcx.codegen_fn_attrs(def_id);
-                    if matches!(fn_attrs.inline, InlineAttr::Force { .. }) {
-                        return Err(TypeError::ForceInlineCast);
-                    }
-
-                    // FIXME(target_feature): Safe `#[target_feature]` functions could be cast to safe fn pointers (RFC 2396),
-                    // as you can already write that "cast" in user code by wrapping a target_feature fn call in a closure,
-                    // which is safe. This is sound because you already need to be executing code that is satisfying the target
-                    // feature constraints..
                     if b_hdr.safety.is_safe()
                         && self.tcx.codegen_fn_attrs(def_id).safe_target_features
                     {
-                        return Err(TypeError::TargetFeatureCast(def_id));
+                        // Allow the coercion if the current function has all the features that would be
+                        // needed to call the coercee safely.
+                        if let Some(safe_sig) = self.tcx.adjust_target_feature_sig(
+                            def_id,
+                            a_sig,
+                            self.fcx.body_id.into(),
+                        ) {
+                            a_sig = safe_sig;
+                        } else {
+                            return Err(TypeError::TargetFeatureCast(def_id));
+                        }
                     }
                 }
 
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index fab0047babf..7035e641f39 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -60,7 +60,7 @@ use crate::dep_graph::{DepGraph, DepKindStruct};
 use crate::infer::canonical::{CanonicalParamEnvCache, CanonicalVarInfo, CanonicalVarInfos};
 use crate::lint::lint_level;
 use crate::metadata::ModChild;
-use crate::middle::codegen_fn_attrs::CodegenFnAttrs;
+use crate::middle::codegen_fn_attrs::{CodegenFnAttrs, TargetFeature};
 use crate::middle::{resolve_bound_vars, stability};
 use crate::mir::interpret::{self, Allocation, ConstAllocation};
 use crate::mir::{Body, Local, Place, PlaceElem, ProjectionKind, Promoted};
@@ -1776,6 +1776,37 @@ impl<'tcx> TyCtxt<'tcx> {
     pub fn dcx(self) -> DiagCtxtHandle<'tcx> {
         self.sess.dcx()
     }
+
+    pub fn is_target_feature_call_safe(
+        self,
+        callee_features: &[TargetFeature],
+        body_features: &[TargetFeature],
+    ) -> bool {
+        // If the called function has target features the calling function hasn't,
+        // the call requires `unsafe`. Don't check this on wasm
+        // targets, though. For more information on wasm see the
+        // is_like_wasm check in hir_analysis/src/collect.rs
+        self.sess.target.options.is_like_wasm
+            || callee_features
+                .iter()
+                .all(|feature| body_features.iter().any(|f| f.name == feature.name))
+    }
+
+    /// Returns the safe version of the signature of the given function, if calling it
+    /// would be safe in the context of the given caller.
+    pub fn adjust_target_feature_sig(
+        self,
+        fun_def: DefId,
+        fun_sig: ty::Binder<'tcx, ty::FnSig<'tcx>>,
+        caller: DefId,
+    ) -> Option<ty::Binder<'tcx, ty::FnSig<'tcx>>> {
+        let fun_features = &self.codegen_fn_attrs(fun_def).target_features;
+        let callee_features = &self.codegen_fn_attrs(caller).target_features;
+        if self.is_target_feature_call_safe(&fun_features, &callee_features) {
+            return Some(fun_sig.map_bound(|sig| ty::FnSig { safety: hir::Safety::Safe, ..sig }));
+        }
+        None
+    }
 }
 
 impl<'tcx> TyCtxtAt<'tcx> {
diff --git a/compiler/rustc_mir_build/src/check_unsafety.rs b/compiler/rustc_mir_build/src/check_unsafety.rs
index 6279d0f94af..5eed9ef798d 100644
--- a/compiler/rustc_mir_build/src/check_unsafety.rs
+++ b/compiler/rustc_mir_build/src/check_unsafety.rs
@@ -495,14 +495,9 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
                     };
                     self.requires_unsafe(expr.span, CallToUnsafeFunction(func_id));
                 } else if let &ty::FnDef(func_did, _) = fn_ty.kind() {
-                    // If the called function has target features the calling function hasn't,
-                    // the call requires `unsafe`. Don't check this on wasm
-                    // targets, though. For more information on wasm see the
-                    // is_like_wasm check in hir_analysis/src/collect.rs
-                    if !self.tcx.sess.target.options.is_like_wasm
-                        && !callee_features.iter().all(|feature| {
-                            self.body_target_features.iter().any(|f| f.name == feature.name)
-                        })
+                    if !self
+                        .tcx
+                        .is_target_feature_call_safe(callee_features, self.body_target_features)
                     {
                         let missing: Vec<_> = callee_features
                             .iter()