about summary refs log tree commit diff
diff options
context:
space:
mode:
authorXFFXFF <1247714429@qq.com>2023-03-11 10:35:55 +0800
committerXFFXFF <1247714429@qq.com>2023-03-15 09:07:11 +0800
commitbf0322cd0c841a31b1ce242333c2d457ab85ee5a (patch)
treebb282efb7def55a22b6203cabd79dd36d9fb3f6a
parent82780d8caf7bd034d4ef0e3a62789f61a047eb81 (diff)
downloadrust-bf0322cd0c841a31b1ce242333c2d457ab85ee5a.tar.gz
rust-bf0322cd0c841a31b1ce242333c2d457ab85ee5a.zip
pick the best ancestor expr of unsafe expr to add unsafe block. Thanks! @Veykril
-rw-r--r--crates/ide-diagnostics/src/handlers/missing_unsafe.rs266
1 files changed, 226 insertions, 40 deletions
diff --git a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs
index 60086ed4a4e..c709c8c482b 100644
--- a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs
+++ b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs
@@ -1,6 +1,5 @@
 use hir::db::AstDatabase;
 use ide_db::{assists::Assist, source_change::SourceChange};
-use syntax::ast::{ExprStmt, LetStmt};
 use syntax::AstNode;
 use syntax::{ast, SyntaxNode};
 use text_edit::TextEdit;
@@ -23,7 +22,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
     let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?;
     let expr = d.expr.value.to_node(&root);
 
-    let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block(ctx, &expr);
+    let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block(&expr);
 
     let replacement = format!("unsafe {{ {} }}", node_to_add_unsafe_block.text());
     let edit = TextEdit::replace(node_to_add_unsafe_block.text_range(), replacement);
@@ -32,39 +31,78 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
     Some(vec![fix("add_unsafe", "Add unsafe block", source_change, expr.syntax().text_range())])
 }
 
-// Find the let statement or expression statement closest to the `expr` in the
-// ancestor chain.
-//
-// Why don't we just add an unsafe block around the `expr`?
-//
-// Consider this example:
-// ```
-// STATIC_MUT += 1;
-// ```
-// We can't add an unsafe block to the left-hand side of an assignment.
-// ```
-// unsafe { STATIC_MUT } += 1;
-// ```
-//
-// Or this example:
-// ```
-// let z = STATIC_MUT.a;
-// ```
-// We can't add an unsafe block like this:
-// ```
-// let z = unsafe { STATIC_MUT } .a;
-// ```
-fn pick_best_node_to_add_unsafe_block(
-    ctx: &DiagnosticsContext<'_>,
-    expr: &ast::Expr,
-) -> SyntaxNode {
-    let Some(let_or_expr_stmt) = ctx.sema.ancestors_with_macros(expr.syntax().clone()).find(|node| {
-        LetStmt::can_cast(node.kind()) || ExprStmt::can_cast(node.kind())
-    }) else {
-        // Is this reachable?
-        return expr.syntax().clone();
-    };
-    let_or_expr_stmt
+// Pick the first ancestor expression of the unsafe `expr` that is not a
+// receiver of a method call, a field access, the left-hand side of an
+// assignment, or a reference. As all of those cases would incur a forced move
+// if wrapped which might not be wanted. That is:
+// - `unsafe_expr.foo` -> `unsafe { unsafe_expr.foo }`
+// - `unsafe_expr.foo.bar` -> `unsafe { unsafe_expr.foo.bar }`
+// - `unsafe_expr.foo()` -> `unsafe { unsafe_expr.foo() }`
+// - `unsafe_expr.foo.bar()` -> `unsafe { unsafe_expr.foo.bar() }`
+// - `unsafe_expr += 1` -> `unsafe { unsafe_expr += 1 }`
+// - `&unsafe_expr` -> `unsafe { &unsafe_expr }`
+// - `&&unsafe_expr` -> `unsafe { &&unsafe_expr }`
+fn pick_best_node_to_add_unsafe_block(unsafe_expr: &ast::Expr) -> SyntaxNode {
+    // The `unsafe_expr` might be:
+    // - `ast::CallExpr`: call an unsafe function
+    // - `ast::MethodCallExpr`: call an unsafe method
+    // - `ast::PrefixExpr`: dereference a raw pointer
+    // - `ast::PathExpr`: access a static mut variable
+    for node in unsafe_expr.syntax().ancestors() {
+        let Some(parent) = node.parent() else {
+            return node;
+        };
+        match parent.kind() {
+            syntax::SyntaxKind::METHOD_CALL_EXPR => {
+                // Check if the `node` is the receiver of the method call
+                let method_call_expr = ast::MethodCallExpr::cast(parent.clone()).unwrap();
+                if method_call_expr
+                    .receiver()
+                    .map(|receiver| {
+                        receiver.syntax().text_range().contains_range(node.text_range())
+                    })
+                    .unwrap_or(false)
+                {
+                    // Actually, I think it's not necessary to check whether the
+                    // text range of the `node` (which is the ancestor of the
+                    // `unsafe_expr`) is contained in the text range of the
+                    // receiver. The `node` could potentially be the receiver, the
+                    // method name, or the argument list. Since the `node` is the
+                    // ancestor of the unsafe_expr, it cannot be the method name.
+                    // Additionally, if the `node` is the argument list, the loop
+                    // would break at least when `parent` reaches the argument list.
+                    //
+                    // Dispite this, I still check the text range because I think it
+                    // makes the code easier to understand.
+                    continue;
+                }
+                return node;
+            }
+            syntax::SyntaxKind::FIELD_EXPR | syntax::SyntaxKind::REF_EXPR => continue,
+            syntax::SyntaxKind::BIN_EXPR => {
+                // Check if the `node` is the left-hand side of an assignment
+                let is_left_hand_side_of_assignment = {
+                    let bin_expr = ast::BinExpr::cast(parent.clone()).unwrap();
+                    if let Some(ast::BinaryOp::Assignment { .. }) = bin_expr.op_kind() {
+                        let is_left_hand_side = bin_expr
+                            .lhs()
+                            .map(|lhs| lhs.syntax().text_range().contains_range(node.text_range()))
+                            .unwrap_or(false);
+                        is_left_hand_side
+                    } else {
+                        false
+                    }
+                };
+                if !is_left_hand_side_of_assignment {
+                    return node;
+                }
+            }
+            _ => {
+                return node;
+            }
+        }
+    }
+    unsafe_expr.syntax().clone()
 }
 
 #[cfg(test)]
@@ -168,7 +206,7 @@ fn main() {
             r#"
 fn main() {
     let x = &5 as *const usize;
-    unsafe { let z = *x; }
+    let z = unsafe { *x };
 }
 "#,
         );
@@ -192,7 +230,7 @@ unsafe fn func() {
     let z = *x;
 }
 fn main() {
-    unsafe { func(); }
+    unsafe { func() };
 }
 "#,
         )
@@ -224,7 +262,7 @@ impl S {
 }
 fn main() {
     let s = S(5);
-    unsafe { s.func(); }
+    unsafe { s.func() };
 }
 "#,
         )
@@ -252,7 +290,7 @@ struct Ty {
 static mut STATIC_MUT: Ty = Ty { a: 0 };
 
 fn main() {
-    unsafe { let x = STATIC_MUT.a; }
+    let x = unsafe { STATIC_MUT.a };
 }
 "#,
         )
@@ -276,7 +314,155 @@ extern "rust-intrinsic" {
 }
 
 fn main() {
-    unsafe { let _ = floorf32(12.0); }
+    let _ = unsafe { floorf32(12.0) };
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn unsafe_expr_as_a_receiver_of_a_method_call() {
+        check_fix(
+            r#"
+unsafe fn foo() -> String {
+    "string".to_string()
+}
+
+fn main() {
+    foo$0().len();
+}
+"#,
+            r#"
+unsafe fn foo() -> String {
+    "string".to_string()
+}
+
+fn main() {
+    unsafe { foo().len() };
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn unsafe_expr_as_an_argument_of_a_method_call() {
+        check_fix(
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    let mut v = vec![];
+    v.push(STATIC_MUT$0);
+}
+"#,
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    let mut v = vec![];
+    v.push(unsafe { STATIC_MUT });
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn unsafe_expr_as_left_hand_side_of_assignment() {
+        check_fix(
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    STATIC_MUT$0 = 1;
+}
+"#,
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    unsafe { STATIC_MUT = 1 };
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn unsafe_expr_as_right_hand_side_of_assignment() {
+        check_fix(
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    let x;
+    x = STATIC_MUT$0;
+}
+"#,
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    let x;
+    x = unsafe { STATIC_MUT };
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn unsafe_expr_in_binary_plus() {
+        check_fix(
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    let x = STATIC_MUT$0 + 1;
+}
+"#,
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    let x = unsafe { STATIC_MUT } + 1;
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn ref_to_unsafe_expr() {
+        check_fix(
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    let x = &STATIC_MUT$0;
+}
+"#,
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    let x = unsafe { &STATIC_MUT };
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn ref_ref_to_unsafe_expr() {
+        check_fix(
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    let x = &&STATIC_MUT$0;
+}
+"#,
+            r#"
+static mut STATIC_MUT: u8 = 0;
+
+fn main() {
+    let x = unsafe { &&STATIC_MUT };
 }
 "#,
         )