about summary refs log tree commit diff
path: root/compiler/rustc_pattern_analysis
diff options
context:
space:
mode:
authorNadrieril <nadrieril+git@gmail.com>2024-02-06 03:12:21 +0100
committerNadrieril <nadrieril+git@gmail.com>2024-02-28 17:47:19 +0100
commit39441e4cdd46c61be6b86e4bfe352d1e7d5af6fb (patch)
treeef846c5a8581bf50d0a3e1089580ec356c9ad03d /compiler/rustc_pattern_analysis
parentea381663900b90c0f78c6a64cd5e0b1876047714 (diff)
downloadrust-39441e4cdd46c61be6b86e4bfe352d1e7d5af6fb.tar.gz
rust-39441e4cdd46c61be6b86e4bfe352d1e7d5af6fb.zip
Simplify
Diffstat (limited to 'compiler/rustc_pattern_analysis')
-rw-r--r--compiler/rustc_pattern_analysis/src/rustc.rs80
1 files changed, 34 insertions, 46 deletions
diff --git a/compiler/rustc_pattern_analysis/src/rustc.rs b/compiler/rustc_pattern_analysis/src/rustc.rs
index 4e90d1a7406..4f5f0383890 100644
--- a/compiler/rustc_pattern_analysis/src/rustc.rs
+++ b/compiler/rustc_pattern_analysis/src/rustc.rs
@@ -10,7 +10,7 @@ use rustc_middle::mir::interpret::Scalar;
 use rustc_middle::mir::{self, Const};
 use rustc_middle::thir::{FieldPat, Pat, PatKind, PatRange, PatRangeBoundary};
 use rustc_middle::ty::layout::IntegerExt;
-use rustc_middle::ty::{self, OpaqueTypeKey, Ty, TyCtxt, TypeVisitableExt, VariantDef};
+use rustc_middle::ty::{self, FieldDef, OpaqueTypeKey, Ty, TyCtxt, TypeVisitableExt, VariantDef};
 use rustc_session::lint;
 use rustc_span::{ErrorGuaranteed, Span, DUMMY_SP};
 use rustc_target::abi::{FieldIdx, Integer, VariantIdx, FIRST_VARIANT};
@@ -158,32 +158,19 @@ impl<'p, 'tcx: 'p> RustcMatchCheckCtxt<'p, 'tcx> {
         }
     }
 
-    // In the cases of either a `#[non_exhaustive]` field list or a non-public field, we hide
-    // uninhabited fields in order not to reveal the uninhabitedness of the whole variant.
-    // This lists the fields we keep along with their types.
-    pub(crate) fn list_variant_nonhidden_fields(
+    pub(crate) fn variant_sub_tys(
         &self,
         ty: RevealedTy<'tcx>,
         variant: &'tcx VariantDef,
-    ) -> impl Iterator<Item = (FieldIdx, RevealedTy<'tcx>, bool)> + Captures<'p> + Captures<'_>
+    ) -> impl Iterator<Item = (&'tcx FieldDef, RevealedTy<'tcx>)> + Captures<'p> + Captures<'_>
     {
-        let cx = self;
-        let ty::Adt(adt, args) = ty.kind() else { bug!() };
-        // Whether we must avoid matching the fields of this variant exhaustively.
-        let is_non_exhaustive = variant.is_field_list_non_exhaustive() && !adt.did().is_local();
-
-        variant.fields.iter().enumerate().map(move |(i, field)| {
-            let ty = field.ty(cx.tcx, args);
+        let ty::Adt(_, args) = ty.kind() else { bug!() };
+        variant.fields.iter().map(move |field| {
+            let ty = field.ty(self.tcx, args);
             // `field.ty()` doesn't normalize after instantiating.
-            let ty = cx.tcx.normalize_erasing_regions(cx.param_env, ty);
-            let is_visible = adt.is_enum() || field.vis.is_accessible_from(cx.module, cx.tcx);
-            let is_uninhabited = (cx.tcx.features().exhaustive_patterns
-                || cx.tcx.features().min_exhaustive_patterns)
-                && cx.is_uninhabited(ty);
-
-            let skip = is_uninhabited && (!is_visible || is_non_exhaustive);
-            let ty = cx.reveal_opaque_ty(ty);
-            (FieldIdx::new(i), ty, skip)
+            let ty = self.tcx.normalize_erasing_regions(self.param_env, ty);
+            let ty = self.reveal_opaque_ty(ty);
+            (field, ty)
         })
     }
 
@@ -230,9 +217,21 @@ impl<'p, 'tcx: 'p> RustcMatchCheckCtxt<'p, 'tcx> {
                     } else {
                         let variant =
                             &adt.variant(RustcMatchCheckCtxt::variant_index_for_adt(&ctor, *adt));
-                        let tys = cx
-                            .list_variant_nonhidden_fields(ty, variant)
-                            .map(|(_, ty, skip)| (ty, SkipField(skip)));
+
+                        // In the cases of either a `#[non_exhaustive]` field list or a non-public
+                        // field, we skip uninhabited fields in order not to reveal the
+                        // uninhabitedness of the whole variant.
+                        let is_non_exhaustive =
+                            variant.is_field_list_non_exhaustive() && !adt.did().is_local();
+                        let tys = cx.variant_sub_tys(ty, variant).map(|(field, ty)| {
+                            let is_visible =
+                                adt.is_enum() || field.vis.is_accessible_from(cx.module, cx.tcx);
+                            let is_uninhabited = (cx.tcx.features().exhaustive_patterns
+                                || cx.tcx.features().min_exhaustive_patterns)
+                                && cx.is_uninhabited(*ty);
+                            let skip = is_uninhabited && (!is_visible || is_non_exhaustive);
+                            (ty, SkipField(skip))
+                        });
                         cx.dropless_arena.alloc_from_iter(tys)
                     }
                 }
@@ -269,9 +268,8 @@ impl<'p, 'tcx: 'p> RustcMatchCheckCtxt<'p, 'tcx> {
                         // patterns. If we're here we can assume this is a box pattern.
                         1
                     } else {
-                        let variant =
-                            &adt.variant(RustcMatchCheckCtxt::variant_index_for_adt(&ctor, *adt));
-                        self.list_variant_nonhidden_fields(ty, variant).count()
+                        let variant_idx = RustcMatchCheckCtxt::variant_index_for_adt(&ctor, *adt);
+                        adt.variant(variant_idx).fields.len()
                     }
                 }
                 _ => bug!("Unexpected type for constructor `{ctor:?}`: {ty:?}"),
@@ -507,20 +505,12 @@ impl<'p, 'tcx: 'p> RustcMatchCheckCtxt<'p, 'tcx> {
                         };
                         let variant =
                             &adt.variant(RustcMatchCheckCtxt::variant_index_for_adt(&ctor, *adt));
-                        // For each field in the variant, we store the relevant index into `self.fields` if any.
-                        let mut field_id_to_id: Vec<Option<usize>> =
-                            (0..variant.fields.len()).map(|_| None).collect();
-                        let tys = cx.list_variant_nonhidden_fields(ty, variant).enumerate().map(
-                            |(i, (field, ty, _))| {
-                                field_id_to_id[field.index()] = Some(i);
-                                ty
-                            },
-                        );
-                        fields = tys.map(|ty| DeconstructedPat::wildcard(ty)).collect();
+                        fields = cx
+                            .variant_sub_tys(ty, variant)
+                            .map(|(_, ty)| DeconstructedPat::wildcard(ty))
+                            .collect();
                         for pat in subpatterns {
-                            if let Some(i) = field_id_to_id[pat.field.index()] {
-                                fields[i] = self.lower_pat(&pat.pattern);
-                            }
+                            fields[pat.field.index()] = self.lower_pat(&pat.pattern);
                         }
                     }
                     _ => bug!("pattern has unexpected type: pat: {:?}, ty: {:?}", pat, ty),
@@ -762,11 +752,9 @@ impl<'p, 'tcx: 'p> RustcMatchCheckCtxt<'p, 'tcx> {
                 ty::Adt(adt_def, args) => {
                     let variant_index =
                         RustcMatchCheckCtxt::variant_index_for_adt(&pat.ctor(), *adt_def);
-                    let variant = &adt_def.variant(variant_index);
-                    let subpatterns = cx
-                        .list_variant_nonhidden_fields(*pat.ty(), variant)
-                        .zip(subpatterns)
-                        .map(|((field, _ty, _), pattern)| FieldPat { field, pattern })
+                    let subpatterns = subpatterns
+                        .enumerate()
+                        .map(|(i, pattern)| FieldPat { field: FieldIdx::new(i), pattern })
                         .collect();
 
                     if adt_def.is_enum() {