about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_hir/src/hir.rs14
-rw-r--r--compiler/rustc_typeck/src/check/generator_interior.rs134
-rw-r--r--compiler/rustc_typeck/src/expr_use_visitor.rs21
-rw-r--r--src/test/ui/generator/drop-tracking-yielding-in-match-guards.rs12
4 files changed, 87 insertions, 94 deletions
diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs
index 81d544c7b96..57655365cca 100644
--- a/compiler/rustc_hir/src/hir.rs
+++ b/compiler/rustc_hir/src/hir.rs
@@ -1323,6 +1323,20 @@ pub enum Guard<'hir> {
     IfLet(&'hir Let<'hir>),
 }
 
+impl<'hir> Guard<'hir> {
+    /// Returns the body of the guard
+    ///
+    /// In other words, returns the e in either of the following:
+    ///
+    /// - `if e`
+    /// - `if let x = e`
+    pub fn body(&self) -> &'hir Expr<'hir> {
+        match self {
+            Guard::If(e) | Guard::IfLet(Let { init: e, .. }) => e,
+        }
+    }
+}
+
 #[derive(Debug, HashStable_Generic)]
 pub struct ExprField<'hir> {
     #[stable_hasher(ignore)]
diff --git a/compiler/rustc_typeck/src/check/generator_interior.rs b/compiler/rustc_typeck/src/check/generator_interior.rs
index 92a2584a6de..60d19405bcf 100644
--- a/compiler/rustc_typeck/src/check/generator_interior.rs
+++ b/compiler/rustc_typeck/src/check/generator_interior.rs
@@ -17,7 +17,6 @@ use rustc_middle::middle::region::{self, Scope, ScopeData, YieldData};
 use rustc_middle::ty::{self, Ty, TyCtxt};
 use rustc_span::symbol::sym;
 use rustc_span::Span;
-use smallvec::SmallVec;
 use tracing::debug;
 
 mod drop_ranges;
@@ -29,13 +28,6 @@ struct InteriorVisitor<'a, 'tcx> {
     expr_count: usize,
     kind: hir::GeneratorKind,
     prev_unresolved_span: Option<Span>,
-    /// Match arm guards have temporary borrows from the pattern bindings.
-    /// In case there is a yield point in a guard with a reference to such bindings,
-    /// such borrows can span across this yield point.
-    /// As such, we need to track these borrows and record them despite of the fact
-    /// that they may succeed the said yield point in the post-order.
-    guard_bindings: SmallVec<[SmallVec<[HirId; 4]>; 1]>,
-    guard_bindings_set: HirIdSet,
     linted_values: HirIdSet,
     drop_ranges: DropRanges,
 }
@@ -48,7 +40,6 @@ impl<'a, 'tcx> InteriorVisitor<'a, 'tcx> {
         scope: Option<region::Scope>,
         expr: Option<&'tcx Expr<'tcx>>,
         source_span: Span,
-        guard_borrowing_from_pattern: bool,
     ) {
         use rustc_span::DUMMY_SP;
 
@@ -89,8 +80,7 @@ impl<'a, 'tcx> InteriorVisitor<'a, 'tcx> {
                             // If it is a borrowing happening in the guard,
                             // it needs to be recorded regardless because they
                             // do live across this yield point.
-                            guard_borrowing_from_pattern
-                                || yield_data.expr_and_pat_count >= self.expr_count
+                            yield_data.expr_and_pat_count >= self.expr_count
                         })
                         .cloned()
                 })
@@ -196,8 +186,6 @@ pub fn resolve_interior<'a, 'tcx>(
         expr_count: 0,
         kind,
         prev_unresolved_span: None,
-        guard_bindings: <_>::default(),
-        guard_bindings_set: <_>::default(),
         linted_values: <_>::default(),
         drop_ranges: drop_ranges::compute_drop_ranges(fcx, def_id, body),
     };
@@ -284,15 +272,47 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
         let Arm { guard, pat, body, .. } = arm;
         self.visit_pat(pat);
         if let Some(ref g) = guard {
-            self.guard_bindings.push(<_>::default());
-            ArmPatCollector {
-                guard_bindings_set: &mut self.guard_bindings_set,
-                guard_bindings: self
-                    .guard_bindings
-                    .last_mut()
-                    .expect("should have pushed at least one earlier"),
+            {
+                // If there is a guard, we need to count all variables bound in the pattern as
+                // borrowed for the entire guard body, regardless of whether they are accessed.
+                // We do this by walking the pattern bindings and recording `&T` for any `x: T`
+                // that is bound.
+
+                struct ArmPatCollector<'a, 'b, 'tcx> {
+                    interior_visitor: &'a mut InteriorVisitor<'b, 'tcx>,
+                    scope: Scope,
+                }
+
+                impl<'a, 'b, 'tcx> Visitor<'tcx> for ArmPatCollector<'a, 'b, 'tcx> {
+                    fn visit_pat(&mut self, pat: &'tcx Pat<'tcx>) {
+                        intravisit::walk_pat(self, pat);
+                        if let PatKind::Binding(_, id, ident, ..) = pat.kind {
+                            let ty =
+                                self.interior_visitor.fcx.typeck_results.borrow().node_type(id);
+                            let tcx = self.interior_visitor.fcx.tcx;
+                            let ty = tcx.mk_ref(
+                                // Use `ReErased` as `resolve_interior` is going to replace all the
+                                // regions anyway.
+                                tcx.mk_region(ty::ReErased),
+                                ty::TypeAndMut { ty, mutbl: hir::Mutability::Not },
+                            );
+                            self.interior_visitor.record(
+                                ty,
+                                id,
+                                Some(self.scope),
+                                None,
+                                ident.span,
+                            );
+                        }
+                    }
+                }
+
+                ArmPatCollector {
+                    interior_visitor: self,
+                    scope: Scope { id: g.body().hir_id.local_id, data: ScopeData::Node },
+                }
+                .visit_pat(pat);
             }
-            .visit_pat(pat);
 
             match g {
                 Guard::If(ref e) => {
@@ -302,12 +322,6 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
                     self.visit_let_expr(l);
                 }
             }
-
-            let mut scope_var_ids =
-                self.guard_bindings.pop().expect("should have pushed at least one earlier");
-            for var_id in scope_var_ids.drain(..) {
-                self.guard_bindings_set.remove(&var_id);
-            }
         }
         self.visit_expr(body);
     }
@@ -320,13 +334,11 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
         if let PatKind::Binding(..) = pat.kind {
             let scope = self.region_scope_tree.var_scope(pat.hir_id.local_id).unwrap();
             let ty = self.fcx.typeck_results.borrow().pat_ty(pat);
-            self.record(ty, pat.hir_id, Some(scope), None, pat.span, false);
+            self.record(ty, pat.hir_id, Some(scope), None, pat.span);
         }
     }
 
     fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
-        let mut guard_borrowing_from_pattern = false;
-
         match &expr.kind {
             ExprKind::Call(callee, args) => match &callee.kind {
                 ExprKind::Path(qpath) => {
@@ -353,16 +365,6 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
                 }
                 _ => intravisit::walk_expr(self, expr),
             },
-            ExprKind::Path(qpath) => {
-                intravisit::walk_expr(self, expr);
-                let res = self.fcx.typeck_results.borrow().qpath_res(qpath, expr.hir_id);
-                match res {
-                    Res::Local(id) if self.guard_bindings_set.contains(&id) => {
-                        guard_borrowing_from_pattern = true;
-                    }
-                    _ => {}
-                }
-            }
             _ => intravisit::walk_expr(self, expr),
         }
 
@@ -391,14 +393,7 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
         // If there are adjustments, then record the final type --
         // this is the actual value that is being produced.
         if let Some(adjusted_ty) = self.fcx.typeck_results.borrow().expr_ty_adjusted_opt(expr) {
-            self.record(
-                adjusted_ty,
-                expr.hir_id,
-                scope,
-                Some(expr),
-                expr.span,
-                guard_borrowing_from_pattern,
-            );
+            self.record(adjusted_ty, expr.hir_id, scope, Some(expr), expr.span);
         }
 
         // Also record the unadjusted type (which is the only type if
@@ -426,54 +421,13 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
         // The type table might not have information for this expression
         // if it is in a malformed scope. (#66387)
         if let Some(ty) = self.fcx.typeck_results.borrow().expr_ty_opt(expr) {
-            if guard_borrowing_from_pattern {
-                // Match guards create references to all the bindings in the pattern that are used
-                // in the guard, e.g. `y if is_even(y) => ...` becomes `is_even(*r_y)` where `r_y`
-                // is a reference to `y`, so we must record a reference to the type of the binding.
-                let tcx = self.fcx.tcx;
-                let ref_ty = tcx.mk_ref(
-                    // Use `ReErased` as `resolve_interior` is going to replace all the regions anyway.
-                    tcx.mk_region(ty::ReErased),
-                    ty::TypeAndMut { ty, mutbl: hir::Mutability::Not },
-                );
-                self.record(
-                    ref_ty,
-                    expr.hir_id,
-                    scope,
-                    Some(expr),
-                    expr.span,
-                    guard_borrowing_from_pattern,
-                );
-            }
-            self.record(
-                ty,
-                expr.hir_id,
-                scope,
-                Some(expr),
-                expr.span,
-                guard_borrowing_from_pattern,
-            );
+            self.record(ty, expr.hir_id, scope, Some(expr), expr.span);
         } else {
             self.fcx.tcx.sess.delay_span_bug(expr.span, "no type for node");
         }
     }
 }
 
-struct ArmPatCollector<'a> {
-    guard_bindings_set: &'a mut HirIdSet,
-    guard_bindings: &'a mut SmallVec<[HirId; 4]>,
-}
-
-impl<'a, 'tcx> Visitor<'tcx> for ArmPatCollector<'a> {
-    fn visit_pat(&mut self, pat: &'tcx Pat<'tcx>) {
-        intravisit::walk_pat(self, pat);
-        if let PatKind::Binding(_, id, ..) = pat.kind {
-            self.guard_bindings.push(id);
-            self.guard_bindings_set.insert(id);
-        }
-    }
-}
-
 #[derive(Default)]
 pub struct SuspendCheckData<'a, 'tcx> {
     expr: Option<&'tcx Expr<'tcx>>,
diff --git a/compiler/rustc_typeck/src/expr_use_visitor.rs b/compiler/rustc_typeck/src/expr_use_visitor.rs
index 6de6b6ee479..ad44adb68c6 100644
--- a/compiler/rustc_typeck/src/expr_use_visitor.rs
+++ b/compiler/rustc_typeck/src/expr_use_visitor.rs
@@ -17,6 +17,7 @@ use rustc_middle::hir::place::ProjectionKind;
 use rustc_middle::mir::FakeReadCause;
 use rustc_middle::ty::{self, adjustment, AdtKind, Ty, TyCtxt};
 use rustc_target::abi::VariantIdx;
+use ty::BorrowKind::ImmBorrow;
 
 use crate::mem_categorization as mc;
 
@@ -621,7 +622,7 @@ impl<'a, 'tcx> ExprUseVisitor<'a, 'tcx> {
             FakeReadCause::ForMatchedPlace(closure_def_id),
             discr_place.hir_id,
         );
-        self.walk_pat(discr_place, arm.pat);
+        self.walk_pat(discr_place, arm.pat, arm.guard.is_some());
 
         if let Some(hir::Guard::If(e)) = arm.guard {
             self.consume_expr(e)
@@ -645,12 +646,17 @@ impl<'a, 'tcx> ExprUseVisitor<'a, 'tcx> {
             FakeReadCause::ForLet(closure_def_id),
             discr_place.hir_id,
         );
-        self.walk_pat(discr_place, pat);
+        self.walk_pat(discr_place, pat, false);
     }
 
     /// The core driver for walking a pattern
-    fn walk_pat(&mut self, discr_place: &PlaceWithHirId<'tcx>, pat: &hir::Pat<'_>) {
-        debug!("walk_pat(discr_place={:?}, pat={:?})", discr_place, pat);
+    fn walk_pat(
+        &mut self,
+        discr_place: &PlaceWithHirId<'tcx>,
+        pat: &hir::Pat<'_>,
+        has_guard: bool,
+    ) {
+        debug!("walk_pat(discr_place={:?}, pat={:?}, has_guard={:?})", discr_place, pat, has_guard);
 
         let tcx = self.tcx();
         let ExprUseVisitor { ref mc, body_owner: _, ref mut delegate } = *self;
@@ -671,6 +677,13 @@ impl<'a, 'tcx> ExprUseVisitor<'a, 'tcx> {
                         delegate.bind(binding_place, binding_place.hir_id);
                     }
 
+                    // Subtle: MIR desugaring introduces immutable borrows for each pattern
+                    // binding when lowering pattern guards to ensure that the guard does not
+                    // modify the scrutinee.
+                    if has_guard {
+                        delegate.borrow(place, discr_place.hir_id, ImmBorrow);
+                    }
+
                     // It is also a borrow or copy/move of the value being matched.
                     // In a cases of pattern like `let pat = upvar`, don't use the span
                     // of the pattern, as this just looks confusing, instead use the span
diff --git a/src/test/ui/generator/drop-tracking-yielding-in-match-guards.rs b/src/test/ui/generator/drop-tracking-yielding-in-match-guards.rs
new file mode 100644
index 00000000000..646365e4359
--- /dev/null
+++ b/src/test/ui/generator/drop-tracking-yielding-in-match-guards.rs
@@ -0,0 +1,12 @@
+// build-pass
+// edition:2018
+// compile-flags: -Zdrop-tracking
+
+#![feature(generators)]
+
+fn main() {
+    let _ = static |x: u8| match x {
+        y if { yield } == y + 1 => (),
+        _ => (),
+    };
+}