about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJesse Bakker <github@jessebakker.com>2021-01-03 14:11:42 +0100
committerJesse Bakker <github@jessebakker.com>2021-01-03 15:46:57 +0100
commitba4c42af02d3fb37869e74d34dfd14a605e15c8e (patch)
treea0f381c9130531ef95df78bbcbb1b37d2c14854f
parentb47c63a4bcefe93be3e5fe97b2a57489f13da493 (diff)
downloadrust-ba4c42af02d3fb37869e74d34dfd14a605e15c8e.tar.gz
rust-ba4c42af02d3fb37869e74d34dfd14a605e15c8e.zip
Support assignment to FieldExpr for extract_assignment assist
-rw-r--r--crates/assists/src/handlers/extract_assignment.rs103
1 files changed, 89 insertions, 14 deletions
diff --git a/crates/assists/src/handlers/extract_assignment.rs b/crates/assists/src/handlers/extract_assignment.rs
index 281cf5d24e2..ae99598c00e 100644
--- a/crates/assists/src/handlers/extract_assignment.rs
+++ b/crates/assists/src/handlers/extract_assignment.rs
@@ -1,4 +1,3 @@
-use hir::AsName;
 use syntax::{
     ast::{self, edit::AstNodeEdit, make},
     AstNode,
@@ -38,15 +37,23 @@ use crate::{
 // }
 // ```
 pub(crate) fn extract_assigment(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
-    let name = ctx.find_node_at_offset::<ast::NameRef>()?.as_name();
+    let assign_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
+    let name_expr = if assign_expr.op_kind()? == ast::BinOp::Assignment {
+        assign_expr.lhs()?
+    } else {
+        return None;
+    };
 
     let (old_stmt, new_stmt) = if let Some(if_expr) = ctx.find_node_at_offset::<ast::IfExpr>() {
         (
             ast::Expr::cast(if_expr.syntax().to_owned())?,
-            exprify_if(&if_expr, &name)?.indent(if_expr.indent_level()),
+            exprify_if(&if_expr, &ctx.sema, &name_expr)?.indent(if_expr.indent_level()),
         )
     } else if let Some(match_expr) = ctx.find_node_at_offset::<ast::MatchExpr>() {
-        (ast::Expr::cast(match_expr.syntax().to_owned())?, exprify_match(&match_expr, &name)?)
+        (
+            ast::Expr::cast(match_expr.syntax().to_owned())?,
+            exprify_match(&match_expr, &ctx.sema, &name_expr)?,
+        )
     } else {
         return None;
     };
@@ -58,18 +65,22 @@ pub(crate) fn extract_assigment(acc: &mut Assists, ctx: &AssistContext) -> Optio
         "Extract assignment",
         old_stmt.syntax().text_range(),
         move |edit| {
-            edit.replace(old_stmt.syntax().text_range(), format!("{} = {};", name, expr_stmt));
+            edit.replace(old_stmt.syntax().text_range(), format!("{} = {};", name_expr, expr_stmt));
         },
     )
 }
 
-fn exprify_match(match_expr: &ast::MatchExpr, name: &hir::Name) -> Option<ast::Expr> {
+fn exprify_match(
+    match_expr: &ast::MatchExpr,
+    sema: &hir::Semantics<ide_db::RootDatabase>,
+    name: &ast::Expr,
+) -> Option<ast::Expr> {
     let new_arm_list = match_expr
         .match_arm_list()?
         .arms()
         .map(|arm| {
             if let ast::Expr::BlockExpr(block) = arm.expr()? {
-                let new_block = exprify_block(&block, name)?.indent(block.indent_level());
+                let new_block = exprify_block(&block, sema, name)?.indent(block.indent_level());
                 Some(arm.replace_descendant(block, new_block))
             } else {
                 None
@@ -82,21 +93,31 @@ fn exprify_match(match_expr: &ast::MatchExpr, name: &hir::Name) -> Option<ast::E
     Some(make::expr_match(match_expr.expr()?, new_arm_list))
 }
 
-fn exprify_if(statement: &ast::IfExpr, name: &hir::Name) -> Option<ast::Expr> {
-    let then_branch = exprify_block(&statement.then_branch()?, name)?;
+fn exprify_if(
+    statement: &ast::IfExpr,
+    sema: &hir::Semantics<ide_db::RootDatabase>,
+    name: &ast::Expr,
+) -> Option<ast::Expr> {
+    let then_branch = exprify_block(&statement.then_branch()?, sema, name)?;
     let else_branch = match statement.else_branch()? {
-        ast::ElseBranch::Block(ref block) => ast::ElseBranch::Block(exprify_block(block, name)?),
+        ast::ElseBranch::Block(ref block) => {
+            ast::ElseBranch::Block(exprify_block(block, sema, name)?)
+        }
         ast::ElseBranch::IfExpr(expr) => {
             mark::hit!(test_extract_assigment_chained_if);
             ast::ElseBranch::IfExpr(ast::IfExpr::cast(
-                exprify_if(&expr, name)?.syntax().to_owned(),
+                exprify_if(&expr, sema, name)?.syntax().to_owned(),
             )?)
         }
     };
     Some(make::expr_if(statement.condition()?, then_branch, Some(else_branch)))
 }
 
-fn exprify_block(block: &ast::BlockExpr, name: &hir::Name) -> Option<ast::BlockExpr> {
+fn exprify_block(
+    block: &ast::BlockExpr,
+    sema: &hir::Semantics<ide_db::RootDatabase>,
+    name: &ast::Expr,
+) -> Option<ast::BlockExpr> {
     if block.expr().is_some() {
         return None;
     }
@@ -106,8 +127,7 @@ fn exprify_block(block: &ast::BlockExpr, name: &hir::Name) -> Option<ast::BlockE
 
     if let ast::Stmt::ExprStmt(stmt) = stmt {
         if let ast::Expr::BinExpr(expr) = stmt.expr()? {
-            if expr.op_kind()? == ast::BinOp::Assignment
-                && &expr.lhs()?.name_ref()?.as_name() == name
+            if expr.op_kind()? == ast::BinOp::Assignment && is_equivalent(sema, &expr.lhs()?, name)
             {
                 // The last statement in the block is an assignment to the name we want
                 return Some(make::block_expr(stmts, Some(expr.rhs()?)));
@@ -117,6 +137,29 @@ fn exprify_block(block: &ast::BlockExpr, name: &hir::Name) -> Option<ast::BlockE
     None
 }
 
+fn is_equivalent(
+    sema: &hir::Semantics<ide_db::RootDatabase>,
+    expr0: &ast::Expr,
+    expr1: &ast::Expr,
+) -> bool {
+    match (expr0, expr1) {
+        (ast::Expr::FieldExpr(field_expr0), ast::Expr::FieldExpr(field_expr1)) => {
+            mark::hit!(test_extract_assignment_field_assignment);
+            sema.resolve_field(field_expr0) == sema.resolve_field(field_expr1)
+        }
+        (ast::Expr::PathExpr(path0), ast::Expr::PathExpr(path1)) => {
+            let path0 = path0.path();
+            let path1 = path1.path();
+            if let (Some(path0), Some(path1)) = (path0, path1) {
+                sema.resolve_path(&path0) == sema.resolve_path(&path1)
+            } else {
+                false
+            }
+        }
+        _ => false,
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -322,4 +365,36 @@ fn foo() {
 }"#,
         )
     }
+
+    #[test]
+    fn test_extract_assignment_field_assignment() {
+        mark::check!(test_extract_assignment_field_assignment);
+        check_assist(
+            extract_assigment,
+            r#"
+struct A(usize);
+
+fn foo() {
+    let mut a = A(1);
+
+    if true {
+        <|>a.0 = 2;
+    } else {
+        a.0 = 3;
+    }
+}"#,
+            r#"
+struct A(usize);
+
+fn foo() {
+    let mut a = A(1);
+
+    a.0 = if true {
+        2
+    } else {
+        3
+    };
+}"#,
+        )
+    }
 }