about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Holk <ericholk@microsoft.com>2021-10-25 17:01:24 -0700
committerEric Holk <ericholk@microsoft.com>2022-01-18 14:25:24 -0800
commitf246c0b116cdbbad570c23c5745aa01f6f3f64a0 (patch)
tree632f368ab1d6bf21171f8051f1c7751c39d901bb
parentf664cfc47cdfaa83a1fd35e6e6a3fcdb692286ae (diff)
downloadrust-f246c0b116cdbbad570c23c5745aa01f6f3f64a0.tar.gz
rust-f246c0b116cdbbad570c23c5745aa01f6f3f64a0.zip
Attribute drop to parent expression of the consume point
This is needed to handle cases like `[a, b.await, c]`. `ExprUseVisitor`
considers `a` to be consumed when it is passed to the array, but the
array is not quite live yet at that point. This means we were missing
the `a` value across the await point. Attributing drops to the parent
expression means we do not consider the value consumed until the
consuming expression has finished.

Issue #57478
-rw-r--r--compiler/rustc_typeck/src/check/generator_interior.rs70
-rw-r--r--src/test/ui/async-await/unresolved_type_param.rs8
-rw-r--r--src/test/ui/async-await/unresolved_type_param.stderr26
-rw-r--r--src/test/ui/lint/must_not_suspend/dedup.rs2
-rw-r--r--src/test/ui/lint/must_not_suspend/dedup.stderr12
5 files changed, 87 insertions, 31 deletions
diff --git a/compiler/rustc_typeck/src/check/generator_interior.rs b/compiler/rustc_typeck/src/check/generator_interior.rs
index baeb78139ac..92dea92a0bc 100644
--- a/compiler/rustc_typeck/src/check/generator_interior.rs
+++ b/compiler/rustc_typeck/src/check/generator_interior.rs
@@ -6,7 +6,7 @@
 use crate::expr_use_visitor::{self, ExprUseVisitor};
 
 use super::FnCtxt;
-use hir::HirIdMap;
+use hir::{HirIdMap, Node};
 use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
 use rustc_errors::pluralize;
 use rustc_hir as hir;
@@ -15,6 +15,7 @@ use rustc_hir::def_id::DefId;
 use rustc_hir::hir_id::HirIdSet;
 use rustc_hir::intravisit::{self, Visitor};
 use rustc_hir::{Arm, Expr, ExprKind, Guard, HirId, Pat, PatKind};
+use rustc_middle::hir::map::Map;
 use rustc_middle::hir::place::{Place, PlaceBase};
 use rustc_middle::middle::region::{self, YieldData};
 use rustc_middle::ty::{self, Ty, TyCtxt};
@@ -225,6 +226,7 @@ pub fn resolve_interior<'a, 'tcx>(
 
     let mut visitor = {
         let mut drop_range_visitor = DropRangeVisitor {
+            hir: fcx.tcx.hir(),
             consumed_places: <_>::default(),
             borrowed_places: <_>::default(),
             drop_ranges: vec![<_>::default()],
@@ -664,19 +666,28 @@ fn check_must_not_suspend_def(
 }
 
 /// This struct facilitates computing the ranges for which a place is uninitialized.
-struct DropRangeVisitor {
-    consumed_places: HirIdSet,
+struct DropRangeVisitor<'tcx> {
+    hir: Map<'tcx>,
+    /// Maps a HirId to a set of HirIds that are dropped by that node.
+    consumed_places: HirIdMap<HirIdSet>,
     borrowed_places: HirIdSet,
     drop_ranges: Vec<HirIdMap<DropRange>>,
     expr_count: usize,
 }
 
-impl DropRangeVisitor {
+impl DropRangeVisitor<'tcx> {
+    fn mark_consumed(&mut self, consumer: HirId, target: HirId) {
+        if !self.consumed_places.contains_key(&consumer) {
+            self.consumed_places.insert(consumer, <_>::default());
+        }
+        self.consumed_places.get_mut(&consumer).map(|places| places.insert(target));
+    }
+
     fn record_drop(&mut self, hir_id: HirId) {
         let drop_ranges = self.drop_ranges.last_mut().unwrap();
         if self.borrowed_places.contains(&hir_id) {
             debug!("not marking {:?} as dropped because it is borrowed at some point", hir_id);
-        } else if self.consumed_places.contains(&hir_id) {
+        } else {
             debug!("marking {:?} as dropped at {}", hir_id, self.expr_count);
             drop_ranges.insert(hir_id, DropRange { dropped_at: self.expr_count });
         }
@@ -700,15 +711,24 @@ impl DropRangeVisitor {
     /// ExprUseVisitor's consume callback doesn't go deep enough for our purposes in all
     /// expressions. This method consumes a little deeper into the expression when needed.
     fn consume_expr(&mut self, expr: &hir::Expr<'_>) {
-        self.record_drop(expr.hir_id);
-        match expr.kind {
-            hir::ExprKind::Path(hir::QPath::Resolved(
-                _,
-                hir::Path { res: hir::def::Res::Local(hir_id), .. },
-            )) => {
-                self.record_drop(*hir_id);
+        debug!("consuming expr {:?}, count={}", expr.hir_id, self.expr_count);
+        let places = self
+            .consumed_places
+            .get(&expr.hir_id)
+            .map_or(vec![], |places| places.iter().cloned().collect());
+        for place in places {
+            self.record_drop(place);
+            if let Some(Node::Expr(expr)) = self.hir.find(place) {
+                match expr.kind {
+                    hir::ExprKind::Path(hir::QPath::Resolved(
+                        _,
+                        hir::Path { res: hir::def::Res::Local(hir_id), .. },
+                    )) => {
+                        self.record_drop(*hir_id);
+                    }
+                    _ => (),
+                }
             }
-            _ => (),
         }
     }
 }
@@ -721,15 +741,19 @@ fn place_hir_id(place: &Place<'_>) -> Option<HirId> {
     }
 }
 
-impl<'tcx> expr_use_visitor::Delegate<'tcx> for DropRangeVisitor {
+impl<'tcx> expr_use_visitor::Delegate<'tcx> for DropRangeVisitor<'tcx> {
     fn consume(
         &mut self,
         place_with_id: &expr_use_visitor::PlaceWithHirId<'tcx>,
         diag_expr_id: hir::HirId,
     ) {
-        debug!("consume {:?}; diag_expr_id={:?}", place_with_id, diag_expr_id);
-        self.consumed_places.insert(place_with_id.hir_id);
-        place_hir_id(&place_with_id.place).map(|place| self.consumed_places.insert(place));
+        let parent = match self.hir.find_parent_node(place_with_id.hir_id) {
+            Some(parent) => parent,
+            None => place_with_id.hir_id,
+        };
+        debug!("consume {:?}; diag_expr_id={:?}, using parent {:?}", place_with_id, diag_expr_id, parent);
+        self.mark_consumed(parent, place_with_id.hir_id);
+        place_hir_id(&place_with_id.place).map(|place| self.mark_consumed(parent, place));
     }
 
     fn borrow(
@@ -757,7 +781,7 @@ impl<'tcx> expr_use_visitor::Delegate<'tcx> for DropRangeVisitor {
     }
 }
 
-impl<'tcx> Visitor<'tcx> for DropRangeVisitor {
+impl<'tcx> Visitor<'tcx> for DropRangeVisitor<'tcx> {
     type Map = intravisit::ErasedMap<'tcx>;
 
     fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
@@ -766,20 +790,20 @@ impl<'tcx> Visitor<'tcx> for DropRangeVisitor {
 
     fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
         match expr.kind {
-            ExprKind::AssignOp(_, lhs, rhs) => {
+            ExprKind::AssignOp(_op, lhs, rhs) => {
                 // These operations are weird because their order of evaluation depends on whether
                 // the operator is overloaded. In a perfect world, we'd just ask the type checker
                 // whether this is a method call, but we also need to match the expression IDs
                 // from RegionResolutionVisitor. RegionResolutionVisitor doesn't know the order,
                 // so it runs both orders and picks the most conservative. We'll mirror that here.
                 let mut old_count = self.expr_count;
-                intravisit::walk_expr(self, lhs);
-                intravisit::walk_expr(self, rhs);
+                self.visit_expr(lhs);
+                self.visit_expr(rhs);
 
                 self.push_drop_scope();
                 std::mem::swap(&mut old_count, &mut self.expr_count);
-                intravisit::walk_expr(self, rhs);
-                intravisit::walk_expr(self, lhs);
+                self.visit_expr(rhs);
+                self.visit_expr(lhs);
 
                 // We should have visited the same number of expressions in either order.
                 assert_eq!(old_count, self.expr_count);
diff --git a/src/test/ui/async-await/unresolved_type_param.rs b/src/test/ui/async-await/unresolved_type_param.rs
index 79c043b701d..d313691b388 100644
--- a/src/test/ui/async-await/unresolved_type_param.rs
+++ b/src/test/ui/async-await/unresolved_type_param.rs
@@ -8,8 +8,16 @@ async fn bar<T>() -> () {}
 async fn foo() {
     bar().await;
     //~^ ERROR type inside `async fn` body must be known in this context
+    //~| ERROR type inside `async fn` body must be known in this context
+    //~| ERROR type inside `async fn` body must be known in this context
     //~| NOTE cannot infer type for type parameter `T`
+    //~| NOTE cannot infer type for type parameter `T`
+    //~| NOTE cannot infer type for type parameter `T`
+    //~| NOTE the type is part of the `async fn` body because of this `await`
     //~| NOTE the type is part of the `async fn` body because of this `await`
+    //~| NOTE the type is part of the `async fn` body because of this `await`
+    //~| NOTE in this expansion of desugaring of `await`
+    //~| NOTE in this expansion of desugaring of `await`
     //~| NOTE in this expansion of desugaring of `await`
 }
 fn main() {}
diff --git a/src/test/ui/async-await/unresolved_type_param.stderr b/src/test/ui/async-await/unresolved_type_param.stderr
index 853e53ed69d..6a268bcda62 100644
--- a/src/test/ui/async-await/unresolved_type_param.stderr
+++ b/src/test/ui/async-await/unresolved_type_param.stderr
@@ -10,6 +10,30 @@ note: the type is part of the `async fn` body because of this `await`
 LL |     bar().await;
    |          ^^^^^^
 
-error: aborting due to previous error
+error[E0698]: type inside `async fn` body must be known in this context
+  --> $DIR/unresolved_type_param.rs:9:5
+   |
+LL |     bar().await;
+   |     ^^^ cannot infer type for type parameter `T` declared on the function `bar`
+   |
+note: the type is part of the `async fn` body because of this `await`
+  --> $DIR/unresolved_type_param.rs:9:5
+   |
+LL |     bar().await;
+   |     ^^^^^^^^^^^
+
+error[E0698]: type inside `async fn` body must be known in this context
+  --> $DIR/unresolved_type_param.rs:9:5
+   |
+LL |     bar().await;
+   |     ^^^ cannot infer type for type parameter `T` declared on the function `bar`
+   |
+note: the type is part of the `async fn` body because of this `await`
+  --> $DIR/unresolved_type_param.rs:9:5
+   |
+LL |     bar().await;
+   |     ^^^^^^^^^^^
+
+error: aborting due to 3 previous errors
 
 For more information about this error, try `rustc --explain E0698`.
diff --git a/src/test/ui/lint/must_not_suspend/dedup.rs b/src/test/ui/lint/must_not_suspend/dedup.rs
index 040fff5a5a5..81a08579bb7 100644
--- a/src/test/ui/lint/must_not_suspend/dedup.rs
+++ b/src/test/ui/lint/must_not_suspend/dedup.rs
@@ -13,7 +13,7 @@ async fn wheeee<T>(t: T) {
 }
 
 async fn yes() {
-    wheeee(No {}).await; //~ ERROR `No` held across
+    wheeee(&No {}).await; //~ ERROR `No` held across
 }
 
 fn main() {
diff --git a/src/test/ui/lint/must_not_suspend/dedup.stderr b/src/test/ui/lint/must_not_suspend/dedup.stderr
index bc1b611299a..d1513747452 100644
--- a/src/test/ui/lint/must_not_suspend/dedup.stderr
+++ b/src/test/ui/lint/must_not_suspend/dedup.stderr
@@ -1,8 +1,8 @@
 error: `No` held across a suspend point, but should not be
-  --> $DIR/dedup.rs:16:12
+  --> $DIR/dedup.rs:16:13
    |
-LL |     wheeee(No {}).await;
-   |            ^^^^^ ------ the value is held across this suspend point
+LL |     wheeee(&No {}).await;
+   |     --------^^^^^------- the value is held across this suspend point
    |
 note: the lint level is defined here
   --> $DIR/dedup.rs:3:9
@@ -10,10 +10,10 @@ note: the lint level is defined here
 LL | #![deny(must_not_suspend)]
    |         ^^^^^^^^^^^^^^^^
 help: consider using a block (`{ ... }`) to shrink the value's scope, ending before the suspend point
-  --> $DIR/dedup.rs:16:12
+  --> $DIR/dedup.rs:16:13
    |
-LL |     wheeee(No {}).await;
-   |            ^^^^^
+LL |     wheeee(&No {}).await;
+   |             ^^^^^
 
 error: aborting due to previous error