about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMaybe Waffle <waffle.lapkin@gmail.com>2024-04-14 18:07:40 +0000
committerWaffle Lapkin <waffle.lapkin@gmail.com>2024-05-02 03:47:32 +0200
commitaa0a916c81936ba725b7efb68804a4217b09b43a (patch)
treea16931d75e9692681bddd0fd2154ecd2c32ef449
parentff0bfea45f8e2f9f3081bd416b4610511b9a6384 (diff)
downloadrust-aa0a916c81936ba725b7efb68804a4217b09b43a.tar.gz
rust-aa0a916c81936ba725b7efb68804a4217b09b43a.zip
Add a lint against never type fallback affecting unsafe code
-rw-r--r--compiler/rustc_hir_typeck/messages.ftl4
-rw-r--r--compiler/rustc_hir_typeck/src/errors.rs6
-rw-r--r--compiler/rustc_hir_typeck/src/fallback.rs135
-rw-r--r--compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs4
-rw-r--r--compiler/rustc_lint_defs/src/builtin.rs44
-rw-r--r--tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.rs35
-rw-r--r--tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.stderr23
7 files changed, 241 insertions, 10 deletions
diff --git a/compiler/rustc_hir_typeck/messages.ftl b/compiler/rustc_hir_typeck/messages.ftl
index 07b4948872d..0caebf44a19 100644
--- a/compiler/rustc_hir_typeck/messages.ftl
+++ b/compiler/rustc_hir_typeck/messages.ftl
@@ -99,6 +99,10 @@ hir_typeck_lossy_provenance_ptr2int =
 
 hir_typeck_missing_parentheses_in_range = can't call method `{$method_name}` on type `{$ty_str}`
 
+hir_typeck_never_type_fallback_flowing_into_unsafe =
+    never type fallback affects this call to an `unsafe` function
+    .help = specify the type explicitly
+
 hir_typeck_no_associated_item = no {$item_kind} named `{$item_name}` found for {$ty_prefix} `{$ty_str}`{$trait_missing_method ->
     [true] {""}
     *[other] {" "}in the current scope
diff --git a/compiler/rustc_hir_typeck/src/errors.rs b/compiler/rustc_hir_typeck/src/errors.rs
index 1c4d5657b17..fcad88f829e 100644
--- a/compiler/rustc_hir_typeck/src/errors.rs
+++ b/compiler/rustc_hir_typeck/src/errors.rs
@@ -164,6 +164,11 @@ pub struct MissingParenthesesInRange {
     pub add_missing_parentheses: Option<AddMissingParenthesesInRange>,
 }
 
+#[derive(LintDiagnostic)]
+#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe)]
+#[help]
+pub struct NeverTypeFallbackFlowingIntoUnsafe {}
+
 #[derive(Subdiagnostic)]
 #[multipart_suggestion(
     hir_typeck_add_missing_parentheses_in_range,
@@ -632,7 +637,6 @@ pub enum SuggestBoxingForReturnImplTrait {
         ends: Vec<Span>,
     },
 }
-
 #[derive(LintDiagnostic)]
 #[diag(hir_typeck_dereferencing_mut_binding)]
 pub struct DereferencingMutBinding {
diff --git a/compiler/rustc_hir_typeck/src/fallback.rs b/compiler/rustc_hir_typeck/src/fallback.rs
index 3b00c7353e5..86a75aa4d78 100644
--- a/compiler/rustc_hir_typeck/src/fallback.rs
+++ b/compiler/rustc_hir_typeck/src/fallback.rs
@@ -1,10 +1,15 @@
-use crate::FnCtxt;
+use std::cell::OnceCell;
+
+use crate::{errors, FnCtxt};
 use rustc_data_structures::{
     graph::{self, iterate::DepthFirstSearch, vec_graph::VecGraph},
     unord::{UnordBag, UnordMap, UnordSet},
 };
+use rustc_hir::HirId;
 use rustc_infer::infer::{DefineOpaqueTypes, InferOk};
-use rustc_middle::ty::{self, Ty};
+use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitable};
+use rustc_session::lint;
+use rustc_span::Span;
 use rustc_span::DUMMY_SP;
 
 #[derive(Copy, Clone)]
@@ -335,6 +340,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
         // reach a member of N. If so, it falls back to `()`. Else
         // `!`.
         let mut diverging_fallback = UnordMap::with_capacity(diverging_vids.len());
+        let unsafe_infer_vars = OnceCell::new();
         for &diverging_vid in &diverging_vids {
             let diverging_ty = Ty::new_var(self.tcx, diverging_vid);
             let root_vid = self.root_var(diverging_vid);
@@ -354,11 +360,35 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
                 output: infer_var_infos.items().any(|info| info.output),
             };
 
+            let mut fallback_to = |ty| {
+                let unsafe_infer_vars = unsafe_infer_vars.get_or_init(|| {
+                    let unsafe_infer_vars = compute_unsafe_infer_vars(self.root_ctxt, self.body_id);
+                    debug!(?unsafe_infer_vars);
+                    unsafe_infer_vars
+                });
+
+                let affected_unsafe_infer_vars =
+                    graph::depth_first_search_as_undirected(&coercion_graph, root_vid)
+                        .filter_map(|x| unsafe_infer_vars.get(&x).copied())
+                        .collect::<Vec<_>>();
+
+                for (hir_id, span) in affected_unsafe_infer_vars {
+                    self.tcx.emit_node_span_lint(
+                        lint::builtin::NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE,
+                        hir_id,
+                        span,
+                        errors::NeverTypeFallbackFlowingIntoUnsafe {},
+                    );
+                }
+
+                diverging_fallback.insert(diverging_ty, ty);
+            };
+
             use DivergingFallbackBehavior::*;
             match behavior {
                 FallbackToUnit => {
                     debug!("fallback to () - legacy: {:?}", diverging_vid);
-                    diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
+                    fallback_to(self.tcx.types.unit);
                 }
                 FallbackToNiko => {
                     if found_infer_var_info.self_in_trait && found_infer_var_info.output {
@@ -387,13 +417,13 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
                         // set, see the relationship finding module in
                         // compiler/rustc_trait_selection/src/traits/relationships.rs.
                         debug!("fallback to () - found trait and projection: {:?}", diverging_vid);
-                        diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
+                        fallback_to(self.tcx.types.unit);
                     } else if can_reach_non_diverging {
                         debug!("fallback to () - reached non-diverging: {:?}", diverging_vid);
-                        diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
+                        fallback_to(self.tcx.types.unit);
                     } else {
                         debug!("fallback to ! - all diverging: {:?}", diverging_vid);
-                        diverging_fallback.insert(diverging_ty, self.tcx.types.never);
+                        fallback_to(self.tcx.types.never);
                     }
                 }
                 FallbackToNever => {
@@ -401,7 +431,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
                         "fallback to ! - `rustc_never_type_mode = \"fallback_to_never\")`: {:?}",
                         diverging_vid
                     );
-                    diverging_fallback.insert(diverging_ty, self.tcx.types.never);
+                    fallback_to(self.tcx.types.never);
                 }
                 NoFallback => {
                     debug!(
@@ -417,7 +447,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
 
     /// Returns a graph whose nodes are (unresolved) inference variables and where
     /// an edge `?A -> ?B` indicates that the variable `?A` is coerced to `?B`.
-    fn create_coercion_graph(&self) -> VecGraph<ty::TyVid> {
+    fn create_coercion_graph(&self) -> VecGraph<ty::TyVid, true> {
         let pending_obligations = self.fulfillment_cx.borrow_mut().pending_obligations();
         debug!("create_coercion_graph: pending_obligations={:?}", pending_obligations);
         let coercion_edges: Vec<(ty::TyVid, ty::TyVid)> = pending_obligations
@@ -451,6 +481,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
             .collect();
         debug!("create_coercion_graph: coercion_edges={:?}", coercion_edges);
         let num_ty_vars = self.num_ty_vars();
+
         VecGraph::new(num_ty_vars, coercion_edges)
     }
 
@@ -459,3 +490,91 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
         Some(self.root_var(self.shallow_resolve(ty).ty_vid()?))
     }
 }
+
+/// Finds all type variables which are passed to an `unsafe` function.
+///
+/// For example, for this function `f`:
+/// ```ignore (demonstrative)
+/// fn f() {
+///     unsafe {
+///         let x /* ?X */ = core::mem::zeroed();
+///         //               ^^^^^^^^^^^^^^^^^^^ -- hir_id, span
+///
+///         let y = core::mem::zeroed::<Option<_ /* ?Y */>>();
+///         //               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- hir_id, span
+///     }
+/// }
+/// ```
+///
+/// Will return `{ id(?X) -> (hir_id, span) }`
+fn compute_unsafe_infer_vars<'a, 'tcx>(
+    root_ctxt: &'a crate::TypeckRootCtxt<'tcx>,
+    body_id: rustc_span::def_id::LocalDefId,
+) -> UnordMap<ty::TyVid, (HirId, Span)> {
+    use rustc_hir as hir;
+
+    let tcx = root_ctxt.infcx.tcx;
+    let body_id = tcx.hir().maybe_body_owned_by(body_id).unwrap();
+    let body = tcx.hir().body(body_id);
+    let mut res = <_>::default();
+
+    struct UnsafeInferVarsVisitor<'a, 'tcx, 'r> {
+        root_ctxt: &'a crate::TypeckRootCtxt<'tcx>,
+        res: &'r mut UnordMap<ty::TyVid, (HirId, Span)>,
+    }
+
+    use hir::intravisit::Visitor;
+    impl hir::intravisit::Visitor<'_> for UnsafeInferVarsVisitor<'_, '_, '_> {
+        fn visit_expr(&mut self, ex: &'_ hir::Expr<'_>) {
+            // FIXME: method calls
+            if let hir::ExprKind::Call(func, ..) = ex.kind {
+                let typeck_results = self.root_ctxt.typeck_results.borrow();
+
+                let func_ty = typeck_results.expr_ty(func);
+
+                // `is_fn` is required to ignore closures (which can't be unsafe)
+                if func_ty.is_fn()
+                    && let sig = func_ty.fn_sig(self.root_ctxt.infcx.tcx)
+                    && let hir::Unsafety::Unsafe = sig.unsafety()
+                {
+                    let mut collector =
+                        InferVarCollector { hir_id: ex.hir_id, call_span: ex.span, res: self.res };
+
+                    // Collect generic arguments of the function which are inference variables
+                    typeck_results
+                        .node_args(ex.hir_id)
+                        .types()
+                        .for_each(|t| t.visit_with(&mut collector));
+
+                    // Also check the return type, for cases like `(unsafe_fn::<_> as unsafe fn() -> _)()`
+                    sig.output().visit_with(&mut collector);
+                }
+            }
+
+            hir::intravisit::walk_expr(self, ex);
+        }
+    }
+
+    struct InferVarCollector<'r> {
+        hir_id: HirId,
+        call_span: Span,
+        res: &'r mut UnordMap<ty::TyVid, (HirId, Span)>,
+    }
+
+    impl<'tcx> ty::TypeVisitor<TyCtxt<'tcx>> for InferVarCollector<'_> {
+        fn visit_ty(&mut self, t: Ty<'tcx>) {
+            if let Some(vid) = t.ty_vid() {
+                self.res.insert(vid, (self.hir_id, self.call_span));
+            } else {
+                use ty::TypeSuperVisitable as _;
+                t.super_visit_with(self)
+            }
+        }
+    }
+
+    UnsafeInferVarsVisitor { root_ctxt, res: &mut res }.visit_expr(&body.value);
+
+    debug!(?res, "collected the following unsafe vars for {body_id:?}");
+
+    res
+}
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs
index 2f96cf9e373..794b854ca5f 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs
@@ -5,12 +5,14 @@ mod checks;
 mod inspect_obligations;
 mod suggestions;
 
+use rustc_errors::ErrorGuaranteed;
+
 use crate::coercion::DynamicCoerceMany;
 use crate::fallback::DivergingFallbackBehavior;
 use crate::fn_ctxt::checks::DivergingBlockBehavior;
 use crate::{CoroutineTypes, Diverges, EnclosingBreakables, TypeckRootCtxt};
 use hir::def_id::CRATE_DEF_ID;
-use rustc_errors::{DiagCtxt, ErrorGuaranteed};
+use rustc_errors::DiagCtxt;
 use rustc_hir as hir;
 use rustc_hir::def_id::{DefId, LocalDefId};
 use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer;
diff --git a/compiler/rustc_lint_defs/src/builtin.rs b/compiler/rustc_lint_defs/src/builtin.rs
index 86a0f33a8d1..664c63da0fc 100644
--- a/compiler/rustc_lint_defs/src/builtin.rs
+++ b/compiler/rustc_lint_defs/src/builtin.rs
@@ -69,6 +69,7 @@ declare_lint_pass! {
         MISSING_FRAGMENT_SPECIFIER,
         MUST_NOT_SUSPEND,
         NAMED_ARGUMENTS_USED_POSITIONALLY,
+        NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE,
         NON_CONTIGUOUS_RANGE_ENDPOINTS,
         NON_EXHAUSTIVE_OMITTED_PATTERNS,
         ORDER_DEPENDENT_TRAIT_OBJECTS,
@@ -4246,6 +4247,49 @@ declare_lint! {
 }
 
 declare_lint! {
+    /// The `never_type_fallback_flowing_into_unsafe` lint detects cases where never type fallback
+    /// affects unsafe function calls.
+    ///
+    /// ### Example
+    ///
+    /// ```rust,compile_fail
+    /// #![deny(never_type_fallback_flowing_into_unsafe)]
+    /// fn main() {
+    ///     if true {
+    ///         // return has type `!` (never) which, is some cases, causes never type fallback
+    ///         return
+    ///     } else {
+    ///         // `zeroed` is an unsafe function, which returns an unbounded type
+    ///         unsafe { std::mem::zeroed() }
+    ///     };
+    ///     // depending on the fallback, `zeroed` may create `()` (which is completely sound),
+    ///     // or `!` (which is instant undefined behavior)
+    /// }
+    /// ```
+    ///
+    /// {{produces}}
+    ///
+    /// ### Explanation
+    ///
+    /// Due to historic reasons never type fallback were `()`, meaning that `!` got spontaneously
+    /// coerced to `()`. There are plans to change that, but they may make the code such as above
+    /// unsound. Instead of depending on the fallback, you should specify the type explicitly:
+    /// ```
+    /// if true {
+    ///     return
+    /// } else {
+    ///     // type is explicitly specified, fallback can't hurt us no more
+    ///     unsafe { std::mem::zeroed::<()>() }
+    /// };
+    /// ```
+    ///
+    /// See [Tracking Issue for making `!` fall back to `!`](https://github.com/rust-lang/rust/issues/123748).
+    pub NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE,
+    Warn,
+    "never type fallback affecting unsafe function calls"
+}
+
+declare_lint! {
     /// The `byte_slice_in_packed_struct_with_derive` lint detects cases where a byte slice field
     /// (`[u8]`) or string slice field (`str`) is used in a `packed` struct that derives one or
     /// more built-in traits.
diff --git a/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.rs b/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.rs
new file mode 100644
index 00000000000..f13e20cc0f2
--- /dev/null
+++ b/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.rs
@@ -0,0 +1,35 @@
+//@ check-pass
+use std::mem;
+
+fn main() {
+    if false {
+        unsafe { mem::zeroed() }
+        //~^ warn: never type fallback affects this call to an `unsafe` function
+    } else {
+        return;
+    };
+
+    // no ; -> type is inferred without fallback
+    if true { unsafe { mem::zeroed() } } else { return }
+}
+
+// Minimization of the famous `objc` crate issue
+fn _objc() {
+    pub unsafe fn send_message<R>() -> Result<R, ()> {
+        Ok(unsafe { core::mem::zeroed() })
+    }
+
+    macro_rules! msg_send {
+        () => {
+            match send_message::<_ /* ?0 */>() {
+                //~^ warn: never type fallback affects this call to an `unsafe` function
+                Ok(x) => x,
+                Err(_) => loop {},
+            }
+        };
+    }
+
+    unsafe {
+        msg_send!();
+    }
+}
diff --git a/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.stderr b/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.stderr
new file mode 100644
index 00000000000..1610804c29b
--- /dev/null
+++ b/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.stderr
@@ -0,0 +1,23 @@
+warning: never type fallback affects this call to an `unsafe` function
+  --> $DIR/lint-never-type-fallback-flowing-into-unsafe.rs:6:18
+   |
+LL |         unsafe { mem::zeroed() }
+   |                  ^^^^^^^^^^^^^
+   |
+   = help: specify the type explicitly
+   = note: `#[warn(never_type_fallback_flowing_into_unsafe)]` on by default
+
+warning: never type fallback affects this call to an `unsafe` function
+  --> $DIR/lint-never-type-fallback-flowing-into-unsafe.rs:24:19
+   |
+LL |             match send_message::<_ /* ?0 */>() {
+   |                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+...
+LL |         msg_send!();
+   |         ----------- in this macro invocation
+   |
+   = help: specify the type explicitly
+   = note: this warning originates in the macro `msg_send` (in Nightly builds, run with -Z macro-backtrace for more info)
+
+warning: 2 warnings emitted
+