about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDylan MacKenzie <ecstaticmorse@gmail.com>2020-06-10 14:24:28 -0700
committerDylan MacKenzie <ecstaticmorse@gmail.com>2020-06-19 10:33:09 -0700
commit8fc2eeb1c81aef42855257adc11791dcffe803cb (patch)
treeb264e358cf7dc507a1f8cb605f31d6d3ae5e90ef
parent72417d84fb51495a4f1d007fb2397a0b2609ab63 (diff)
downloadrust-8fc2eeb1c81aef42855257adc11791dcffe803cb.tar.gz
rust-8fc2eeb1c81aef42855257adc11791dcffe803cb.zip
Use newtype to map from `Local` to `GeneratorSavedLocal`
-rw-r--r--src/librustc_mir/transform/generator.rs108
1 files changed, 64 insertions, 44 deletions
diff --git a/src/librustc_mir/transform/generator.rs b/src/librustc_mir/transform/generator.rs
index c8702eeae1d..1445315c6b1 100644
--- a/src/librustc_mir/transform/generator.rs
+++ b/src/librustc_mir/transform/generator.rs
@@ -72,7 +72,7 @@ use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
 use rustc_target::abi::VariantIdx;
 use rustc_target::spec::PanicStrategy;
 use std::borrow::Cow;
-use std::iter;
+use std::{iter, ops};
 
 pub struct StateTransform;
 
@@ -417,11 +417,7 @@ fn replace_local<'tcx>(
 
 struct LivenessInfo {
     /// Which locals are live across any suspension point.
-    ///
-    /// GeneratorSavedLocal is indexed in terms of the elements in this set;
-    /// i.e. GeneratorSavedLocal::new(1) corresponds to the second local
-    /// included in this set.
-    live_locals: BitSet<Local>,
+    saved_locals: GeneratorSavedLocals,
 
     /// The set of saved locals live at each suspension point.
     live_locals_at_suspension_points: Vec<BitSet<GeneratorSavedLocal>>,
@@ -524,49 +520,75 @@ fn locals_live_across_suspend_points(
             live_locals_at_suspension_points.push(live_locals);
         }
     }
+
     debug!("live_locals_anywhere = {:?}", live_locals_at_any_suspension_point);
+    let saved_locals = GeneratorSavedLocals(live_locals_at_any_suspension_point);
 
     // Renumber our liveness_map bitsets to include only the locals we are
     // saving.
     let live_locals_at_suspension_points = live_locals_at_suspension_points
         .iter()
-        .map(|live_here| renumber_bitset(&live_here, &live_locals_at_any_suspension_point))
+        .map(|live_here| saved_locals.renumber_bitset(&live_here))
         .collect();
 
     let storage_conflicts = compute_storage_conflicts(
         body_ref,
-        &live_locals_at_any_suspension_point,
+        &saved_locals,
         always_live_locals.clone(),
         requires_storage_results,
     );
 
     LivenessInfo {
-        live_locals: live_locals_at_any_suspension_point,
+        saved_locals,
         live_locals_at_suspension_points,
         storage_conflicts,
         storage_liveness: storage_liveness_map,
     }
 }
 
-/// Renumbers the items present in `stored_locals` and applies the renumbering
-/// to 'input`.
+/// The set of `Local`s that must be saved across yield points.
 ///
-/// For example, if `stored_locals = [1, 3, 5]`, this would be renumbered to
-/// `[0, 1, 2]`. Thus, if `input = [3, 5]` we would return `[1, 2]`.
-fn renumber_bitset(
-    input: &BitSet<Local>,
-    stored_locals: &BitSet<Local>,
-) -> BitSet<GeneratorSavedLocal> {
-    assert!(stored_locals.superset(&input), "{:?} not a superset of {:?}", stored_locals, input);
-    let mut out = BitSet::new_empty(stored_locals.count());
-    for (idx, local) in stored_locals.iter().enumerate() {
-        let saved_local = GeneratorSavedLocal::from(idx);
-        if input.contains(local) {
-            out.insert(saved_local);
+/// `GeneratorSavedLocal` is indexed in terms of the elements in this set;
+/// i.e. `GeneratorSavedLocal::new(1)` corresponds to the second local
+/// included in this set.
+struct GeneratorSavedLocals(BitSet<Local>);
+
+impl GeneratorSavedLocals {
+    /// Returns an iterator over each `GeneratorSavedLocal` along with the `Local` it corresponds
+    /// to.
+    fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (GeneratorSavedLocal, Local)> {
+        self.iter().enumerate().map(|(i, l)| (GeneratorSavedLocal::from(i), l))
+    }
+
+    /// Transforms a `BitSet<Local>` that contains only locals saved across yield points to the
+    /// equivalent `BitSet<GeneratorSavedLocal>`.
+    fn renumber_bitset(&self, input: &BitSet<Local>) -> BitSet<GeneratorSavedLocal> {
+        assert!(self.superset(&input), "{:?} not a superset of {:?}", self.0, input);
+        let mut out = BitSet::new_empty(self.count());
+        for (saved_local, local) in self.iter_enumerated() {
+            if input.contains(local) {
+                out.insert(saved_local);
+            }
+        }
+        out
+    }
+
+    fn get(&self, local: Local) -> Option<GeneratorSavedLocal> {
+        if !self.contains(local) {
+            return None;
         }
+
+        let idx = self.iter().take_while(|&l| l < local).count();
+        Some(GeneratorSavedLocal::new(idx))
+    }
+}
+
+impl ops::Deref for GeneratorSavedLocals {
+    type Target = BitSet<Local>;
+
+    fn deref(&self) -> &Self::Target {
+        &self.0
     }
-    debug!("renumber_bitset({:?}, {:?}) => {:?}", input, stored_locals, out);
-    out
 }
 
 /// For every saved local, looks for which locals are StorageLive at the same
@@ -575,11 +597,11 @@ fn renumber_bitset(
 /// computation; see `GeneratorLayout` for more.
 fn compute_storage_conflicts(
     body: &'mir Body<'tcx>,
-    stored_locals: &BitSet<Local>,
+    saved_locals: &GeneratorSavedLocals,
     always_live_locals: storage::AlwaysLiveLocals,
     requires_storage: dataflow::Results<'tcx, MaybeRequiresStorage<'mir, 'tcx>>,
 ) -> BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal> {
-    assert_eq!(body.local_decls.len(), stored_locals.domain_size());
+    assert_eq!(body.local_decls.len(), saved_locals.domain_size());
 
     debug!("compute_storage_conflicts({:?})", body.span);
     debug!("always_live = {:?}", always_live_locals);
@@ -587,12 +609,12 @@ fn compute_storage_conflicts(
     // Locals that are always live or ones that need to be stored across
     // suspension points are not eligible for overlap.
     let mut ineligible_locals = always_live_locals.into_inner();
-    ineligible_locals.intersect(stored_locals);
+    ineligible_locals.intersect(saved_locals);
 
     // Compute the storage conflicts for all eligible locals.
     let mut visitor = StorageConflictVisitor {
         body,
-        stored_locals: &stored_locals,
+        saved_locals: &saved_locals,
         local_conflicts: BitMatrix::from_row_n(&ineligible_locals, body.local_decls.len()),
     };
 
@@ -609,16 +631,14 @@ fn compute_storage_conflicts(
     // However, in practice these bitsets are not usually large. The layout code
     // also needs to keep track of how many conflicts each local has, so it's
     // simpler to keep it this way for now.
-    let mut storage_conflicts = BitMatrix::new(stored_locals.count(), stored_locals.count());
-    for (idx_a, local_a) in stored_locals.iter().enumerate() {
-        let saved_local_a = GeneratorSavedLocal::new(idx_a);
+    let mut storage_conflicts = BitMatrix::new(saved_locals.count(), saved_locals.count());
+    for (saved_local_a, local_a) in saved_locals.iter_enumerated() {
         if ineligible_locals.contains(local_a) {
             // Conflicts with everything.
             storage_conflicts.insert_all_into_row(saved_local_a);
         } else {
             // Keep overlap information only for stored locals.
-            for (idx_b, local_b) in stored_locals.iter().enumerate() {
-                let saved_local_b = GeneratorSavedLocal::new(idx_b);
+            for (saved_local_b, local_b) in saved_locals.iter_enumerated() {
                 if local_conflicts.contains(local_a, local_b) {
                     storage_conflicts.insert(saved_local_a, saved_local_b);
                 }
@@ -630,7 +650,7 @@ fn compute_storage_conflicts(
 
 struct StorageConflictVisitor<'mir, 'tcx, 's> {
     body: &'mir Body<'tcx>,
-    stored_locals: &'s BitSet<Local>,
+    saved_locals: &'s GeneratorSavedLocals,
     // FIXME(tmandry): Consider using sparse bitsets here once we have good
     // benchmarks for generators.
     local_conflicts: BitMatrix<Local, Local>,
@@ -666,7 +686,7 @@ impl<'body, 'tcx, 's> StorageConflictVisitor<'body, 'tcx, 's> {
         }
 
         let mut eligible_storage_live = flow_state.clone();
-        eligible_storage_live.intersect(&self.stored_locals);
+        eligible_storage_live.intersect(&self.saved_locals);
 
         for local in eligible_storage_live.iter() {
             self.local_conflicts.union_row_with(&eligible_storage_live, local);
@@ -678,7 +698,7 @@ impl<'body, 'tcx, 's> StorageConflictVisitor<'body, 'tcx, 's> {
     }
 }
 
-/// Validates the typeck view of the generator against the actual set of types retained between
+/// Validates the typeck view of the generator against the actual set of types saved between
 /// yield points.
 fn sanitize_witness<'tcx>(
     tcx: TyCtxt<'tcx>,
@@ -686,7 +706,7 @@ fn sanitize_witness<'tcx>(
     did: DefId,
     witness: Ty<'tcx>,
     upvars: &Vec<Ty<'tcx>>,
-    retained: &BitSet<Local>,
+    saved_locals: &GeneratorSavedLocals,
 ) {
     let allowed_upvars = tcx.erase_regions(upvars);
     let allowed = match witness.kind {
@@ -703,8 +723,8 @@ fn sanitize_witness<'tcx>(
     let param_env = tcx.param_env(did);
 
     for (local, decl) in body.local_decls.iter_enumerated() {
-        // Ignore locals which are internal or not retained between yields.
-        if !retained.contains(local) || decl.internal {
+        // Ignore locals which are internal or not saved between yields.
+        if !saved_locals.contains(local) || decl.internal {
             continue;
         }
         let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty);
@@ -738,21 +758,21 @@ fn compute_layout<'tcx>(
 ) {
     // Use a liveness analysis to compute locals which are live across a suspension point
     let LivenessInfo {
-        live_locals,
+        saved_locals,
         live_locals_at_suspension_points,
         storage_conflicts,
         storage_liveness,
     } = locals_live_across_suspend_points(tcx, body, source, always_live_locals, movable);
 
-    sanitize_witness(tcx, body, source.def_id(), interior, upvars, &live_locals);
+    sanitize_witness(tcx, body, source.def_id(), interior, upvars, &saved_locals);
 
     // Gather live local types and their indices.
     let mut locals = IndexVec::<GeneratorSavedLocal, _>::new();
     let mut tys = IndexVec::<GeneratorSavedLocal, _>::new();
-    for (idx, local) in live_locals.iter().enumerate() {
+    for (saved_local, local) in saved_locals.iter_enumerated() {
         locals.push(local);
         tys.push(body.local_decls[local].ty);
-        debug!("generator saved local {:?} => {:?}", GeneratorSavedLocal::from(idx), local);
+        debug!("generator saved local {:?} => {:?}", saved_local, local);
     }
 
     // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states.