about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--clippy_lints/src/ptr.rs56
-rw-r--r--clippy_utils/src/visitors.rs34
-rw-r--r--tests/ui/mut_from_ref.rs20
-rw-r--r--tests/ui/mut_from_ref.stderr14
4 files changed, 89 insertions, 35 deletions
diff --git a/clippy_lints/src/ptr.rs b/clippy_lints/src/ptr.rs
index ba1997e70e1..4ecde8f4958 100644
--- a/clippy_lints/src/ptr.rs
+++ b/clippy_lints/src/ptr.rs
@@ -3,6 +3,7 @@
 use clippy_utils::diagnostics::{span_lint, span_lint_and_sugg, span_lint_and_then};
 use clippy_utils::source::snippet_opt;
 use clippy_utils::ty::expr_sig;
+use clippy_utils::visitors::contains_unsafe_block;
 use clippy_utils::{get_expr_use_or_unification_node, is_lint_allowed, path_def_id, path_to_local, paths};
 use if_chain::if_chain;
 use rustc_errors::Applicability;
@@ -10,9 +11,9 @@ use rustc_hir::def_id::DefId;
 use rustc_hir::hir_id::HirIdMap;
 use rustc_hir::intravisit::{walk_expr, Visitor};
 use rustc_hir::{
-    self as hir, AnonConst, BinOpKind, BindingAnnotation, Body, Expr, ExprKind, FnDecl, FnRetTy, GenericArg,
+    self as hir, AnonConst, BinOpKind, BindingAnnotation, Body, Expr, ExprKind, FnRetTy, FnSig, GenericArg,
     ImplItemKind, ItemKind, Lifetime, LifetimeName, Mutability, Node, Param, ParamName, PatKind, QPath, TraitFn,
-    TraitItem, TraitItemKind, TyKind,
+    TraitItem, TraitItemKind, TyKind, Unsafety,
 };
 use rustc_lint::{LateContext, LateLintPass};
 use rustc_middle::hir::nested_filter;
@@ -145,7 +146,7 @@ impl<'tcx> LateLintPass<'tcx> for Ptr {
                 return;
             }
 
-            check_mut_from_ref(cx, sig.decl);
+            check_mut_from_ref(cx, sig, None);
             for arg in check_fn_args(
                 cx,
                 cx.tcx.fn_sig(item.def_id).skip_binder().inputs(),
@@ -170,10 +171,10 @@ impl<'tcx> LateLintPass<'tcx> for Ptr {
     fn check_body(&mut self, cx: &LateContext<'tcx>, body: &'tcx Body<'_>) {
         let hir = cx.tcx.hir();
         let mut parents = hir.parent_iter(body.value.hir_id);
-        let (item_id, decl, is_trait_item) = match parents.next() {
+        let (item_id, sig, is_trait_item) = match parents.next() {
             Some((_, Node::Item(i))) => {
                 if let ItemKind::Fn(sig, ..) = &i.kind {
-                    (i.def_id, sig.decl, false)
+                    (i.def_id, sig, false)
                 } else {
                     return;
                 }
@@ -185,14 +186,14 @@ impl<'tcx> LateLintPass<'tcx> for Ptr {
                     return;
                 }
                 if let ImplItemKind::Fn(sig, _) = &i.kind {
-                    (i.def_id, sig.decl, false)
+                    (i.def_id, sig, false)
                 } else {
                     return;
                 }
             },
             Some((_, Node::TraitItem(i))) => {
                 if let TraitItemKind::Fn(sig, _) = &i.kind {
-                    (i.def_id, sig.decl, true)
+                    (i.def_id, sig, true)
                 } else {
                     return;
                 }
@@ -200,7 +201,8 @@ impl<'tcx> LateLintPass<'tcx> for Ptr {
             _ => return,
         };
 
-        check_mut_from_ref(cx, decl);
+        check_mut_from_ref(cx, sig, Some(body));
+        let decl = sig.decl;
         let sig = cx.tcx.fn_sig(item_id).skip_binder();
         let lint_args: Vec<_> = check_fn_args(cx, sig.inputs(), decl.inputs, body.params)
             .filter(|arg| !is_trait_item || arg.mutability() == Mutability::Not)
@@ -478,31 +480,31 @@ fn check_fn_args<'cx, 'tcx: 'cx>(
         })
 }
 
-fn check_mut_from_ref(cx: &LateContext<'_>, decl: &FnDecl<'_>) {
-    if let FnRetTy::Return(ty) = decl.output {
-        if let Some((out, Mutability::Mut, _)) = get_rptr_lm(ty) {
-            let mut immutables = vec![];
-            for (_, mutbl, argspan) in decl
-                .inputs
-                .iter()
-                .filter_map(get_rptr_lm)
-                .filter(|&(lt, _, _)| lt.name == out.name)
-            {
-                if mutbl == Mutability::Mut {
-                    return;
-                }
-                immutables.push(argspan);
-            }
-            if immutables.is_empty() {
-                return;
-            }
+fn check_mut_from_ref<'tcx>(cx: &LateContext<'tcx>, sig: &FnSig<'_>, body: Option<&'tcx Body<'_>>) {
+    if let FnRetTy::Return(ty) = sig.decl.output
+        && let Some((out, Mutability::Mut, _)) = get_rptr_lm(ty)
+    {
+        let args: Option<Vec<_>> = sig
+            .decl
+            .inputs
+            .iter()
+            .filter_map(get_rptr_lm)
+            .filter(|&(lt, _, _)| lt.name == out.name)
+            .map(|(_, mutability, span)| (mutability == Mutability::Not).then(|| span))
+            .collect();
+        if let Some(args) = args
+            && !args.is_empty()
+            && body.map_or(true, |body| {
+                sig.header.unsafety == Unsafety::Unsafe || contains_unsafe_block(cx, &body.value)
+            })
+        {
             span_lint_and_then(
                 cx,
                 MUT_FROM_REF,
                 ty.span,
                 "mutable borrow from immutable input(s)",
                 |diag| {
-                    let ms = MultiSpan::from_spans(immutables);
+                    let ms = MultiSpan::from_spans(args);
                     diag.span_note(ms, "immutable borrow here");
                 },
             );
diff --git a/clippy_utils/src/visitors.rs b/clippy_utils/src/visitors.rs
index 40451b17a9c..3db64b25353 100644
--- a/clippy_utils/src/visitors.rs
+++ b/clippy_utils/src/visitors.rs
@@ -3,7 +3,8 @@ use rustc_hir as hir;
 use rustc_hir::def::{DefKind, Res};
 use rustc_hir::intravisit::{self, walk_block, walk_expr, Visitor};
 use rustc_hir::{
-    Arm, Block, BlockCheckMode, Body, BodyId, Expr, ExprKind, HirId, ItemId, ItemKind, Stmt, UnOp, Unsafety,
+    Arm, Block, BlockCheckMode, Body, BodyId, Expr, ExprKind, HirId, ItemId, ItemKind, Stmt, UnOp, UnsafeSource,
+    Unsafety,
 };
 use rustc_lint::LateContext;
 use rustc_middle::hir::map::Map;
@@ -370,3 +371,34 @@ pub fn is_expr_unsafe<'tcx>(cx: &LateContext<'tcx>, e: &'tcx Expr<'_>) -> bool {
     v.visit_expr(e);
     v.is_unsafe
 }
+
+/// Checks if the given expression contains an unsafe block
+pub fn contains_unsafe_block<'tcx>(cx: &LateContext<'tcx>, e: &'tcx Expr<'tcx>) -> bool {
+    struct V<'cx, 'tcx> {
+        cx: &'cx LateContext<'tcx>,
+        found_unsafe: bool,
+    }
+    impl<'tcx> Visitor<'tcx> for V<'_, 'tcx> {
+        type NestedFilter = nested_filter::OnlyBodies;
+        fn nested_visit_map(&mut self) -> Self::Map {
+            self.cx.tcx.hir()
+        }
+
+        fn visit_block(&mut self, b: &'tcx Block<'_>) {
+            if self.found_unsafe {
+                return;
+            }
+            if b.rules == BlockCheckMode::UnsafeBlock(UnsafeSource::UserProvided) {
+                self.found_unsafe = true;
+                return;
+            }
+            walk_block(self, b);
+        }
+    }
+    let mut v = V {
+        cx,
+        found_unsafe: false,
+    };
+    v.visit_expr(e);
+    v.found_unsafe
+}
diff --git a/tests/ui/mut_from_ref.rs b/tests/ui/mut_from_ref.rs
index a9a04c8f56b..370dbd58821 100644
--- a/tests/ui/mut_from_ref.rs
+++ b/tests/ui/mut_from_ref.rs
@@ -5,7 +5,7 @@ struct Foo;
 
 impl Foo {
     fn this_wont_hurt_a_bit(&self) -> &mut Foo {
-        unimplemented!()
+        unsafe { unimplemented!() }
     }
 }
 
@@ -15,29 +15,37 @@ trait Ouch {
 
 impl Ouch for Foo {
     fn ouch(x: &Foo) -> &mut Foo {
-        unimplemented!()
+        unsafe { unimplemented!() }
     }
 }
 
 fn fail(x: &u32) -> &mut u16 {
-    unimplemented!()
+    unsafe { unimplemented!() }
 }
 
 fn fail_lifetime<'a>(x: &'a u32, y: &mut u32) -> &'a mut u32 {
-    unimplemented!()
+    unsafe { unimplemented!() }
 }
 
 fn fail_double<'a, 'b>(x: &'a u32, y: &'a u32, z: &'b mut u32) -> &'a mut u32 {
-    unimplemented!()
+    unsafe { unimplemented!() }
 }
 
 // this is OK, because the result borrows y
 fn works<'a>(x: &u32, y: &'a mut u32) -> &'a mut u32 {
-    unimplemented!()
+    unsafe { unimplemented!() }
 }
 
 // this is also OK, because the result could borrow y
 fn also_works<'a>(x: &'a u32, y: &'a mut u32) -> &'a mut u32 {
+    unsafe { unimplemented!() }
+}
+
+unsafe fn also_broken(x: &u32) -> &mut u32 {
+    unimplemented!()
+}
+
+fn without_unsafe(x: &u32) -> &mut u32 {
     unimplemented!()
 }
 
diff --git a/tests/ui/mut_from_ref.stderr b/tests/ui/mut_from_ref.stderr
index 4787999920b..b76d6a13ffb 100644
--- a/tests/ui/mut_from_ref.stderr
+++ b/tests/ui/mut_from_ref.stderr
@@ -59,5 +59,17 @@ note: immutable borrow here
 LL | fn fail_double<'a, 'b>(x: &'a u32, y: &'a u32, z: &'b mut u32) -> &'a mut u32 {
    |                           ^^^^^^^     ^^^^^^^
 
-error: aborting due to 5 previous errors
+error: mutable borrow from immutable input(s)
+  --> $DIR/mut_from_ref.rs:44:35
+   |
+LL | unsafe fn also_broken(x: &u32) -> &mut u32 {
+   |                                   ^^^^^^^^
+   |
+note: immutable borrow here
+  --> $DIR/mut_from_ref.rs:44:26
+   |
+LL | unsafe fn also_broken(x: &u32) -> &mut u32 {
+   |                          ^^^^
+
+error: aborting due to 6 previous errors