about summary refs log tree commit diff
path: root/compiler/rustc_pattern_analysis/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_pattern_analysis/src')
-rw-r--r--compiler/rustc_pattern_analysis/src/constructor.rs40
-rw-r--r--compiler/rustc_pattern_analysis/src/lib.rs49
-rw-r--r--compiler/rustc_pattern_analysis/src/rustc.rs5
-rw-r--r--compiler/rustc_pattern_analysis/src/usefulness.rs49
4 files changed, 77 insertions, 66 deletions
diff --git a/compiler/rustc_pattern_analysis/src/constructor.rs b/compiler/rustc_pattern_analysis/src/constructor.rs
index 95c5556410d..44f09b66bf6 100644
--- a/compiler/rustc_pattern_analysis/src/constructor.rs
+++ b/compiler/rustc_pattern_analysis/src/constructor.rs
@@ -140,6 +140,34 @@
 //! [`ConstructorSet::split`]. The invariants of [`SplitConstructorSet`] are also of interest.
 //!
 //!
+//! ## Unions
+//!
+//! Unions allow us to match a value via several overlapping representations at the same time. For
+//! example, the following is exhaustive because when seeing the value as a boolean we handled all
+//! possible cases (other cases such as `n == 3` would trigger UB).
+//!
+//! ```rust
+//! # fn main() {
+//! union U8AsBool {
+//!     n: u8,
+//!     b: bool,
+//! }
+//! let x = U8AsBool { n: 1 };
+//! unsafe {
+//!     match x {
+//!         U8AsBool { n: 2 } => {}
+//!         U8AsBool { b: true } => {}
+//!         U8AsBool { b: false } => {}
+//!     }
+//! }
+//! # }
+//! ```
+//!
+//! Pattern-matching has no knowledge that e.g. `false as u8 == 0`, so the values we consider in the
+//! algorithm look like `U8AsBool { b: true, n: 2 }`. In other words, for the most part a union is
+//! treated like a struct with the same fields. The difference lies in how we construct witnesses of
+//! non-exhaustiveness.
+//!
 //!
 //! ## Opaque patterns
 //!
@@ -155,13 +183,13 @@ use std::iter::once;
 use smallvec::SmallVec;
 
 use rustc_apfloat::ieee::{DoubleS, IeeeFloat, SingleS};
-use rustc_index::bit_set::GrowableBitSet;
+use rustc_index::bit_set::{BitSet, GrowableBitSet};
+use rustc_index::IndexVec;
 
 use self::Constructor::*;
 use self::MaybeInfiniteInt::*;
 use self::SliceKind::*;
 
-use crate::index;
 use crate::PatCx;
 
 /// Whether we have seen a constructor in the column or not.
@@ -920,10 +948,7 @@ pub enum ConstructorSet<Cx: PatCx> {
     Struct { empty: bool },
     /// This type has the following list of constructors. If `variants` is empty and
     /// `non_exhaustive` is false, don't use this; use `NoConstructors` instead.
-    Variants {
-        variants: index::IdxContainer<Cx::VariantIdx, VariantVisibility>,
-        non_exhaustive: bool,
-    },
+    Variants { variants: IndexVec<Cx::VariantIdx, VariantVisibility>, non_exhaustive: bool },
     /// The type is `&T`.
     Ref,
     /// The type is a union.
@@ -977,7 +1002,6 @@ impl<Cx: PatCx> ConstructorSet<Cx> {
     /// any) are missing; 2/ split constructors to handle non-trivial intersections e.g. on ranges
     /// or slices. This can get subtle; see [`SplitConstructorSet`] for details of this operation
     /// and its invariants.
-    #[instrument(level = "debug", skip(self, ctors), ret)]
     pub fn split<'a>(
         &self,
         ctors: impl Iterator<Item = &'a Constructor<Cx>> + Clone,
@@ -1025,7 +1049,7 @@ impl<Cx: PatCx> ConstructorSet<Cx> {
                 }
             }
             ConstructorSet::Variants { variants, non_exhaustive } => {
-                let mut seen_set = index::IdxSet::new_empty(variants.len());
+                let mut seen_set = BitSet::new_empty(variants.len());
                 for idx in seen.iter().filter_map(|c| c.as_variant()) {
                     seen_set.insert(idx);
                 }
diff --git a/compiler/rustc_pattern_analysis/src/lib.rs b/compiler/rustc_pattern_analysis/src/lib.rs
index 1a1da5c55f6..6e8843d9049 100644
--- a/compiler/rustc_pattern_analysis/src/lib.rs
+++ b/compiler/rustc_pattern_analysis/src/lib.rs
@@ -25,50 +25,9 @@ rustc_fluent_macro::fluent_messages! { "../messages.ftl" }
 
 use std::fmt;
 
-#[cfg(feature = "rustc")]
-pub mod index {
-    // Faster version when the indices of variants are `0..variants.len()`.
-    pub use rustc_index::bit_set::BitSet as IdxSet;
-    pub use rustc_index::Idx;
-    pub use rustc_index::IndexVec as IdxContainer;
-}
-#[cfg(not(feature = "rustc"))]
-pub mod index {
-    // Slower version when the indices of variants are something else.
-    pub trait Idx: Copy + PartialEq + Eq + std::hash::Hash {}
-    impl<T: Copy + PartialEq + Eq + std::hash::Hash> Idx for T {}
-
-    #[derive(Debug)]
-    pub struct IdxContainer<K, V>(pub rustc_hash::FxHashMap<K, V>);
-    impl<K: Idx, V> IdxContainer<K, V> {
-        pub fn len(&self) -> usize {
-            self.0.len()
-        }
-        pub fn iter_enumerated(&self) -> impl Iterator<Item = (K, &V)> {
-            self.0.iter().map(|(k, v)| (*k, v))
-        }
-    }
-
-    impl<V> FromIterator<V> for IdxContainer<usize, V> {
-        fn from_iter<T: IntoIterator<Item = V>>(iter: T) -> Self {
-            Self(iter.into_iter().enumerate().collect())
-        }
-    }
-
-    #[derive(Debug)]
-    pub struct IdxSet<T>(pub rustc_hash::FxHashSet<T>);
-    impl<T: Idx> IdxSet<T> {
-        pub fn new_empty(_len: usize) -> Self {
-            Self(Default::default())
-        }
-        pub fn contains(&self, elem: T) -> bool {
-            self.0.contains(&elem)
-        }
-        pub fn insert(&mut self, elem: T) {
-            self.0.insert(elem);
-        }
-    }
-}
+// Re-exports to avoid rustc_index version issues.
+pub use rustc_index::Idx;
+pub use rustc_index::IndexVec;
 
 #[cfg(feature = "rustc")]
 use rustc_middle::ty::Ty;
@@ -96,7 +55,7 @@ pub trait PatCx: Sized + fmt::Debug {
     /// Errors that can abort analysis.
     type Error: fmt::Debug;
     /// The index of an enum variant.
-    type VariantIdx: Clone + index::Idx + fmt::Debug;
+    type VariantIdx: Clone + Idx + fmt::Debug;
     /// A string literal
     type StrLit: Clone + PartialEq + fmt::Debug;
     /// Extra data to store in a match arm.
diff --git a/compiler/rustc_pattern_analysis/src/rustc.rs b/compiler/rustc_pattern_analysis/src/rustc.rs
index ae3eabe1745..548a7b43005 100644
--- a/compiler/rustc_pattern_analysis/src/rustc.rs
+++ b/compiler/rustc_pattern_analysis/src/rustc.rs
@@ -186,7 +186,6 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
 
     /// Returns the types of the fields for a given constructor. The result must have a length of
     /// `ctor.arity()`.
-    #[instrument(level = "trace", skip(self))]
     pub(crate) fn ctor_sub_tys<'a>(
         &'a self,
         ctor: &'a Constructor<'p, 'tcx>,
@@ -283,7 +282,6 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
     /// Creates a set that represents all the constructors of `ty`.
     ///
     /// See [`crate::constructor`] for considerations of emptiness.
-    #[instrument(level = "debug", skip(self), ret)]
     pub fn ctors_for_ty(
         &self,
         ty: RevealedTy<'tcx>,
@@ -398,9 +396,10 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
             ty::Float(_)
             | ty::Str
             | ty::Foreign(_)
-            | ty::RawPtr(_)
+            | ty::RawPtr(_, _)
             | ty::FnDef(_, _)
             | ty::FnPtr(_)
+            | ty::Pat(_, _)
             | ty::Dynamic(_, _, _)
             | ty::Closure(..)
             | ty::CoroutineClosure(..)
diff --git a/compiler/rustc_pattern_analysis/src/usefulness.rs b/compiler/rustc_pattern_analysis/src/usefulness.rs
index cdc03eaeb37..7246039174a 100644
--- a/compiler/rustc_pattern_analysis/src/usefulness.rs
+++ b/compiler/rustc_pattern_analysis/src/usefulness.rs
@@ -871,12 +871,14 @@ impl<Cx: PatCx> PlaceInfo<Cx> {
     where
         Cx: 'a,
     {
+        debug!(?self.ty);
         if self.private_uninhabited {
             // Skip the whole column
             return Ok((smallvec![Constructor::PrivateUninhabited], vec![]));
         }
 
         let ctors_for_ty = cx.ctors_for_ty(&self.ty)?;
+        debug!(?ctors_for_ty);
 
         // We treat match scrutinees of type `!` or `EmptyEnum` differently.
         let is_toplevel_exception =
@@ -895,6 +897,7 @@ impl<Cx: PatCx> PlaceInfo<Cx> {
 
         // Analyze the constructors present in this column.
         let mut split_set = ctors_for_ty.split(ctors);
+        debug!(?split_set);
         let all_missing = split_set.present.is_empty();
 
         // Build the set of constructors we will specialize with. It must cover the whole type, so
@@ -1254,7 +1257,7 @@ impl<'p, Cx: PatCx> Matrix<'p, Cx> {
 /// + true  + [Second(true)]    +
 /// + false + [_]               +
 /// + _     + [_, _, tail @ ..] +
-/// | ✓     | ?                 | // column validity
+/// | ✓     | ?                 | // validity
 /// ```
 impl<'p, Cx: PatCx> fmt::Debug for Matrix<'p, Cx> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -1285,7 +1288,7 @@ impl<'p, Cx: PatCx> fmt::Debug for Matrix<'p, Cx> {
                 write!(f, " {sep}")?;
             }
             if is_validity_row {
-                write!(f, " // column validity")?;
+                write!(f, " // validity")?;
             }
             write!(f, "\n")?;
         }
@@ -1381,12 +1384,35 @@ impl<Cx: PatCx> WitnessStack<Cx> {
     /// pats: [(false, "foo"), _, true]
     /// result: [Enum::Variant { a: (false, "foo"), b: _ }, true]
     /// ```
-    fn apply_constructor(&mut self, pcx: &PlaceCtxt<'_, Cx>, ctor: &Constructor<Cx>) {
+    fn apply_constructor(
+        mut self,
+        pcx: &PlaceCtxt<'_, Cx>,
+        ctor: &Constructor<Cx>,
+    ) -> SmallVec<[Self; 1]> {
         let len = self.0.len();
         let arity = pcx.ctor_arity(ctor);
-        let fields = self.0.drain((len - arity)..).rev().collect();
-        let pat = WitnessPat::new(ctor.clone(), fields, pcx.ty.clone());
-        self.0.push(pat);
+        let fields: Vec<_> = self.0.drain((len - arity)..).rev().collect();
+        if matches!(ctor, Constructor::UnionField)
+            && fields.iter().filter(|p| !matches!(p.ctor(), Constructor::Wildcard)).count() >= 2
+        {
+            // Convert a `Union { a: p, b: q }` witness into `Union { a: p }` and `Union { b: q }`.
+            // First add `Union { .. }` to `self`.
+            self.0.push(WitnessPat::wild_from_ctor(pcx.cx, ctor.clone(), pcx.ty.clone()));
+            fields
+                .into_iter()
+                .enumerate()
+                .filter(|(_, p)| !matches!(p.ctor(), Constructor::Wildcard))
+                .map(|(i, p)| {
+                    let mut ret = self.clone();
+                    // Fill the `i`th field of the union with `p`.
+                    ret.0.last_mut().unwrap().fields[i] = p;
+                    ret
+                })
+                .collect()
+        } else {
+            self.0.push(WitnessPat::new(ctor.clone(), fields, pcx.ty.clone()));
+            smallvec![self]
+        }
     }
 }
 
@@ -1459,8 +1485,8 @@ impl<Cx: PatCx> WitnessMatrix<Cx> {
             *self = ret;
         } else {
             // Any other constructor we unspecialize as expected.
-            for witness in self.0.iter_mut() {
-                witness.apply_constructor(pcx, ctor)
+            for witness in std::mem::take(&mut self.0) {
+                self.0.extend(witness.apply_constructor(pcx, ctor));
             }
         }
     }
@@ -1617,7 +1643,6 @@ fn compute_exhaustiveness_and_usefulness<'a, 'p, Cx: PatCx>(
     };
 
     // Analyze the constructors present in this column.
-    debug!("ty: {:?}", place.ty);
     let ctors = matrix.heads().map(|p| p.ctor());
     let (split_ctors, missing_ctors) = place.split_column_ctors(mcx.tycx, ctors)?;
 
@@ -1669,7 +1694,10 @@ fn compute_exhaustiveness_and_usefulness<'a, 'p, Cx: PatCx>(
     for row in matrix.rows() {
         if row.useful {
             if let PatOrWild::Pat(pat) = row.head() {
-                mcx.useful_subpatterns.insert(pat.uid);
+                let newly_useful = mcx.useful_subpatterns.insert(pat.uid);
+                if newly_useful {
+                    debug!("newly useful: {pat:?}");
+                }
             }
         }
     }
@@ -1768,6 +1796,7 @@ pub fn compute_match_usefulness<'p, Cx: PatCx>(
         .map(|arm| {
             debug!(?arm);
             let usefulness = collect_pattern_usefulness(&cx.useful_subpatterns, arm.pat);
+            debug!(?usefulness);
             (arm, usefulness)
         })
         .collect();