about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDawer <7803845+iDawer@users.noreply.github.com>2021-04-17 01:09:09 +0500
committerDawer <7803845+iDawer@users.noreply.github.com>2021-04-17 01:09:09 +0500
commitedbb1797fb16523f3d49b048b8d8ee8fe1c48c99 (patch)
treeb3f82be2a6ab671710b07a6ede308fd00c1cb52d
parente53919a425bf062056a23e825fb30a51a639385c (diff)
downloadrust-edbb1797fb16523f3d49b048b8d8ee8fe1c48c99.tar.gz
rust-edbb1797fb16523f3d49b048b8d8ee8fe1c48c99.zip
Fill partial match arms for a tuple of enums
-rw-r--r--crates/ide_assists/src/handlers/fill_match_arms.rs54
1 files changed, 37 insertions, 17 deletions
diff --git a/crates/ide_assists/src/handlers/fill_match_arms.rs b/crates/ide_assists/src/handlers/fill_match_arms.rs
index 80bd1b7e8a4..e4794f17c95 100644
--- a/crates/ide_assists/src/handlers/fill_match_arms.rs
+++ b/crates/ide_assists/src/handlers/fill_match_arms.rs
@@ -1,5 +1,6 @@
 use std::iter;
 
+use either::Either;
 use hir::{Adt, HasSource, ModuleDef, Semantics};
 use ide_db::helpers::{mod_path_to_ast, FamousDefs};
 use ide_db::RootDatabase;
@@ -48,6 +49,16 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<
         }
     }
 
+    let top_lvl_pats: Vec<_> = arms
+        .iter()
+        .filter_map(ast::MatchArm::pat)
+        .flat_map(|pat| match pat {
+            // Special casee OrPat as separate top-level pats
+            Pat::OrPat(or_pat) => Either::Left(or_pat.pats()),
+            _ => Either::Right(iter::once(pat)),
+        })
+        .collect();
+
     let module = ctx.sema.scope(expr.syntax()).module()?;
 
     let missing_arms: Vec<MatchArm> = if let Some(enum_def) = resolve_enum_def(&ctx.sema, &expr) {
@@ -56,7 +67,7 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<
         let mut variants = variants
             .into_iter()
             .filter_map(|variant| build_pat(ctx.db(), module, variant))
-            .filter(|variant_pat| is_variant_missing(&mut arms, variant_pat))
+            .filter(|variant_pat| is_variant_missing(&top_lvl_pats, variant_pat))
             .map(|pat| make::match_arm(iter::once(pat), make::expr_empty_block()))
             .collect::<Vec<_>>();
         if Some(enum_def) == FamousDefs(&ctx.sema, Some(module.krate())).core_option_Option() {
@@ -66,11 +77,6 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<
         }
         variants
     } else if let Some(enum_defs) = resolve_tuple_of_enum_def(&ctx.sema, &expr) {
-        // Partial fill not currently supported for tuple of enums.
-        if !arms.is_empty() {
-            return None;
-        }
-
         // When calculating the match arms for a tuple of enums, we want
         // to create a match arm for each possible combination of enum
         // values. The `multi_cartesian_product` method transforms
@@ -85,7 +91,7 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<
                     variants.into_iter().filter_map(|variant| build_pat(ctx.db(), module, variant));
                 ast::Pat::from(make::tuple_pat(patterns))
             })
-            .filter(|variant_pat| is_variant_missing(&mut arms, variant_pat))
+            .filter(|variant_pat| is_variant_missing(&top_lvl_pats, variant_pat))
             .map(|pat| make::match_arm(iter::once(pat), make::expr_empty_block()))
             .collect()
     } else {
@@ -128,15 +134,14 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<
     )
 }
 
-fn is_variant_missing(existing_arms: &mut Vec<MatchArm>, var: &Pat) -> bool {
-    existing_arms.iter().filter_map(|arm| arm.pat()).all(|pat| {
-        // Special casee OrPat as separate top-level pats
-        let top_level_pats: Vec<Pat> = match pat {
-            Pat::OrPat(pats) => pats.pats().collect::<Vec<_>>(),
-            _ => vec![pat],
-        };
-
-        !top_level_pats.iter().any(|pat| does_pat_match_variant(pat, var))
+fn is_variant_missing(existing_pats: &[Pat], var: &Pat) -> bool {
+    !existing_pats.iter().any(|pat| match (pat, var) {
+        (Pat::TuplePat(tpat), Pat::TuplePat(tvar)) => {
+            // `does_pat_match_variant` gives false positives for tuple patterns
+            // Fixme: this is still somewhat limited
+            tpat.fields().zip(tvar.fields()).all(|(p, v)| does_pat_match_variant(&p, &v))
+        }
+        _ => does_pat_match_variant(pat, var),
     })
 }
 
@@ -467,7 +472,7 @@ fn main() {
 
     #[test]
     fn fill_match_arms_tuple_of_enum_partial() {
-        check_assist_not_applicable(
+        check_assist(
             fill_match_arms,
             r#"
             enum A { One, Two }
@@ -481,6 +486,21 @@ fn main() {
                 }
             }
             "#,
+            r#"
+            enum A { One, Two }
+            enum B { One, Two }
+
+            fn main() {
+                let a = A::One;
+                let b = B::One;
+                match (a, b) {
+                    (A::Two, B::One) => {}
+                    $0(A::One, B::One) => {}
+                    (A::One, B::Two) => {}
+                    (A::Two, B::Two) => {}
+                }
+            }
+            "#,
         );
     }