about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir/src/lib.rs19
-rw-r--r--crates/ide-assists/src/handlers/add_missing_match_arms.rs140
-rw-r--r--crates/syntax/src/ast/make.rs9
3 files changed, 157 insertions, 11 deletions
diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs
index 54342f1b7c4..012812cea24 100644
--- a/crates/hir/src/lib.rs
+++ b/crates/hir/src/lib.rs
@@ -50,6 +50,7 @@ use hir_def::{
     per_ns::PerNs,
     resolver::{HasResolver, Resolver},
     src::HasSource as _,
+    type_ref::ConstScalar,
     AdtId, AssocItemId, AssocItemLoc, AttrDefId, ConstId, ConstParamId, DefWithBodyId, EnumId,
     EnumVariantId, FunctionId, GenericDefId, HasModule, ImplId, ItemContainerId, LifetimeParamId,
     LocalEnumVariantId, LocalFieldId, Lookup, MacroExpander, MacroId, ModuleId, StaticId, StructId,
@@ -65,8 +66,9 @@ use hir_ty::{
     primitive::UintTy,
     traits::FnTrait,
     AliasTy, CallableDefId, CallableSig, Canonical, CanonicalVarKinds, Cast, ClosureId,
-    GenericArgData, Interner, ParamKind, QuantifiedWhereClause, Scalar, Substitution,
-    TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyDefId, TyExt, TyKind, WhereClause,
+    ConcreteConst, ConstValue, GenericArgData, Interner, ParamKind, QuantifiedWhereClause, Scalar,
+    Substitution, TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyDefId, TyExt, TyKind,
+    WhereClause,
 };
 use itertools::Itertools;
 use nameres::diagnostics::DefDiagnosticKind;
@@ -3232,6 +3234,19 @@ impl Type {
         }
     }
 
+    pub fn as_array(&self, _db: &dyn HirDatabase) -> Option<(Type, usize)> {
+        if let TyKind::Array(ty, len) = &self.ty.kind(Interner) {
+            match len.data(Interner).value {
+                ConstValue::Concrete(ConcreteConst { interned: ConstScalar::UInt(len) }) => {
+                    Some((self.derived(ty.clone()), len as usize))
+                }
+                _ => None,
+            }
+        } else {
+            None
+        }
+    }
+
     pub fn autoderef<'a>(&'a self, db: &'a dyn HirDatabase) -> impl Iterator<Item = Type> + 'a {
         self.autoderef_(db).map(move |ty| self.derived(ty))
     }
diff --git a/crates/ide-assists/src/handlers/add_missing_match_arms.rs b/crates/ide-assists/src/handlers/add_missing_match_arms.rs
index 0461cc790eb..5d81e8cfeac 100644
--- a/crates/ide-assists/src/handlers/add_missing_match_arms.rs
+++ b/crates/ide-assists/src/handlers/add_missing_match_arms.rs
@@ -140,6 +140,31 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
             })
             .filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat));
         ((Box::new(missing_pats) as Box<dyn Iterator<Item = _>>).peekable(), is_non_exhaustive)
+    } else if let Some((enum_def, len)) = resolve_array_of_enum_def(&ctx.sema, &expr) {
+        let is_non_exhaustive = enum_def.is_non_exhaustive(ctx.db(), module.krate());
+        let variants = enum_def.variants(ctx.db());
+
+        if len.pow(variants.len() as u32) > 256 {
+            return None;
+        }
+
+        let variants_of_enums = vec![variants.clone(); len];
+
+        let missing_pats = variants_of_enums
+            .into_iter()
+            .multi_cartesian_product()
+            .inspect(|_| cov_mark::hit!(add_missing_match_arms_lazy_computation))
+            .map(|variants| {
+                let is_hidden = variants
+                    .iter()
+                    .any(|variant| variant.should_be_hidden(ctx.db(), module.krate()));
+                let patterns = variants.into_iter().filter_map(|variant| {
+                    build_pat(ctx.db(), module, variant.clone(), ctx.config.prefer_no_std)
+                });
+                (ast::Pat::from(make::slice_pat(patterns)), is_hidden)
+            })
+            .filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat));
+        ((Box::new(missing_pats) as Box<dyn Iterator<Item = _>>).peekable(), is_non_exhaustive)
     } else {
         return None;
     };
@@ -266,6 +291,9 @@ fn is_variant_missing(existing_pats: &[Pat], var: &Pat) -> bool {
 fn does_pat_match_variant(pat: &Pat, var: &Pat) -> bool {
     match (pat, var) {
         (Pat::WildcardPat(_), _) => true,
+        (Pat::SlicePat(spat), Pat::SlicePat(svar)) => {
+            spat.pats().zip(svar.pats()).all(|(p, v)| does_pat_match_variant(&p, &v))
+        }
         (Pat::TuplePat(tpat), Pat::TuplePat(tvar)) => {
             tpat.fields().zip(tvar.fields()).all(|(p, v)| does_pat_match_variant(&p, &v))
         }
@@ -280,7 +308,7 @@ enum ExtendedEnum {
     Enum(hir::Enum),
 }
 
-#[derive(Eq, PartialEq, Clone, Copy)]
+#[derive(Eq, PartialEq, Clone, Copy, Debug)]
 enum ExtendedVariant {
     True,
     False,
@@ -340,15 +368,30 @@ fn resolve_tuple_of_enum_def(
         .tuple_fields(sema.db)
         .iter()
         .map(|ty| {
-            ty.autoderef(sema.db).find_map(|ty| match ty.as_adt() {
-                Some(Adt::Enum(e)) => Some(lift_enum(e)),
-                // For now we only handle expansion for a tuple of enums. Here
-                // we map non-enum items to None and rely on `collect` to
-                // convert Vec<Option<hir::Enum>> into Option<Vec<hir::Enum>>.
-                _ => ty.is_bool().then_some(ExtendedEnum::Bool),
+            ty.autoderef(sema.db).find_map(|ty| {
+                match ty.as_adt() {
+                    Some(Adt::Enum(e)) => Some(lift_enum(e)),
+                    // For now we only handle expansion for a tuple of enums. Here
+                    // we map non-enum items to None and rely on `collect` to
+                    // convert Vec<Option<hir::Enum>> into Option<Vec<hir::Enum>>.
+                    _ => ty.is_bool().then_some(ExtendedEnum::Bool),
+                }
             })
         })
-        .collect()
+        .collect::<Option<Vec<ExtendedEnum>>>()
+        .and_then(|list| if list.is_empty() { None } else { Some(list) })
+}
+
+fn resolve_array_of_enum_def(
+    sema: &Semantics<'_, RootDatabase>,
+    expr: &ast::Expr,
+) -> Option<(ExtendedEnum, usize)> {
+    sema.type_of_expr(expr)?.adjusted().as_array(sema.db).and_then(|(ty, len)| {
+        ty.autoderef(sema.db).find_map(|ty| match ty.as_adt() {
+            Some(Adt::Enum(e)) => Some((lift_enum(e), len)),
+            _ => ty.is_bool().then_some((ExtendedEnum::Bool, len)),
+        })
+    })
 }
 
 fn build_pat(
@@ -377,7 +420,6 @@ fn build_pat(
                 }
                 ast::StructKind::Unit => make::path_pat(path),
             };
-
             Some(pat)
         }
         ExtendedVariant::True => Some(ast::Pat::from(make::literal_pat("true"))),
@@ -574,6 +616,86 @@ fn foo(a: bool) {
     }
 
     #[test]
+    fn fill_boolean_array() {
+        check_assist(
+            add_missing_match_arms,
+            r#"
+fn foo(a: bool) {
+    match [a]$0 {
+    }
+}
+"#,
+            r#"
+fn foo(a: bool) {
+    match [a] {
+        $0[true] => todo!(),
+        [false] => todo!(),
+    }
+}
+"#,
+        );
+
+        check_assist(
+            add_missing_match_arms,
+            r#"
+fn foo(a: bool) {
+    match [a,]$0 {
+    }
+}
+"#,
+            r#"
+fn foo(a: bool) {
+    match [a,] {
+        $0[true] => todo!(),
+        [false] => todo!(),
+    }
+}
+"#,
+        );
+
+        check_assist(
+            add_missing_match_arms,
+            r#"
+fn foo(a: bool) {
+    match [a, a]$0 {
+        [true, true] => todo!(),
+    }
+}
+"#,
+            r#"
+fn foo(a: bool) {
+    match [a, a] {
+        [true, true] => todo!(),
+        $0[true, false] => todo!(),
+        [false, true] => todo!(),
+        [false, false] => todo!(),
+    }
+}
+"#,
+        );
+
+        check_assist(
+            add_missing_match_arms,
+            r#"
+fn foo(a: bool) {
+    match [a, a]$0 {
+    }
+}
+"#,
+            r#"
+fn foo(a: bool) {
+    match [a, a] {
+        $0[true, true] => todo!(),
+        [true, false] => todo!(),
+        [false, true] => todo!(),
+        [false, false] => todo!(),
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
     fn partial_fill_boolean_tuple() {
         check_assist(
             add_missing_match_arms,
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index d5b3296980c..a35983435c7 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -520,6 +520,15 @@ pub fn literal_pat(lit: &str) -> ast::LiteralPat {
     }
 }
 
+pub fn slice_pat(pats: impl IntoIterator<Item = ast::Pat>) -> ast::SlicePat {
+    let pats_str = pats.into_iter().join(", ");
+    return from_text(&format!("[{pats_str}]"));
+
+    fn from_text(text: &str) -> ast::SlicePat {
+        ast_from_text(&format!("fn f() {{ match () {{{text} => ()}} }}"))
+    }
+}
+
 /// Creates a tuple of patterns from an iterator of patterns.
 ///
 /// Invariant: `pats` must be length > 0