diff options
| author | XFFXFF <1247714429@qq.com> | 2023-03-11 10:35:55 +0800 |
|---|---|---|
| committer | XFFXFF <1247714429@qq.com> | 2023-03-15 09:07:11 +0800 |
| commit | bf0322cd0c841a31b1ce242333c2d457ab85ee5a (patch) | |
| tree | bb282efb7def55a22b6203cabd79dd36d9fb3f6a | |
| parent | 82780d8caf7bd034d4ef0e3a62789f61a047eb81 (diff) | |
| download | rust-bf0322cd0c841a31b1ce242333c2d457ab85ee5a.tar.gz rust-bf0322cd0c841a31b1ce242333c2d457ab85ee5a.zip | |
pick the best ancestor expr of unsafe expr to add unsafe block. Thanks! @Veykril
| -rw-r--r-- | crates/ide-diagnostics/src/handlers/missing_unsafe.rs | 266 |
1 files changed, 226 insertions, 40 deletions
diff --git a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs index 60086ed4a4e..c709c8c482b 100644 --- a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs +++ b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs @@ -1,6 +1,5 @@ use hir::db::AstDatabase; use ide_db::{assists::Assist, source_change::SourceChange}; -use syntax::ast::{ExprStmt, LetStmt}; use syntax::AstNode; use syntax::{ast, SyntaxNode}; use text_edit::TextEdit; @@ -23,7 +22,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?; let expr = d.expr.value.to_node(&root); - let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block(ctx, &expr); + let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block(&expr); let replacement = format!("unsafe {{ {} }}", node_to_add_unsafe_block.text()); let edit = TextEdit::replace(node_to_add_unsafe_block.text_range(), replacement); @@ -32,39 +31,78 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass Some(vec![fix("add_unsafe", "Add unsafe block", source_change, expr.syntax().text_range())]) } -// Find the let statement or expression statement closest to the `expr` in the -// ancestor chain. -// -// Why don't we just add an unsafe block around the `expr`? -// -// Consider this example: -// ``` -// STATIC_MUT += 1; -// ``` -// We can't add an unsafe block to the left-hand side of an assignment. -// ``` -// unsafe { STATIC_MUT } += 1; -// ``` -// -// Or this example: -// ``` -// let z = STATIC_MUT.a; -// ``` -// We can't add an unsafe block like this: -// ``` -// let z = unsafe { STATIC_MUT } .a; -// ``` -fn pick_best_node_to_add_unsafe_block( - ctx: &DiagnosticsContext<'_>, - expr: &ast::Expr, -) -> SyntaxNode { - let Some(let_or_expr_stmt) = ctx.sema.ancestors_with_macros(expr.syntax().clone()).find(|node| { - LetStmt::can_cast(node.kind()) || ExprStmt::can_cast(node.kind()) - }) else { - // Is this reachable? - return expr.syntax().clone(); - }; - let_or_expr_stmt +// Pick the first ancestor expression of the unsafe `expr` that is not a +// receiver of a method call, a field access, the left-hand side of an +// assignment, or a reference. As all of those cases would incur a forced move +// if wrapped which might not be wanted. That is: +// - `unsafe_expr.foo` -> `unsafe { unsafe_expr.foo }` +// - `unsafe_expr.foo.bar` -> `unsafe { unsafe_expr.foo.bar }` +// - `unsafe_expr.foo()` -> `unsafe { unsafe_expr.foo() }` +// - `unsafe_expr.foo.bar()` -> `unsafe { unsafe_expr.foo.bar() }` +// - `unsafe_expr += 1` -> `unsafe { unsafe_expr += 1 }` +// - `&unsafe_expr` -> `unsafe { &unsafe_expr }` +// - `&&unsafe_expr` -> `unsafe { &&unsafe_expr }` +fn pick_best_node_to_add_unsafe_block(unsafe_expr: &ast::Expr) -> SyntaxNode { + // The `unsafe_expr` might be: + // - `ast::CallExpr`: call an unsafe function + // - `ast::MethodCallExpr`: call an unsafe method + // - `ast::PrefixExpr`: dereference a raw pointer + // - `ast::PathExpr`: access a static mut variable + for node in unsafe_expr.syntax().ancestors() { + let Some(parent) = node.parent() else { + return node; + }; + match parent.kind() { + syntax::SyntaxKind::METHOD_CALL_EXPR => { + // Check if the `node` is the receiver of the method call + let method_call_expr = ast::MethodCallExpr::cast(parent.clone()).unwrap(); + if method_call_expr + .receiver() + .map(|receiver| { + receiver.syntax().text_range().contains_range(node.text_range()) + }) + .unwrap_or(false) + { + // Actually, I think it's not necessary to check whether the + // text range of the `node` (which is the ancestor of the + // `unsafe_expr`) is contained in the text range of the + // receiver. The `node` could potentially be the receiver, the + // method name, or the argument list. Since the `node` is the + // ancestor of the unsafe_expr, it cannot be the method name. + // Additionally, if the `node` is the argument list, the loop + // would break at least when `parent` reaches the argument list. + // + // Dispite this, I still check the text range because I think it + // makes the code easier to understand. + continue; + } + return node; + } + syntax::SyntaxKind::FIELD_EXPR | syntax::SyntaxKind::REF_EXPR => continue, + syntax::SyntaxKind::BIN_EXPR => { + // Check if the `node` is the left-hand side of an assignment + let is_left_hand_side_of_assignment = { + let bin_expr = ast::BinExpr::cast(parent.clone()).unwrap(); + if let Some(ast::BinaryOp::Assignment { .. }) = bin_expr.op_kind() { + let is_left_hand_side = bin_expr + .lhs() + .map(|lhs| lhs.syntax().text_range().contains_range(node.text_range())) + .unwrap_or(false); + is_left_hand_side + } else { + false + } + }; + if !is_left_hand_side_of_assignment { + return node; + } + } + _ => { + return node; + } + } + } + unsafe_expr.syntax().clone() } #[cfg(test)] @@ -168,7 +206,7 @@ fn main() { r#" fn main() { let x = &5 as *const usize; - unsafe { let z = *x; } + let z = unsafe { *x }; } "#, ); @@ -192,7 +230,7 @@ unsafe fn func() { let z = *x; } fn main() { - unsafe { func(); } + unsafe { func() }; } "#, ) @@ -224,7 +262,7 @@ impl S { } fn main() { let s = S(5); - unsafe { s.func(); } + unsafe { s.func() }; } "#, ) @@ -252,7 +290,7 @@ struct Ty { static mut STATIC_MUT: Ty = Ty { a: 0 }; fn main() { - unsafe { let x = STATIC_MUT.a; } + let x = unsafe { STATIC_MUT.a }; } "#, ) @@ -276,7 +314,155 @@ extern "rust-intrinsic" { } fn main() { - unsafe { let _ = floorf32(12.0); } + let _ = unsafe { floorf32(12.0) }; +} +"#, + ) + } + + #[test] + fn unsafe_expr_as_a_receiver_of_a_method_call() { + check_fix( + r#" +unsafe fn foo() -> String { + "string".to_string() +} + +fn main() { + foo$0().len(); +} +"#, + r#" +unsafe fn foo() -> String { + "string".to_string() +} + +fn main() { + unsafe { foo().len() }; +} +"#, + ) + } + + #[test] + fn unsafe_expr_as_an_argument_of_a_method_call() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let mut v = vec![]; + v.push(STATIC_MUT$0); +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let mut v = vec![]; + v.push(unsafe { STATIC_MUT }); +} +"#, + ) + } + + #[test] + fn unsafe_expr_as_left_hand_side_of_assignment() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + STATIC_MUT$0 = 1; +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + unsafe { STATIC_MUT = 1 }; +} +"#, + ) + } + + #[test] + fn unsafe_expr_as_right_hand_side_of_assignment() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x; + x = STATIC_MUT$0; +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x; + x = unsafe { STATIC_MUT }; +} +"#, + ) + } + + #[test] + fn unsafe_expr_in_binary_plus() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = STATIC_MUT$0 + 1; +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = unsafe { STATIC_MUT } + 1; +} +"#, + ) + } + + #[test] + fn ref_to_unsafe_expr() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = &STATIC_MUT$0; +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = unsafe { &STATIC_MUT }; +} +"#, + ) + } + + #[test] + fn ref_ref_to_unsafe_expr() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = &&STATIC_MUT$0; +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = unsafe { &&STATIC_MUT }; } "#, ) |
