about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock1
-rw-r--r--compiler/rustc_typeck/Cargo.toml1
-rw-r--r--compiler/rustc_typeck/src/check/generator_interior.rs224
-rw-r--r--src/test/ui/generator/drop-if.rs22
4 files changed, 220 insertions, 28 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 529e17b158f..dfe3db5907a 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -4383,6 +4383,7 @@ dependencies = [
 name = "rustc_typeck"
 version = "0.0.0"
 dependencies = [
+ "itertools 0.9.0",
  "rustc_arena",
  "rustc_ast",
  "rustc_attr",
diff --git a/compiler/rustc_typeck/Cargo.toml b/compiler/rustc_typeck/Cargo.toml
index 7e570e151c5..1da106d7002 100644
--- a/compiler/rustc_typeck/Cargo.toml
+++ b/compiler/rustc_typeck/Cargo.toml
@@ -8,6 +8,7 @@ test = false
 doctest = false
 
 [dependencies]
+itertools = "0.9"
 rustc_arena = { path = "../rustc_arena" }
 tracing = "0.1"
 rustc_macros = { path = "../rustc_macros" }
diff --git a/compiler/rustc_typeck/src/check/generator_interior.rs b/compiler/rustc_typeck/src/check/generator_interior.rs
index 92dea92a0bc..6144cbbd8dd 100644
--- a/compiler/rustc_typeck/src/check/generator_interior.rs
+++ b/compiler/rustc_typeck/src/check/generator_interior.rs
@@ -3,10 +3,13 @@
 //! is calculated in `rustc_const_eval::transform::generator` and may be a subset of the
 //! types computed here.
 
+use std::mem;
+
 use crate::expr_use_visitor::{self, ExprUseVisitor};
 
 use super::FnCtxt;
 use hir::{HirIdMap, Node};
+use itertools::Itertools;
 use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
 use rustc_errors::pluralize;
 use rustc_hir as hir;
@@ -24,6 +27,9 @@ use rustc_span::Span;
 use smallvec::SmallVec;
 use tracing::debug;
 
+#[cfg(test)]
+mod tests;
+
 struct InteriorVisitor<'a, 'tcx> {
     fcx: &'a FnCtxt<'a, 'tcx>,
     types: FxIndexSet<ty::GeneratorInteriorTypeCause<'tcx>>,
@@ -80,7 +86,9 @@ impl<'a, 'tcx> InteriorVisitor<'a, 'tcx> {
                             );
 
                             match self.drop_ranges.get(&hir_id) {
-                                Some(range) if range.contains(yield_data.expr_and_pat_count) => {
+                                Some(range)
+                                    if range.is_dropped_at(yield_data.expr_and_pat_count) =>
+                                {
                                     debug!("value is dropped at yield point; not recording");
                                     return false;
                                 }
@@ -229,7 +237,7 @@ pub fn resolve_interior<'a, 'tcx>(
             hir: fcx.tcx.hir(),
             consumed_places: <_>::default(),
             borrowed_places: <_>::default(),
-            drop_ranges: vec![<_>::default()],
+            drop_ranges: <_>::default(),
             expr_count: 0,
         };
 
@@ -254,7 +262,7 @@ pub fn resolve_interior<'a, 'tcx>(
             guard_bindings: <_>::default(),
             guard_bindings_set: <_>::default(),
             linted_values: <_>::default(),
-            drop_ranges: drop_range_visitor.drop_ranges.pop().unwrap(),
+            drop_ranges: drop_range_visitor.drop_ranges,
         }
     };
     intravisit::walk_body(&mut visitor, body);
@@ -671,7 +679,7 @@ struct DropRangeVisitor<'tcx> {
     /// Maps a HirId to a set of HirIds that are dropped by that node.
     consumed_places: HirIdMap<HirIdSet>,
     borrowed_places: HirIdSet,
-    drop_ranges: Vec<HirIdMap<DropRange>>,
+    drop_ranges: HirIdMap<DropRange>,
     expr_count: usize,
 }
 
@@ -684,28 +692,42 @@ impl DropRangeVisitor<'tcx> {
     }
 
     fn record_drop(&mut self, hir_id: HirId) {
-        let drop_ranges = self.drop_ranges.last_mut().unwrap();
+        let drop_ranges = &mut self.drop_ranges;
         if self.borrowed_places.contains(&hir_id) {
             debug!("not marking {:?} as dropped because it is borrowed at some point", hir_id);
         } else {
             debug!("marking {:?} as dropped at {}", hir_id, self.expr_count);
-            drop_ranges.insert(hir_id, DropRange { dropped_at: self.expr_count });
+            drop_ranges.insert(hir_id, DropRange::new(self.expr_count));
         }
     }
 
-    fn push_drop_scope(&mut self) {
-        self.drop_ranges.push(<_>::default());
+    fn swap_drop_ranges(&mut self, mut other: HirIdMap<DropRange>) -> HirIdMap<DropRange> {
+        mem::swap(&mut self.drop_ranges, &mut other);
+        other
     }
 
-    fn pop_and_merge_drop_scope(&mut self) {
-        let mut old_last = self.drop_ranges.pop().unwrap();
-        let drop_ranges = self.drop_ranges.last_mut().unwrap();
-        for (k, v) in old_last.drain() {
-            match drop_ranges.get(&k).cloned() {
-                Some(v2) => drop_ranges.insert(k, v.intersect(&v2)),
-                None => drop_ranges.insert(k, v),
-            };
-        }
+    #[allow(dead_code)]
+    fn fork_drop_ranges(&self) -> HirIdMap<DropRange> {
+        self.drop_ranges.iter().map(|(k, v)| (*k, v.fork_at(self.expr_count))).collect()
+    }
+
+    fn intersect_drop_ranges(&mut self, drops: HirIdMap<DropRange>) {
+        drops.into_iter().for_each(|(k, v)| match self.drop_ranges.get_mut(&k) {
+            Some(ranges) => *ranges = ranges.intersect(&v),
+            None => {
+                self.drop_ranges.insert(k, v);
+            }
+        })
+    }
+
+    #[allow(dead_code)]
+    fn merge_drop_ranges(&mut self, drops: HirIdMap<DropRange>) {
+        drops.into_iter().for_each(|(k, v)| {
+            if !self.drop_ranges.contains_key(&k) {
+                self.drop_ranges.insert(k, DropRange { events: vec![] });
+            }
+            self.drop_ranges.get_mut(&k).unwrap().merge_with(&v, self.expr_count);
+        });
     }
 
     /// ExprUseVisitor's consume callback doesn't go deep enough for our purposes in all
@@ -751,7 +773,10 @@ impl<'tcx> expr_use_visitor::Delegate<'tcx> for DropRangeVisitor<'tcx> {
             Some(parent) => parent,
             None => place_with_id.hir_id,
         };
-        debug!("consume {:?}; diag_expr_id={:?}, using parent {:?}", place_with_id, diag_expr_id, parent);
+        debug!(
+            "consume {:?}; diag_expr_id={:?}, using parent {:?}",
+            place_with_id, diag_expr_id, parent
+        );
         self.mark_consumed(parent, place_with_id.hir_id);
         place_hir_id(&place_with_id.place).map(|place| self.mark_consumed(parent, place));
     }
@@ -800,7 +825,7 @@ impl<'tcx> Visitor<'tcx> for DropRangeVisitor<'tcx> {
                 self.visit_expr(lhs);
                 self.visit_expr(rhs);
 
-                self.push_drop_scope();
+                let old_drops = self.swap_drop_ranges(<_>::default());
                 std::mem::swap(&mut old_count, &mut self.expr_count);
                 self.visit_expr(rhs);
                 self.visit_expr(lhs);
@@ -808,7 +833,39 @@ impl<'tcx> Visitor<'tcx> for DropRangeVisitor<'tcx> {
                 // We should have visited the same number of expressions in either order.
                 assert_eq!(old_count, self.expr_count);
 
-                self.pop_and_merge_drop_scope();
+                self.intersect_drop_ranges(old_drops);
+            }
+            ExprKind::If(test, if_true, if_false) => {
+                self.visit_expr(test);
+
+                match if_false {
+                    Some(if_false) => {
+                        let mut true_ranges = self.fork_drop_ranges();
+                        let mut false_ranges = self.fork_drop_ranges();
+
+                        true_ranges = self.swap_drop_ranges(true_ranges);
+                        self.visit_expr(if_true);
+                        true_ranges = self.swap_drop_ranges(true_ranges);
+
+                        false_ranges = self.swap_drop_ranges(false_ranges);
+                        self.visit_expr(if_false);
+                        false_ranges = self.swap_drop_ranges(false_ranges);
+
+                        self.merge_drop_ranges(true_ranges);
+                        self.merge_drop_ranges(false_ranges);
+                    }
+                    None => {
+                        let mut true_ranges = self.fork_drop_ranges();
+                        debug!("true branch drop range fork: {:?}", true_ranges);
+                        true_ranges = self.swap_drop_ranges(true_ranges);
+                        self.visit_expr(if_true);
+                        true_ranges = self.swap_drop_ranges(true_ranges);
+                        debug!("true branch computed drop_ranges: {:?}", true_ranges);
+                        debug!("drop ranges before merging: {:?}", self.drop_ranges);
+                        self.merge_drop_ranges(true_ranges);
+                        debug!("drop ranges after merging: {:?}", self.drop_ranges);
+                    }
+                }
             }
             _ => intravisit::walk_expr(self, expr),
         }
@@ -825,20 +882,131 @@ impl<'tcx> Visitor<'tcx> for DropRangeVisitor<'tcx> {
     }
 }
 
-#[derive(Clone)]
+#[derive(Clone, Debug, PartialEq, Eq)]
+enum Event {
+    Drop(usize),
+    Reinit(usize),
+}
+
+impl Event {
+    fn location(&self) -> usize {
+        match *self {
+            Event::Drop(i) | Event::Reinit(i) => i,
+        }
+    }
+}
+
+impl PartialOrd for Event {
+    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+        self.location().partial_cmp(&other.location())
+    }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)]
 struct DropRange {
-    /// The post-order id of the point where this expression is dropped.
-    ///
-    /// We can consider the value dropped at any post-order id greater than dropped_at.
-    dropped_at: usize,
+    events: Vec<Event>,
 }
 
 impl DropRange {
+    fn new(begin: usize) -> Self {
+        Self { events: vec![Event::Drop(begin)] }
+    }
+
     fn intersect(&self, other: &Self) -> Self {
-        Self { dropped_at: self.dropped_at.max(other.dropped_at) }
+        let mut events = vec![];
+        self.events
+            .iter()
+            .merge_join_by(other.events.iter(), |a, b| a.partial_cmp(b).unwrap())
+            .fold((false, false), |(left, right), event| match event {
+                itertools::EitherOrBoth::Both(_, _) => todo!(),
+                itertools::EitherOrBoth::Left(e) => match e {
+                    Event::Drop(i) => {
+                        if !left && right {
+                            events.push(Event::Drop(*i));
+                        }
+                        (true, right)
+                    }
+                    Event::Reinit(i) => {
+                        if left && !right {
+                            events.push(Event::Reinit(*i));
+                        }
+                        (false, right)
+                    }
+                },
+                itertools::EitherOrBoth::Right(e) => match e {
+                    Event::Drop(i) => {
+                        if left && !right {
+                            events.push(Event::Drop(*i));
+                        }
+                        (left, true)
+                    }
+                    Event::Reinit(i) => {
+                        if !left && right {
+                            events.push(Event::Reinit(*i));
+                        }
+                        (left, false)
+                    }
+                },
+            });
+        Self { events }
+    }
+
+    fn is_dropped_at(&self, id: usize) -> bool {
+        match self.events.iter().try_fold(false, |is_dropped, event| {
+            if event.location() < id {
+                Ok(match event {
+                    Event::Drop(_) => true,
+                    Event::Reinit(_) => false,
+                })
+            } else {
+                Err(is_dropped)
+            }
+        }) {
+            Ok(is_dropped) | Err(is_dropped) => is_dropped,
+        }
+    }
+
+    #[allow(dead_code)]
+    fn drop(&mut self, location: usize) {
+        self.events.push(Event::Drop(location))
+    }
+
+    #[allow(dead_code)]
+    fn reinit(&mut self, location: usize) {
+        self.events.push(Event::Reinit(location));
+    }
+
+    /// Merges another range with this one. Meant to be used at control flow join points.
+    ///
+    /// After merging, the value will be dead at the end of the range only if it was dead
+    /// at the end of both self and other.
+    ///
+    /// Assumes that all locations in each range are less than joinpoint
+    #[allow(dead_code)]
+    fn merge_with(&mut self, other: &DropRange, join_point: usize) {
+        let mut events: Vec<_> =
+            self.events.iter().merge(other.events.iter()).dedup().cloned().collect();
+
+        events.push(if self.is_dropped_at(join_point) && other.is_dropped_at(join_point) {
+            Event::Drop(join_point)
+        } else {
+            Event::Reinit(join_point)
+        });
+
+        self.events = events;
     }
 
-    fn contains(&self, id: usize) -> bool {
-        id > self.dropped_at
+    /// Creates a new DropRange from this one at the split point.
+    ///
+    /// Used to model branching control flow.
+    #[allow(dead_code)]
+    fn fork_at(&self, split_point: usize) -> Self {
+        Self {
+            events: vec![if self.is_dropped_at(split_point) {
+                Event::Drop(split_point)
+            } else {
+                Event::Reinit(split_point)
+            }],
+        }
     }
 }
diff --git a/src/test/ui/generator/drop-if.rs b/src/test/ui/generator/drop-if.rs
new file mode 100644
index 00000000000..40f01f78662
--- /dev/null
+++ b/src/test/ui/generator/drop-if.rs
@@ -0,0 +1,22 @@
+// build-pass
+
+// This test case is reduced from src/test/ui/drop/dynamic-drop-async.rs
+
+#![feature(generators)]
+
+struct Ptr;
+impl<'a> Drop for Ptr {
+    fn drop(&mut self) {
+    }
+}
+
+fn main() {
+    let arg = true;
+    let _ = || {
+        let arr = [Ptr];
+        if arg {
+            drop(arr);
+        }
+        yield
+    };
+}