about summary refs log tree commit diff
path: root/compiler/rustc_mir_build/src/builder/matches/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_build/src/builder/matches/mod.rs')
-rw-r--r--compiler/rustc_mir_build/src/builder/matches/mod.rs144
1 files changed, 135 insertions, 9 deletions
diff --git a/compiler/rustc_mir_build/src/builder/matches/mod.rs b/compiler/rustc_mir_build/src/builder/matches/mod.rs
index 977d4f3e931..270a7d4b154 100644
--- a/compiler/rustc_mir_build/src/builder/matches/mod.rs
+++ b/compiler/rustc_mir_build/src/builder/matches/mod.rs
@@ -18,7 +18,9 @@ use rustc_middle::bug;
 use rustc_middle::middle::region;
 use rustc_middle::mir::{self, *};
 use rustc_middle::thir::{self, *};
-use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty};
+use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty, ValTree, ValTreeKind};
+use rustc_pattern_analysis::constructor::RangeEnd;
+use rustc_pattern_analysis::rustc::{DeconstructedPat, RustcPatCtxt};
 use rustc_span::{BytePos, Pos, Span, Symbol, sym};
 use tracing::{debug, instrument};
 
@@ -426,7 +428,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
     /// (by [Builder::lower_match_tree]).
     ///
     /// `outer_source_info` is the SourceInfo for the whole match.
-    fn lower_match_arms(
+    pub(crate) fn lower_match_arms(
         &mut self,
         destination: Place<'tcx>,
         scrutinee_place_builder: PlaceBuilder<'tcx>,
@@ -1395,7 +1397,7 @@ pub(crate) struct ArmHasGuard(pub(crate) bool);
 /// A sub-branch in the output of match lowering. Match lowering has generated MIR code that will
 /// branch to `success_block` when the matched value matches the corresponding pattern. If there is
 /// a guard, its failure must continue to `otherwise_block`, which will resume testing patterns.
-#[derive(Debug)]
+#[derive(Debug, Clone)]
 struct MatchTreeSubBranch<'tcx> {
     span: Span,
     /// The block that is branched to if the corresponding subpattern matches.
@@ -1411,7 +1413,7 @@ struct MatchTreeSubBranch<'tcx> {
 }
 
 /// A branch in the output of match lowering.
-#[derive(Debug)]
+#[derive(Debug, Clone)]
 struct MatchTreeBranch<'tcx> {
     sub_branches: Vec<MatchTreeSubBranch<'tcx>>,
 }
@@ -1430,8 +1432,8 @@ struct MatchTreeBranch<'tcx> {
 /// Here the first arm gives the first `MatchTreeBranch`, which has two sub-branches, one for each
 /// alternative of the or-pattern. They are kept separate because each needs to bind `x` to a
 /// different place.
-#[derive(Debug)]
-struct BuiltMatchTree<'tcx> {
+#[derive(Debug, Clone)]
+pub(crate) struct BuiltMatchTree<'tcx> {
     branches: Vec<MatchTreeBranch<'tcx>>,
     otherwise_block: BasicBlock,
     /// If any of the branches had a guard, we collect here the places and locals to fakely borrow
@@ -1489,7 +1491,7 @@ impl<'tcx> MatchTreeBranch<'tcx> {
 }
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
-enum HasMatchGuard {
+pub(crate) enum HasMatchGuard {
     Yes,
     No,
 }
@@ -1504,7 +1506,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
     /// `refutable` indicates whether the candidate list is refutable (for `if let` and `let else`)
     /// or not (for `let` and `match`). In the refutable case we return the block to which we branch
     /// on failure.
-    fn lower_match_tree(
+    pub(crate) fn lower_match_tree(
         &mut self,
         block: BasicBlock,
         scrutinee_span: Span,
@@ -1890,7 +1892,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         debug!("expanding or-pattern: candidate={:#?}\npats={:#?}", candidate, pats);
         candidate.or_span = Some(match_pair.pattern_span);
         candidate.subcandidates = pats
-            .into_vec()
             .into_iter()
             .map(|flat_pat| Candidate::from_flat_pat(flat_pat, candidate.has_guard))
             .collect();
@@ -2864,4 +2865,129 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
 
         true
     }
+
+    /// Attempt to statically pick the `BasicBlock` that a value would resolve to at runtime.
+    pub(crate) fn static_pattern_match(
+        &self,
+        cx: &RustcPatCtxt<'_, 'tcx>,
+        valtree: ValTree<'tcx>,
+        arms: &[ArmId],
+        built_match_tree: &BuiltMatchTree<'tcx>,
+    ) -> Option<BasicBlock> {
+        let it = arms.iter().zip(built_match_tree.branches.iter());
+        for (&arm_id, branch) in it {
+            let pat = cx.lower_pat(&*self.thir.arms[arm_id].pattern);
+
+            // Peel off or-patterns if they exist.
+            if let rustc_pattern_analysis::rustc::Constructor::Or = pat.ctor() {
+                for pat in pat.iter_fields() {
+                    // For top-level or-patterns (the only ones we accept right now), when the
+                    // bindings are the same (e.g. there are none), the sub_branch is stored just
+                    // once.
+                    let sub_branch = branch
+                        .sub_branches
+                        .get(pat.idx)
+                        .or_else(|| branch.sub_branches.last())
+                        .unwrap();
+
+                    match self.static_pattern_match_inner(valtree, &pat.pat) {
+                        true => return Some(sub_branch.success_block),
+                        false => continue,
+                    }
+                }
+            } else if self.static_pattern_match_inner(valtree, &pat) {
+                return Some(branch.sub_branches[0].success_block);
+            }
+        }
+
+        None
+    }
+
+    /// Helper for [`Self::static_pattern_match`], checking whether the value represented by the
+    /// `ValTree` matches the given pattern. This function does not recurse, meaning that it does
+    /// not handle or-patterns, or patterns for types with fields.
+    fn static_pattern_match_inner(
+        &self,
+        valtree: ty::ValTree<'tcx>,
+        pat: &DeconstructedPat<'_, 'tcx>,
+    ) -> bool {
+        use rustc_pattern_analysis::constructor::{IntRange, MaybeInfiniteInt};
+        use rustc_pattern_analysis::rustc::Constructor;
+
+        match pat.ctor() {
+            Constructor::Variant(variant_index) => {
+                let ValTreeKind::Branch(box [actual_variant_idx]) = *valtree else {
+                    bug!("malformed valtree for an enum")
+                };
+
+                let ValTreeKind::Leaf(actual_variant_idx) = ***actual_variant_idx else {
+                    bug!("malformed valtree for an enum")
+                };
+
+                *variant_index == VariantIdx::from_u32(actual_variant_idx.to_u32())
+            }
+            Constructor::IntRange(int_range) => {
+                let size = pat.ty().primitive_size(self.tcx);
+                let actual_int = valtree.unwrap_leaf().to_bits(size);
+                let actual_int = if pat.ty().is_signed() {
+                    MaybeInfiniteInt::new_finite_int(actual_int, size.bits())
+                } else {
+                    MaybeInfiniteInt::new_finite_uint(actual_int)
+                };
+                IntRange::from_singleton(actual_int).is_subrange(int_range)
+            }
+            Constructor::Bool(pattern_value) => match valtree.unwrap_leaf().try_to_bool() {
+                Ok(actual_value) => *pattern_value == actual_value,
+                Err(()) => bug!("bool value with invalid bits"),
+            },
+            Constructor::F16Range(l, h, end) => {
+                let actual = valtree.unwrap_leaf().to_f16();
+                match end {
+                    RangeEnd::Included => (*l..=*h).contains(&actual),
+                    RangeEnd::Excluded => (*l..*h).contains(&actual),
+                }
+            }
+            Constructor::F32Range(l, h, end) => {
+                let actual = valtree.unwrap_leaf().to_f32();
+                match end {
+                    RangeEnd::Included => (*l..=*h).contains(&actual),
+                    RangeEnd::Excluded => (*l..*h).contains(&actual),
+                }
+            }
+            Constructor::F64Range(l, h, end) => {
+                let actual = valtree.unwrap_leaf().to_f64();
+                match end {
+                    RangeEnd::Included => (*l..=*h).contains(&actual),
+                    RangeEnd::Excluded => (*l..*h).contains(&actual),
+                }
+            }
+            Constructor::F128Range(l, h, end) => {
+                let actual = valtree.unwrap_leaf().to_f128();
+                match end {
+                    RangeEnd::Included => (*l..=*h).contains(&actual),
+                    RangeEnd::Excluded => (*l..*h).contains(&actual),
+                }
+            }
+            Constructor::Wildcard => true,
+
+            // These we may eventually support:
+            Constructor::Struct
+            | Constructor::Ref
+            | Constructor::DerefPattern(_)
+            | Constructor::Slice(_)
+            | Constructor::UnionField
+            | Constructor::Or
+            | Constructor::Str(_) => bug!("unsupported pattern constructor {:?}", pat.ctor()),
+
+            // These should never occur here:
+            Constructor::Opaque(_)
+            | Constructor::Never
+            | Constructor::NonExhaustive
+            | Constructor::Hidden
+            | Constructor::Missing
+            | Constructor::PrivateUninhabited => {
+                bug!("unsupported pattern constructor {:?}", pat.ctor())
+            }
+        }
+    }
 }