about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide-assists/src/handlers/convert_match_to_let_else.rs85
1 files changed, 63 insertions, 22 deletions
diff --git a/crates/ide-assists/src/handlers/convert_match_to_let_else.rs b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs
index 745a870ab6b..7f2c01772ba 100644
--- a/crates/ide-assists/src/handlers/convert_match_to_let_else.rs
+++ b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs
@@ -1,6 +1,6 @@
 use ide_db::defs::{Definition, NameRefClass};
 use syntax::{
-    ast::{self, HasName},
+    ast::{self, HasName, Name},
     ted, AstNode, SyntaxNode,
 };
 
@@ -48,7 +48,7 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'
         other => format!("{{ {other} }}"),
     };
     let extracting_arm_pat = extracting_arm.pat()?;
-    let extracted_variable = find_extracted_variable(ctx, &extracting_arm)?;
+    let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?;
 
     acc.add(
         AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite),
@@ -56,7 +56,7 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'
         let_stmt.syntax().text_range(),
         |builder| {
             let extracting_arm_pat =
-                rename_variable(&extracting_arm_pat, extracted_variable, binding);
+                rename_variable(&extracting_arm_pat, &extracted_variable_positions, binding);
             builder.replace(
                 let_stmt.syntax().text_range(),
                 format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"),
@@ -95,14 +95,15 @@ fn find_arms(
 }
 
 // Given an extracting arm, find the extracted variable.
-fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<ast::Name> {
+fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<Vec<Name>> {
     match arm.expr()? {
         ast::Expr::PathExpr(path) => {
             let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
             match NameRefClass::classify(&ctx.sema, &name_ref)? {
                 NameRefClass::Definition(Definition::Local(local)) => {
-                    let source = local.primary_source(ctx.db()).into_ident_pat()?;
-                    Some(source.name()?)
+                    let source =
+                        local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name());
+                    source.collect()
                 }
                 _ => None,
             }
@@ -115,27 +116,34 @@ fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Opti
 }
 
 // Rename `extracted` with `binding` in `pat`.
-fn rename_variable(pat: &ast::Pat, extracted: ast::Name, binding: ast::Pat) -> SyntaxNode {
+fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode {
     let syntax = pat.syntax().clone_for_update();
-    let extracted_syntax = syntax.covering_element(extracted.syntax().text_range());
-
-    // If `extracted` variable is a record field, we should rename it to `binding`,
-    // otherwise we just need to replace `extracted` with `binding`.
-
-    if let Some(record_pat_field) = extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
-    {
-        if let Some(name_ref) = record_pat_field.field_name() {
-            ted::replace(
-                record_pat_field.syntax(),
-                ast::make::record_pat_field(ast::make::name_ref(&name_ref.text()), binding)
+    let extracted = extracted
+        .iter()
+        .map(|e| syntax.covering_element(e.syntax().text_range()))
+        .collect::<Vec<_>>();
+    for extracted_syntax in extracted {
+        // If `extracted` variable is a record field, we should rename it to `binding`,
+        // otherwise we just need to replace `extracted` with `binding`.
+
+        if let Some(record_pat_field) =
+            extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
+        {
+            if let Some(name_ref) = record_pat_field.field_name() {
+                ted::replace(
+                    record_pat_field.syntax(),
+                    ast::make::record_pat_field(
+                        ast::make::name_ref(&name_ref.text()),
+                        binding.clone(),
+                    )
                     .syntax()
                     .clone_for_update(),
-            );
+                );
+            }
+        } else {
+            ted::replace(extracted_syntax, binding.clone().syntax().clone_for_update());
         }
-    } else {
-        ted::replace(extracted_syntax, binding.syntax().clone_for_update());
     }
-
     syntax
 }
 
@@ -163,6 +171,39 @@ fn foo(opt: Option<()>) {
     }
 
     #[test]
+    fn or_pattern_multiple_binding() {
+        check_assist(
+            convert_match_to_let_else,
+            r#"
+//- minicore: option
+enum Foo {
+    A(u32),
+    B(u32),
+    C(String),
+}
+
+fn foo(opt: Option<Foo>) -> Result<u32, ()> {
+    let va$0lue = match opt {
+        Some(Foo::A(it) | Foo::B(it)) => it,
+        _ => return Err(()),
+    };
+}
+    "#,
+            r#"
+enum Foo {
+    A(u32),
+    B(u32),
+    C(String),
+}
+
+fn foo(opt: Option<Foo>) -> Result<u32, ()> {
+    let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
+}
+    "#,
+        );
+    }
+
+    #[test]
     fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
         cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
         check_assist_not_applicable(