about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-04-01 12:15:23 +0000
committerbors <bors@rust-lang.org>2024-04-01 12:15:23 +0000
commitc82d168b79fbe1e0711eb199d6decdd847c92fba (patch)
treef9a2e3d7e5318f4c8a1785e9eb983f536b258685
parenta6ddf5fb804f8f8cdde8682ab61d889387118004 (diff)
parent7e8f2d8fd3ceece8051beed5add23191696646a9 (diff)
downloadrust-c82d168b79fbe1e0711eb199d6decdd847c92fba.tar.gz
rust-c82d168b79fbe1e0711eb199d6decdd847c92fba.zip
Auto merge of #16979 - Nadrieril:contiguous-enum-id, r=Veykril
pattern analysis: Use contiguous indices for enum variants

The main blocker to using the in-tree version of the `pattern_analysis` crate is that rustc requires enum indices to be contiguous because it uses `IndexVec`/`BitSet` for performance. Currently we swap these out for `FxHashMap`/`FxHashSet` when the `rustc` feature is off, but we can't do that if we use the in-tree crate.

This PR solves the problem by using contiguous indices on the r-a side too.
-rw-r--r--crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs62
1 files changed, 47 insertions, 15 deletions
diff --git a/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs b/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs
index f45beb4c92b..0b24fe72276 100644
--- a/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs
+++ b/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs
@@ -3,7 +3,7 @@
 use std::fmt;
 use tracing::debug;
 
-use hir_def::{DefWithBodyId, EnumVariantId, HasModule, LocalFieldId, ModuleId, VariantId};
+use hir_def::{DefWithBodyId, EnumId, EnumVariantId, HasModule, LocalFieldId, ModuleId, VariantId};
 use rustc_hash::FxHashMap;
 use rustc_pattern_analysis::{
     constructor::{Constructor, ConstructorSet, VariantVisibility},
@@ -36,6 +36,24 @@ pub(crate) type WitnessPat<'p> = rustc_pattern_analysis::pat::WitnessPat<MatchCh
 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
 pub(crate) enum Void {}
 
+/// An index type for enum variants. This ranges from 0 to `variants.len()`, whereas `EnumVariantId`
+/// can take arbitrary large values (and hence mustn't be used with `IndexVec`/`BitSet`).
+#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
+pub(crate) struct EnumVariantContiguousIndex(usize);
+
+impl EnumVariantContiguousIndex {
+    fn from_enum_variant_id(db: &dyn HirDatabase, target_evid: EnumVariantId) -> Self {
+        // Find the index of this variant in the list of variants.
+        use hir_def::Lookup;
+        let i = target_evid.lookup(db.upcast()).index as usize;
+        EnumVariantContiguousIndex(i)
+    }
+
+    fn to_enum_variant_id(self, db: &dyn HirDatabase, eid: EnumId) -> EnumVariantId {
+        db.enum_data(eid).variants[self.0].0
+    }
+}
+
 #[derive(Clone)]
 pub(crate) struct MatchCheckCtx<'p> {
     module: ModuleId,
@@ -89,9 +107,18 @@ impl<'p> MatchCheckCtx<'p> {
         }
     }
 
-    fn variant_id_for_adt(ctor: &Constructor<Self>, adt: hir_def::AdtId) -> Option<VariantId> {
+    fn variant_id_for_adt(
+        db: &'p dyn HirDatabase,
+        ctor: &Constructor<Self>,
+        adt: hir_def::AdtId,
+    ) -> Option<VariantId> {
         match ctor {
-            &Variant(id) => Some(id.into()),
+            Variant(id) => {
+                let hir_def::AdtId::EnumId(eid) = adt else {
+                    panic!("bad constructor {ctor:?} for adt {adt:?}")
+                };
+                Some(id.to_enum_variant_id(db, eid).into())
+            }
             Struct | UnionField => match adt {
                 hir_def::AdtId::EnumId(_) => None,
                 hir_def::AdtId::StructId(id) => Some(id.into()),
@@ -175,19 +202,24 @@ impl<'p> MatchCheckCtx<'p> {
                         ctor = Struct;
                         arity = 1;
                     }
-                    &TyKind::Adt(adt, _) => {
+                    &TyKind::Adt(AdtId(adt), _) => {
                         ctor = match pat.kind.as_ref() {
-                            PatKind::Leaf { .. } if matches!(adt.0, hir_def::AdtId::UnionId(_)) => {
+                            PatKind::Leaf { .. } if matches!(adt, hir_def::AdtId::UnionId(_)) => {
                                 UnionField
                             }
                             PatKind::Leaf { .. } => Struct,
-                            PatKind::Variant { enum_variant, .. } => Variant(*enum_variant),
+                            PatKind::Variant { enum_variant, .. } => {
+                                Variant(EnumVariantContiguousIndex::from_enum_variant_id(
+                                    self.db,
+                                    *enum_variant,
+                                ))
+                            }
                             _ => {
                                 never!();
                                 Wildcard
                             }
                         };
-                        let variant = Self::variant_id_for_adt(&ctor, adt.0).unwrap();
+                        let variant = Self::variant_id_for_adt(self.db, &ctor, adt).unwrap();
                         arity = variant.variant_data(self.db.upcast()).fields().len();
                     }
                     _ => {
@@ -239,7 +271,7 @@ impl<'p> MatchCheckCtx<'p> {
                     PatKind::Deref { subpattern: subpatterns.next().unwrap() }
                 }
                 TyKind::Adt(adt, substs) => {
-                    let variant = Self::variant_id_for_adt(pat.ctor(), adt.0).unwrap();
+                    let variant = Self::variant_id_for_adt(self.db, pat.ctor(), adt.0).unwrap();
                     let subpatterns = self
                         .list_variant_fields(pat.ty(), variant)
                         .zip(subpatterns)
@@ -277,7 +309,7 @@ impl<'p> MatchCheckCtx<'p> {
 impl<'p> PatCx for MatchCheckCtx<'p> {
     type Error = ();
     type Ty = Ty;
-    type VariantIdx = EnumVariantId;
+    type VariantIdx = EnumVariantContiguousIndex;
     type StrLit = Void;
     type ArmData = ();
     type PatData = PatData<'p>;
@@ -303,7 +335,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
                         // patterns. If we're here we can assume this is a box pattern.
                         1
                     } else {
-                        let variant = Self::variant_id_for_adt(ctor, adt).unwrap();
+                        let variant = Self::variant_id_for_adt(self.db, ctor, adt).unwrap();
                         variant.variant_data(self.db.upcast()).fields().len()
                     }
                 }
@@ -343,7 +375,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
                         let subst_ty = substs.at(Interner, 0).assert_ty_ref(Interner).clone();
                         single(subst_ty)
                     } else {
-                        let variant = Self::variant_id_for_adt(ctor, adt).unwrap();
+                        let variant = Self::variant_id_for_adt(self.db, ctor, adt).unwrap();
                         let (adt, _) = ty.as_adt().unwrap();
 
                         let adt_is_local =
@@ -421,7 +453,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
                     ConstructorSet::NoConstructors
                 } else {
                     let mut variants = FxHashMap::default();
-                    for &(variant, _) in enum_data.variants.iter() {
+                    for (i, &(variant, _)) in enum_data.variants.iter().enumerate() {
                         let is_uninhabited =
                             is_enum_variant_uninhabited_from(variant, subst, cx.module, cx.db);
                         let visibility = if is_uninhabited {
@@ -429,7 +461,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
                         } else {
                             VariantVisibility::Visible
                         };
-                        variants.insert(variant, visibility);
+                        variants.insert(EnumVariantContiguousIndex(i), visibility);
                     }
 
                     ConstructorSet::Variants {
@@ -453,10 +485,10 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
         f: &mut fmt::Formatter<'_>,
         pat: &rustc_pattern_analysis::pat::DeconstructedPat<Self>,
     ) -> fmt::Result {
+        let db = pat.data().db;
         let variant =
-            pat.ty().as_adt().and_then(|(adt, _)| Self::variant_id_for_adt(pat.ctor(), adt));
+            pat.ty().as_adt().and_then(|(adt, _)| Self::variant_id_for_adt(db, pat.ctor(), adt));
 
-        let db = pat.data().db;
         if let Some(variant) = variant {
             match variant {
                 VariantId::EnumVariantId(v) => {