about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2025-03-09 10:58:02 +0000
committerGitHub <noreply@github.com>2025-03-09 10:58:02 +0000
commit60da021da3574203f1e4ad37d8f01a0439f9b9e3 (patch)
tree2911f4dbd4a78e40510e47ac92dd521f74a95429
parentd11c5b8d75244a52f3578244aa5503ffa5893989 (diff)
parent965a0c016677dad9ee8141e7a446ce376b7177c1 (diff)
downloadrust-60da021da3574203f1e4ad37d8f01a0439f9b9e3.tar.gz
rust-60da021da3574203f1e4ad37d8f01a0439f9b9e3.zip
Merge pull request #19324 from ShoyuVanilla/migrate-inline-var
fix: Prevent wrong invocations of `needs_parens_in` with non-ancestral "parent"s
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs14
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs118
-rw-r--r--src/tools/rust-analyzer/crates/syntax/src/ast/prec.rs63
-rw-r--r--src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs21
-rw-r--r--src/tools/rust-analyzer/docs/book/src/assists_generated.md6
5 files changed, 174 insertions, 48 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs
index 77562c588e2..67bf8eed23d 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs
@@ -128,7 +128,9 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
                     let parent = neg_expr.syntax().parent();
                     editor = builder.make_editor(neg_expr.syntax());
 
-                    if parent.is_some_and(|parent| demorganed.needs_parens_in(&parent)) {
+                    if parent.is_some_and(|parent| {
+                        demorganed.needs_parens_in_place_of(&parent, neg_expr.syntax())
+                    }) {
                         cov_mark::hit!(demorgan_keep_parens_for_op_precedence2);
                         editor.replace(neg_expr.syntax(), make.expr_paren(demorganed).syntax());
                     } else {
@@ -392,15 +394,19 @@ fn f() { !(S <= S || S < S) }
 
     #[test]
     fn demorgan_keep_pars_for_op_precedence3() {
-        check_assist(apply_demorgan, "fn f() { (a || !(b &&$0 c); }", "fn f() { (a || !b || !c; }");
+        check_assist(
+            apply_demorgan,
+            "fn f() { (a || !(b &&$0 c); }",
+            "fn f() { (a || (!b || !c); }",
+        );
     }
 
     #[test]
-    fn demorgan_removes_pars_in_eq_precedence() {
+    fn demorgan_keeps_pars_in_eq_precedence() {
         check_assist(
             apply_demorgan,
             "fn() { let x = a && !(!b |$0| !c); }",
-            "fn() { let x = a && b && c; }",
+            "fn() { let x = a && (b && c); }",
         )
     }
 
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs
index cc7bea5152b..36eed290dc8 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs
@@ -5,7 +5,7 @@ use ide_db::{
     EditionedFileId, RootDatabase,
 };
 use syntax::{
-    ast::{self, AstNode, AstToken, HasName},
+    ast::{self, syntax_factory::SyntaxFactory, AstNode, AstToken, HasName},
     SyntaxElement, TextRange,
 };
 
@@ -43,22 +43,6 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>)
         }?;
     let initializer_expr = let_stmt.initializer()?;
 
-    let delete_range = delete_let.then(|| {
-        if let Some(whitespace) = let_stmt
-            .syntax()
-            .next_sibling_or_token()
-            .and_then(SyntaxElement::into_token)
-            .and_then(ast::Whitespace::cast)
-        {
-            TextRange::new(
-                let_stmt.syntax().text_range().start(),
-                whitespace.syntax().text_range().end(),
-            )
-        } else {
-            let_stmt.syntax().text_range()
-        }
-    });
-
     let wrap_in_parens = references
         .into_iter()
         .filter_map(|FileReference { range, name, .. }| match name {
@@ -73,40 +57,60 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>)
             }
             let usage_node =
                 name_ref.syntax().ancestors().find(|it| ast::PathExpr::can_cast(it.kind()));
-            let usage_parent_option = usage_node.and_then(|it| it.parent());
+            let usage_parent_option = usage_node.as_ref().and_then(|it| it.parent());
             let usage_parent = match usage_parent_option {
                 Some(u) => u,
-                None => return Some((range, name_ref, false)),
+                None => return Some((name_ref, false)),
             };
-            Some((range, name_ref, initializer_expr.needs_parens_in(&usage_parent)))
+            let should_wrap = initializer_expr
+                .needs_parens_in_place_of(&usage_parent, usage_node.as_ref().unwrap());
+            Some((name_ref, should_wrap))
         })
         .collect::<Option<Vec<_>>>()?;
 
-    let init_str = initializer_expr.syntax().text().to_string();
-    let init_in_paren = format!("({init_str})");
-
     let target = match target {
-        ast::NameOrNameRef::Name(it) => it.syntax().text_range(),
-        ast::NameOrNameRef::NameRef(it) => it.syntax().text_range(),
+        ast::NameOrNameRef::Name(it) => it.syntax().clone(),
+        ast::NameOrNameRef::NameRef(it) => it.syntax().clone(),
     };
 
     acc.add(
         AssistId("inline_local_variable", AssistKind::RefactorInline),
         "Inline variable",
-        target,
+        target.text_range(),
         move |builder| {
-            if let Some(range) = delete_range {
-                builder.delete(range);
+            let mut editor = builder.make_editor(&target);
+            if delete_let {
+                editor.delete(let_stmt.syntax());
+                if let Some(whitespace) = let_stmt
+                    .syntax()
+                    .next_sibling_or_token()
+                    .and_then(SyntaxElement::into_token)
+                    .and_then(ast::Whitespace::cast)
+                {
+                    editor.delete(whitespace.syntax());
+                }
             }
-            for (range, name, should_wrap) in wrap_in_parens {
-                let replacement = if should_wrap { &init_in_paren } else { &init_str };
-                if ast::RecordExprField::for_field_name(&name).is_some() {
+
+            let make = SyntaxFactory::new();
+
+            for (name, should_wrap) in wrap_in_parens {
+                let replacement = if should_wrap {
+                    make.expr_paren(initializer_expr.clone()).into()
+                } else {
+                    initializer_expr.clone()
+                };
+
+                if let Some(record_field) = ast::RecordExprField::for_field_name(&name) {
                     cov_mark::hit!(inline_field_shorthand);
-                    builder.insert(range.end(), format!(": {replacement}"));
+                    let replacement = make.record_expr_field(name, Some(replacement));
+                    editor.replace(record_field.syntax(), replacement.syntax());
                 } else {
-                    builder.replace(range, replacement.clone())
+                    editor.replace(name.syntax(), replacement.syntax());
                 }
             }
+
+            editor.add_mappings(make.finish_with_mappings());
+            builder.add_file_edits(ctx.file_id(), editor);
         },
     )
 }
@@ -942,4 +946,52 @@ fn main() {
 "#,
         );
     }
+
+    #[test]
+    fn test_wrap_in_parens() {
+        check_assist(
+            inline_local_variable,
+            r#"
+fn main() {
+    let $0a = 123 < 456;
+    let b = !a;
+}
+"#,
+            r#"
+fn main() {
+    let b = !(123 < 456);
+}
+"#,
+        );
+        check_assist(
+            inline_local_variable,
+            r#"
+trait Foo {
+    fn foo(&self);
+}
+
+impl Foo for bool {
+    fn foo(&self) {}
+}
+
+fn main() {
+    let $0a = 123 < 456;
+    let b = a.foo();
+}
+"#,
+            r#"
+trait Foo {
+    fn foo(&self);
+}
+
+impl Foo for bool {
+    fn foo(&self) {}
+}
+
+fn main() {
+    let b = (123 < 456).foo();
+}
+"#,
+        );
+    }
 }
diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/prec.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/prec.rs
index 0c4da762992..4f0e2cad174 100644
--- a/src/tools/rust-analyzer/crates/syntax/src/ast/prec.rs
+++ b/src/tools/rust-analyzer/crates/syntax/src/ast/prec.rs
@@ -1,5 +1,7 @@
 //! Precedence representation.
 
+use stdx::always;
+
 use crate::{
     ast::{self, BinaryOp, Expr, HasArgList, RangeItem},
     match_ast, AstNode, SyntaxNode,
@@ -140,6 +142,22 @@ pub fn precedence(expr: &ast::Expr) -> ExprPrecedence {
     }
 }
 
+fn check_ancestry(ancestor: &SyntaxNode, descendent: &SyntaxNode) -> bool {
+    let bail = || always!(false, "{} is not an ancestor of {}", ancestor, descendent);
+
+    if !ancestor.text_range().contains_range(descendent.text_range()) {
+        return bail();
+    }
+
+    for anc in descendent.ancestors() {
+        if anc == *ancestor {
+            return true;
+        }
+    }
+
+    bail()
+}
+
 impl Expr {
     pub fn precedence(&self) -> ExprPrecedence {
         precedence(self)
@@ -153,9 +171,19 @@ impl Expr {
 
     /// Returns `true` if `self` would need to be wrapped in parentheses given that its parent is `parent`.
     pub fn needs_parens_in(&self, parent: &SyntaxNode) -> bool {
+        self.needs_parens_in_place_of(parent, self.syntax())
+    }
+
+    /// Returns `true` if `self` would need to be wrapped in parentheses if it replaces `place_of`
+    /// given that `place_of`'s parent is `parent`.
+    pub fn needs_parens_in_place_of(&self, parent: &SyntaxNode, place_of: &SyntaxNode) -> bool {
+        if !check_ancestry(parent, place_of) {
+            return false;
+        }
+
         match_ast! {
             match parent {
-                ast::Expr(e) => self.needs_parens_in_expr(&e),
+                ast::Expr(e) => self.needs_parens_in_expr(&e, place_of),
                 ast::Stmt(e) => self.needs_parens_in_stmt(Some(&e)),
                 ast::StmtList(_) => self.needs_parens_in_stmt(None),
                 ast::ArgList(_) => false,
@@ -165,7 +193,7 @@ impl Expr {
         }
     }
 
-    fn needs_parens_in_expr(&self, parent: &Expr) -> bool {
+    fn needs_parens_in_expr(&self, parent: &Expr, place_of: &SyntaxNode) -> bool {
         // Parentheses are necessary when calling a function-like pointer that is a member of a struct or union
         // (e.g. `(a.f)()`).
         let is_parent_call_expr = matches!(parent, ast::Expr::CallExpr(_));
@@ -199,13 +227,17 @@ impl Expr {
 
         if self.is_paren_like()
             || parent.is_paren_like()
-            || self.is_prefix() && (parent.is_prefix() || !self.is_ordered_before(parent))
-            || self.is_postfix() && (parent.is_postfix() || self.is_ordered_before(parent))
+            || self.is_prefix()
+                && (parent.is_prefix()
+                    || !self.is_ordered_before_parent_in_place_of(parent, place_of))
+            || self.is_postfix()
+                && (parent.is_postfix()
+                    || self.is_ordered_before_parent_in_place_of(parent, place_of))
         {
             return false;
         }
 
-        let (left, right, inv) = match self.is_ordered_before(parent) {
+        let (left, right, inv) = match self.is_ordered_before_parent_in_place_of(parent, place_of) {
             true => (self, parent, false),
             false => (parent, self, true),
         };
@@ -413,13 +445,28 @@ impl Expr {
         }
     }
 
-    fn is_ordered_before(&self, other: &Expr) -> bool {
+    fn is_ordered_before_parent_in_place_of(&self, parent: &Expr, place_of: &SyntaxNode) -> bool {
+        use rowan::TextSize;
         use Expr::*;
 
-        return order(self) < order(other);
+        let self_range = self.syntax().text_range();
+        let place_of_range = place_of.text_range();
+
+        let self_order_adjusted = order(self) - self_range.start() + place_of_range.start();
+
+        let parent_order = order(parent);
+        let parent_order_adjusted = if parent_order <= place_of_range.start() {
+            parent_order
+        } else if parent_order >= place_of_range.end() {
+            parent_order - place_of_range.len() + self_range.len()
+        } else {
+            return false;
+        };
+
+        return self_order_adjusted < parent_order_adjusted;
 
         /// Returns text range that can be used to compare two expression for order (which goes first).
-        fn order(this: &Expr) -> rowan::TextSize {
+        fn order(this: &Expr) -> TextSize {
             // For non-paren-like operators: get the operator itself
             let token = match this {
                 RangeExpr(e) => e.op_token(),
diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
index 85393ca5b4c..44f13041c24 100644
--- a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
+++ b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
@@ -783,6 +783,27 @@ impl SyntaxFactory {
         ast
     }
 
+    pub fn record_expr_field(
+        &self,
+        name: ast::NameRef,
+        expr: Option<ast::Expr>,
+    ) -> ast::RecordExprField {
+        let ast = make::record_expr_field(name.clone(), expr.clone()).clone_for_update();
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+
+            builder.map_node(name.syntax().clone(), ast.name_ref().unwrap().syntax().clone());
+            if let Some(expr) = expr {
+                builder.map_node(expr.syntax().clone(), ast.expr().unwrap().syntax().clone());
+            }
+
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
     pub fn record_field_list(
         &self,
         fields: impl IntoIterator<Item = ast::RecordField>,
diff --git a/src/tools/rust-analyzer/docs/book/src/assists_generated.md b/src/tools/rust-analyzer/docs/book/src/assists_generated.md
index 918ae4a5794..9a801851792 100644
--- a/src/tools/rust-analyzer/docs/book/src/assists_generated.md
+++ b/src/tools/rust-analyzer/docs/book/src/assists_generated.md
@@ -306,7 +306,7 @@ fn main() {
 
 
 ### `apply_demorgan_iterator`
-**Source:**  [apply_demorgan.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/apply_demorgan.rs#L154) 
+**Source:**  [apply_demorgan.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/apply_demorgan.rs#L156) 
 
 Apply [De Morgan's law](https://en.wikipedia.org/wiki/De_Morgan%27s_laws) to
 `Iterator::all` and `Iterator::any`.
@@ -1070,7 +1070,7 @@ pub use foo::{Bar, Baz};
 
 
 ### `expand_record_rest_pattern`
-**Source:**  [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L24) 
+**Source:**  [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L26) 
 
 Fills fields by replacing rest pattern in record patterns.
 
@@ -1094,7 +1094,7 @@ fn foo(bar: Bar) {
 
 
 ### `expand_tuple_struct_rest_pattern`
-**Source:**  [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L80) 
+**Source:**  [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L82) 
 
 Fills fields by replacing rest pattern in tuple struct patterns.