about summary refs log tree commit diff
diff options
context:
space:
mode:
authorunexge <unexge@gmail.com>2022-10-29 23:45:13 +0100
committerunexge <unexge@gmail.com>2022-10-29 23:45:13 +0100
commit48efc9d30354eeb98b2992c002f753565e2fe07d (patch)
treec83c9ec766c36b1b3865d610228d2733fc3f45fb
parent319611b7382fc4c84170519dade68f4f558a44b1 (diff)
downloadrust-48efc9d30354eeb98b2992c002f753565e2fe07d.tar.gz
rust-48efc9d30354eeb98b2992c002f753565e2fe07d.zip
Add `Convert match to let-else` assist
-rw-r--r--crates/ide-assists/src/handlers/convert_match_to_let_else.rs400
-rw-r--r--crates/ide-assists/src/lib.rs2
2 files changed, 402 insertions, 0 deletions
diff --git a/crates/ide-assists/src/handlers/convert_match_to_let_else.rs b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs
new file mode 100644
index 00000000000..928016daab9
--- /dev/null
+++ b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs
@@ -0,0 +1,400 @@
+use ide_db::defs::{Definition, NameRefClass};
+use syntax::{
+    ast::{self, HasName},
+    ted, AstNode, SyntaxNode,
+};
+
+use crate::{
+    assist_context::{AssistContext, Assists},
+    AssistId, AssistKind,
+};
+
+// Assist: convert_match_to_let_else
+//
+// Converts let statement with match initializer to let-else statement.
+//
+// ```
+// fn foo(opt: Option<()>) {
+//     let val = $0match opt {
+//         Some(it) => it,
+//         None => return,
+//     };
+// }
+// ```
+// ->
+// ```
+// fn foo(opt: Option<()>) {
+//     let Some(val) = opt else { return };
+// }
+// ```
+pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
+    let let_stmt: ast::LetStmt = ctx.find_node_at_offset()?;
+    let binding = find_binding(let_stmt.pat()?)?;
+
+    let initializer = match let_stmt.initializer() {
+        Some(ast::Expr::MatchExpr(it)) => it,
+        _ => return None,
+    };
+    let initializer_expr = initializer.expr()?;
+
+    let (extracting_arm, diverging_arm) = match find_arms(ctx, &initializer) {
+        Some(it) => it,
+        None => return None,
+    };
+    if extracting_arm.guard().is_some() {
+        cov_mark::hit!(extracting_arm_has_guard);
+        return None;
+    }
+
+    let diverging_arm_expr = diverging_arm.expr()?;
+    let extracting_arm_pat = extracting_arm.pat()?;
+    let extracted_variable = find_extracted_variable(ctx, &extracting_arm)?;
+
+    acc.add(
+        AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite),
+        "Convert match to let-else",
+        let_stmt.syntax().text_range(),
+        |builder| {
+            let extracting_arm_pat = rename_variable(&extracting_arm_pat, extracted_variable, binding);
+            builder.replace(
+                let_stmt.syntax().text_range(),
+                format!("let {extracting_arm_pat} = {initializer_expr} else {{ {diverging_arm_expr} }};")
+            )
+        },
+    )
+}
+
+// Given a pattern, find the name introduced to the surrounding scope.
+fn find_binding(pat: ast::Pat) -> Option<ast::IdentPat> {
+    if let ast::Pat::IdentPat(ident) = pat {
+        Some(ident)
+    } else {
+        None
+    }
+}
+
+// Given a match expression, find extracting and diverging arms.
+fn find_arms(
+    ctx: &AssistContext<'_>,
+    match_expr: &ast::MatchExpr,
+) -> Option<(ast::MatchArm, ast::MatchArm)> {
+    let arms = match_expr.match_arm_list()?.arms().collect::<Vec<_>>();
+    if arms.len() != 2 {
+        return None;
+    }
+
+    let mut extracting = None;
+    let mut diverging = None;
+    for arm in arms {
+        if ctx.sema.is_diverging_match_arm(&arm)? {
+            diverging = Some(arm);
+        } else {
+            extracting = Some(arm);
+        }
+    }
+
+    match (extracting, diverging) {
+        (Some(extracting), Some(diverging)) => Some((extracting, diverging)),
+        _ => {
+            cov_mark::hit!(non_diverging_match);
+            None
+        }
+    }
+}
+
+// Given an extracting arm, find the extracted variable.
+fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<ast::Name> {
+    match arm.expr()? {
+        ast::Expr::PathExpr(path) => {
+            let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
+            match NameRefClass::classify(&ctx.sema, &name_ref)? {
+                NameRefClass::Definition(Definition::Local(local)) => {
+                    let source = local.source(ctx.db()).value.left()?;
+                    Some(source.name()?)
+                }
+                _ => None,
+            }
+        }
+        _ => {
+            cov_mark::hit!(extracting_arm_is_not_an_identity_expr);
+            return None;
+        }
+    }
+}
+
+// Rename `extracted` with `binding` in `pat`.
+fn rename_variable(pat: &ast::Pat, extracted: ast::Name, binding: ast::IdentPat) -> SyntaxNode {
+    let syntax = pat.syntax().clone_for_update();
+    let extracted_syntax = syntax.covering_element(extracted.syntax().text_range());
+
+    // If `extracted` variable is a record field, we should rename it to `binding`,
+    // otherwise we just need to replace `extracted` with `binding`.
+
+    if let Some(record_pat_field) = extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
+    {
+        if let Some(name_ref) = record_pat_field.field_name() {
+            ted::replace(
+                record_pat_field.syntax(),
+                ast::make::record_pat_field(ast::make::name_ref(&name_ref.text()), binding.into())
+                    .syntax()
+                    .clone_for_update(),
+            );
+        }
+    } else {
+        ted::replace(extracted_syntax, binding.syntax().clone_for_update());
+    }
+
+    syntax
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::tests::{check_assist, check_assist_not_applicable};
+
+    use super::*;
+
+    #[test]
+    fn should_not_be_applicable_for_non_diverging_match() {
+        cov_mark::check!(non_diverging_match);
+        check_assist_not_applicable(
+            convert_match_to_let_else,
+            r#"
+fn foo(opt: Option<()>) {
+    let val = $0match opt {
+        Some(it) => it,
+        None => (),
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
+        cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
+        check_assist_not_applicable(
+            convert_match_to_let_else,
+            r#"
+fn foo(opt: Option<()>) {
+    let val = $0match opt {
+        Some(it) => it + 1,
+        None => return,
+    };
+}
+"#,
+        );
+
+        check_assist_not_applicable(
+            convert_match_to_let_else,
+            r#"
+fn foo(opt: Option<()>) {
+    let val = $0match opt {
+        Some(it) => {
+            let _ = 1 + 1;
+            it
+        },
+        None => return,
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn should_not_be_applicable_if_extracting_arm_has_guard() {
+        cov_mark::check!(extracting_arm_has_guard);
+        check_assist_not_applicable(
+            convert_match_to_let_else,
+            r#"
+fn foo(opt: Option<()>) {
+    let val = $0match opt {
+        Some(it) if 2 > 1 => it,
+        None => return,
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn basic_pattern() {
+        check_assist(
+            convert_match_to_let_else,
+            r#"
+fn foo(opt: Option<()>) {
+    let val = $0match opt {
+        Some(it) => it,
+        None => return,
+    };
+}
+    "#,
+            r#"
+fn foo(opt: Option<()>) {
+    let Some(val) = opt else { return };
+}
+    "#,
+        );
+    }
+
+    #[test]
+    fn keeps_modifiers() {
+        check_assist(
+            convert_match_to_let_else,
+            r#"
+fn foo(opt: Option<()>) {
+    let ref mut val = $0match opt {
+        Some(it) => it,
+        None => return,
+    };
+}
+    "#,
+            r#"
+fn foo(opt: Option<()>) {
+    let Some(ref mut val) = opt else { return };
+}
+    "#,
+        );
+    }
+
+    #[test]
+    fn nested_pattern() {
+        check_assist(
+            convert_match_to_let_else,
+            r#"
+fn foo(opt: Option<Result<()>>) {
+    let val = $0match opt {
+        Some(Ok(it)) => it,
+        _ => return,
+    };
+}
+    "#,
+            r#"
+fn foo(opt: Option<Result<()>>) {
+    let Some(Ok(val)) = opt else { return };
+}
+    "#,
+        );
+    }
+
+    #[test]
+    fn works_with_any_diverging_block() {
+        check_assist(
+            convert_match_to_let_else,
+            r#"
+fn foo(opt: Option<()>) {
+    loop {
+        let val = $0match opt {
+            Some(it) => it,
+            None => break,
+        };
+    }
+}
+    "#,
+            r#"
+fn foo(opt: Option<()>) {
+    loop {
+        let Some(val) = opt else { break };
+    }
+}
+    "#,
+        );
+
+        check_assist(
+            convert_match_to_let_else,
+            r#"
+fn foo(opt: Option<()>) {
+    loop {
+        let val = $0match opt {
+            Some(it) => it,
+            None => continue,
+        };
+    }
+}
+    "#,
+            r#"
+fn foo(opt: Option<()>) {
+    loop {
+        let Some(val) = opt else { continue };
+    }
+}
+    "#,
+        );
+
+        check_assist(
+            convert_match_to_let_else,
+            r#"
+fn panic() -> ! {}
+
+fn foo(opt: Option<()>) {
+    loop {
+        let val = $0match opt {
+            Some(it) => it,
+            None => panic(),
+        };
+    }
+}
+    "#,
+            r#"
+fn panic() -> ! {}
+
+fn foo(opt: Option<()>) {
+    loop {
+        let Some(val) = opt else { panic() };
+    }
+}
+    "#,
+        );
+    }
+
+    #[test]
+    fn struct_pattern() {
+        check_assist(
+            convert_match_to_let_else,
+            r#"
+struct Point {
+    x: i32,
+    y: i32,
+}
+
+fn foo(opt: Option<Point>) {
+    let val = $0match opt {
+        Some(Point { x: 0, y }) => y,
+        _ => return,
+    };
+}
+    "#,
+            r#"
+struct Point {
+    x: i32,
+    y: i32,
+}
+
+fn foo(opt: Option<Point>) {
+    let Some(Point { x: 0, y: val }) = opt else { return };
+}
+    "#,
+        );
+    }
+
+    #[test]
+    fn renames_whole_binding() {
+        check_assist(
+            convert_match_to_let_else,
+            r#"
+fn foo(opt: Option<i32>) -> Option<i32> {
+    let val = $0match opt {
+        it @ Some(42) => it,
+        _ => return None,
+    };
+    val
+}
+    "#,
+            r#"
+fn foo(opt: Option<i32>) -> Option<i32> {
+    let val @ Some(42) = opt else { return None };
+    val
+}
+    "#,
+        );
+    }
+}
diff --git a/crates/ide-assists/src/lib.rs b/crates/ide-assists/src/lib.rs
index a07318cefad..387cc631428 100644
--- a/crates/ide-assists/src/lib.rs
+++ b/crates/ide-assists/src/lib.rs
@@ -120,6 +120,7 @@ mod handlers {
     mod convert_into_to_from;
     mod convert_iter_for_each_to_for;
     mod convert_let_else_to_match;
+    mod convert_match_to_let_else;
     mod convert_tuple_struct_to_named_struct;
     mod convert_named_struct_to_tuple_struct;
     mod convert_to_guarded_return;
@@ -220,6 +221,7 @@ mod handlers {
             convert_iter_for_each_to_for::convert_for_loop_with_for_each,
             convert_let_else_to_match::convert_let_else_to_match,
             convert_named_struct_to_tuple_struct::convert_named_struct_to_tuple_struct,
+            convert_match_to_let_else::convert_match_to_let_else,
             convert_to_guarded_return::convert_to_guarded_return,
             convert_tuple_struct_to_named_struct::convert_tuple_struct_to_named_struct,
             convert_two_arm_bool_match_to_matches_macro::convert_two_arm_bool_match_to_matches_macro,