about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_borrowck/src/type_check/mod.rs15
-rw-r--r--compiler/rustc_hir_typeck/src/coercion.rs23
-rw-r--r--compiler/rustc_middle/src/ty/context.rs33
-rw-r--r--compiler/rustc_mir_build/src/check_unsafety.rs11
-rw-r--r--tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.rs14
-rw-r--r--tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.stderr19
-rw-r--r--tests/ui/rfcs/rfc-2396-target_feature-11/return-fn-ptr.rs22
7 files changed, 114 insertions, 23 deletions
diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs
index a1979c8b8ab..eca8a688ff4 100644
--- a/compiler/rustc_borrowck/src/type_check/mod.rs
+++ b/compiler/rustc_borrowck/src/type_check/mod.rs
@@ -1654,7 +1654,20 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                 match *cast_kind {
                     CastKind::PointerCoercion(PointerCoercion::ReifyFnPointer, coercion_source) => {
                         let is_implicit_coercion = coercion_source == CoercionSource::Implicit;
-                        let src_sig = op.ty(body, tcx).fn_sig(tcx);
+                        let src_ty = op.ty(body, tcx);
+                        let mut src_sig = src_ty.fn_sig(tcx);
+                        if let ty::FnDef(def_id, _) = src_ty.kind()
+                            && let ty::FnPtr(_, target_hdr) = *ty.kind()
+                            && tcx.codegen_fn_attrs(def_id).safe_target_features
+                            && target_hdr.safety.is_safe()
+                            && let Some(safe_sig) = tcx.adjust_target_feature_sig(
+                                *def_id,
+                                src_sig,
+                                body.source.def_id(),
+                            )
+                        {
+                            src_sig = safe_sig;
+                        }
 
                         // HACK: This shouldn't be necessary... We can remove this when we actually
                         // get binders with where clauses, then elaborate implied bounds into that
diff --git a/compiler/rustc_hir_typeck/src/coercion.rs b/compiler/rustc_hir_typeck/src/coercion.rs
index ec7c1efa38e..6945dbc3216 100644
--- a/compiler/rustc_hir_typeck/src/coercion.rs
+++ b/compiler/rustc_hir_typeck/src/coercion.rs
@@ -920,7 +920,7 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
 
         match b.kind() {
             ty::FnPtr(_, b_hdr) => {
-                let a_sig = a.fn_sig(self.tcx);
+                let mut a_sig = a.fn_sig(self.tcx);
                 if let ty::FnDef(def_id, _) = *a.kind() {
                     // Intrinsics are not coercible to function pointers
                     if self.tcx.intrinsic(def_id).is_some() {
@@ -932,19 +932,20 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
                         return Err(TypeError::ForceInlineCast);
                     }
 
-                    let fn_attrs = self.tcx.codegen_fn_attrs(def_id);
-                    if matches!(fn_attrs.inline, InlineAttr::Force { .. }) {
-                        return Err(TypeError::ForceInlineCast);
-                    }
-
-                    // FIXME(target_feature): Safe `#[target_feature]` functions could be cast to safe fn pointers (RFC 2396),
-                    // as you can already write that "cast" in user code by wrapping a target_feature fn call in a closure,
-                    // which is safe. This is sound because you already need to be executing code that is satisfying the target
-                    // feature constraints..
                     if b_hdr.safety.is_safe()
                         && self.tcx.codegen_fn_attrs(def_id).safe_target_features
                     {
-                        return Err(TypeError::TargetFeatureCast(def_id));
+                        // Allow the coercion if the current function has all the features that would be
+                        // needed to call the coercee safely.
+                        if let Some(safe_sig) = self.tcx.adjust_target_feature_sig(
+                            def_id,
+                            a_sig,
+                            self.fcx.body_id.into(),
+                        ) {
+                            a_sig = safe_sig;
+                        } else {
+                            return Err(TypeError::TargetFeatureCast(def_id));
+                        }
                     }
                 }
 
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index fab0047babf..7035e641f39 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -60,7 +60,7 @@ use crate::dep_graph::{DepGraph, DepKindStruct};
 use crate::infer::canonical::{CanonicalParamEnvCache, CanonicalVarInfo, CanonicalVarInfos};
 use crate::lint::lint_level;
 use crate::metadata::ModChild;
-use crate::middle::codegen_fn_attrs::CodegenFnAttrs;
+use crate::middle::codegen_fn_attrs::{CodegenFnAttrs, TargetFeature};
 use crate::middle::{resolve_bound_vars, stability};
 use crate::mir::interpret::{self, Allocation, ConstAllocation};
 use crate::mir::{Body, Local, Place, PlaceElem, ProjectionKind, Promoted};
@@ -1776,6 +1776,37 @@ impl<'tcx> TyCtxt<'tcx> {
     pub fn dcx(self) -> DiagCtxtHandle<'tcx> {
         self.sess.dcx()
     }
+
+    pub fn is_target_feature_call_safe(
+        self,
+        callee_features: &[TargetFeature],
+        body_features: &[TargetFeature],
+    ) -> bool {
+        // If the called function has target features the calling function hasn't,
+        // the call requires `unsafe`. Don't check this on wasm
+        // targets, though. For more information on wasm see the
+        // is_like_wasm check in hir_analysis/src/collect.rs
+        self.sess.target.options.is_like_wasm
+            || callee_features
+                .iter()
+                .all(|feature| body_features.iter().any(|f| f.name == feature.name))
+    }
+
+    /// Returns the safe version of the signature of the given function, if calling it
+    /// would be safe in the context of the given caller.
+    pub fn adjust_target_feature_sig(
+        self,
+        fun_def: DefId,
+        fun_sig: ty::Binder<'tcx, ty::FnSig<'tcx>>,
+        caller: DefId,
+    ) -> Option<ty::Binder<'tcx, ty::FnSig<'tcx>>> {
+        let fun_features = &self.codegen_fn_attrs(fun_def).target_features;
+        let callee_features = &self.codegen_fn_attrs(caller).target_features;
+        if self.is_target_feature_call_safe(&fun_features, &callee_features) {
+            return Some(fun_sig.map_bound(|sig| ty::FnSig { safety: hir::Safety::Safe, ..sig }));
+        }
+        None
+    }
 }
 
 impl<'tcx> TyCtxtAt<'tcx> {
diff --git a/compiler/rustc_mir_build/src/check_unsafety.rs b/compiler/rustc_mir_build/src/check_unsafety.rs
index 6279d0f94af..5eed9ef798d 100644
--- a/compiler/rustc_mir_build/src/check_unsafety.rs
+++ b/compiler/rustc_mir_build/src/check_unsafety.rs
@@ -495,14 +495,9 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
                     };
                     self.requires_unsafe(expr.span, CallToUnsafeFunction(func_id));
                 } else if let &ty::FnDef(func_did, _) = fn_ty.kind() {
-                    // If the called function has target features the calling function hasn't,
-                    // the call requires `unsafe`. Don't check this on wasm
-                    // targets, though. For more information on wasm see the
-                    // is_like_wasm check in hir_analysis/src/collect.rs
-                    if !self.tcx.sess.target.options.is_like_wasm
-                        && !callee_features.iter().all(|feature| {
-                            self.body_target_features.iter().any(|f| f.name == feature.name)
-                        })
+                    if !self
+                        .tcx
+                        .is_target_feature_call_safe(callee_features, self.body_target_features)
                     {
                         let missing: Vec<_> = callee_features
                             .iter()
diff --git a/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.rs b/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.rs
index 364b4d35812..d7c17299d06 100644
--- a/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.rs
+++ b/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.rs
@@ -2,9 +2,23 @@
 
 #![feature(target_feature_11)]
 
+#[target_feature(enable = "avx")]
+fn foo_avx() {}
+
 #[target_feature(enable = "sse2")]
 fn foo() {}
 
+#[target_feature(enable = "sse2")]
+fn bar() {
+    let foo: fn() = foo; // this is OK, as we have the necessary target features.
+    let foo: fn() = foo_avx; //~ ERROR mismatched types
+}
+
 fn main() {
+    if std::is_x86_feature_detected!("sse2") {
+        unsafe {
+            bar();
+        }
+    }
     let foo: fn() = foo; //~ ERROR mismatched types
 }
diff --git a/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.stderr b/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.stderr
index a2bda229d10..1228404120a 100644
--- a/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.stderr
+++ b/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.stderr
@@ -1,5 +1,20 @@
 error[E0308]: mismatched types
-  --> $DIR/fn-ptr.rs:9:21
+  --> $DIR/fn-ptr.rs:14:21
+   |
+LL | #[target_feature(enable = "avx")]
+   | --------------------------------- `#[target_feature]` added here
+...
+LL |     let foo: fn() = foo_avx;
+   |              ----   ^^^^^^^ cannot coerce functions with `#[target_feature]` to safe function pointers
+   |              |
+   |              expected due to this
+   |
+   = note: expected fn pointer `fn()`
+                 found fn item `#[target_features] fn() {foo_avx}`
+   = note: functions with `#[target_feature]` can only be coerced to `unsafe` function pointers
+
+error[E0308]: mismatched types
+  --> $DIR/fn-ptr.rs:23:21
    |
 LL | #[target_feature(enable = "sse2")]
    | ---------------------------------- `#[target_feature]` added here
@@ -13,6 +28,6 @@ LL |     let foo: fn() = foo;
                  found fn item `#[target_features] fn() {foo}`
    = note: functions with `#[target_feature]` can only be coerced to `unsafe` function pointers
 
-error: aborting due to 1 previous error
+error: aborting due to 2 previous errors
 
 For more information about this error, try `rustc --explain E0308`.
diff --git a/tests/ui/rfcs/rfc-2396-target_feature-11/return-fn-ptr.rs b/tests/ui/rfcs/rfc-2396-target_feature-11/return-fn-ptr.rs
new file mode 100644
index 00000000000..b49493d6609
--- /dev/null
+++ b/tests/ui/rfcs/rfc-2396-target_feature-11/return-fn-ptr.rs
@@ -0,0 +1,22 @@
+//@ only-x86_64
+//@ run-pass
+
+#![feature(target_feature_11)]
+
+#[target_feature(enable = "sse2")]
+fn foo() -> bool {
+    true
+}
+
+#[target_feature(enable = "sse2")]
+fn bar() -> fn() -> bool {
+    foo
+}
+
+fn main() {
+    if !std::is_x86_feature_detected!("sse2") {
+        return;
+    }
+    let f = unsafe { bar() };
+    assert!(f());
+}