about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_mir/src/transform/dest_prop.rs113
-rw-r--r--src/test/ui/dest-prop/skeptic-miscompile.rs24
2 files changed, 84 insertions, 53 deletions
diff --git a/compiler/rustc_mir/src/transform/dest_prop.rs b/compiler/rustc_mir/src/transform/dest_prop.rs
index d8f62e35062..cd370837405 100644
--- a/compiler/rustc_mir/src/transform/dest_prop.rs
+++ b/compiler/rustc_mir/src/transform/dest_prop.rs
@@ -166,20 +166,24 @@ impl<'tcx> MirPass<'tcx> for DestinationPropagation {
         let mut replacements = Replacements::new(body.local_decls.len());
         for candidate @ CandidateAssignment { dest, src, loc } in candidates {
             // Merge locals that don't conflict.
-            if conflicts.contains(dest.local, src) {
+            if !conflicts.can_unify(dest.local, src) {
                 debug!("at assignment {:?}, conflict {:?} vs. {:?}", loc, dest.local, src);
                 continue;
             }
 
+            if replacements.for_src(candidate.src).is_some() {
+                debug!("src {:?} already has replacement", candidate.src);
+                continue;
+            }
+
             if !tcx.consider_optimizing(|| {
                 format!("DestinationPropagation {:?} {:?}", source.def_id(), candidate)
             }) {
                 break;
             }
 
-            if replacements.push(candidate).is_ok() {
-                conflicts.unify(candidate.src, candidate.dest.local);
-            }
+            replacements.push(candidate);
+            conflicts.unify(candidate.src, candidate.dest.local);
         }
 
         replacements.flatten(tcx);
@@ -220,61 +224,21 @@ struct Replacements<'tcx> {
 
     /// Whose locals' live ranges to kill.
     kill: BitSet<Local>,
-
-    /// Tracks locals that have already been merged together to prevent cycles.
-    unified_locals: InPlaceUnificationTable<UnifyLocal>,
 }
 
 impl Replacements<'tcx> {
     fn new(locals: usize) -> Self {
-        Self {
-            map: IndexVec::from_elem_n(None, locals),
-            kill: BitSet::new_empty(locals),
-            unified_locals: {
-                let mut table = InPlaceUnificationTable::new();
-                for local in 0..locals {
-                    assert_eq!(table.new_key(()), UnifyLocal(Local::from_usize(local)));
-                }
-                table
-            },
-        }
+        Self { map: IndexVec::from_elem_n(None, locals), kill: BitSet::new_empty(locals) }
     }
 
-    fn push(&mut self, candidate: CandidateAssignment<'tcx>) -> Result<(), ()> {
-        if self.unified_locals.unioned(candidate.src, candidate.dest.local) {
-            // Candidate conflicts with previous replacement (ie. could possibly form a cycle and
-            // hang).
-
-            let replacement = self.map[candidate.src].as_mut().unwrap();
-
-            // If the current replacement is for the same `dest` local, there are 2 or more
-            // equivalent `src = dest;` assignments. This is fine, the replacer will `nop` out all
-            // of them.
-            if replacement.local == candidate.dest.local {
-                assert_eq!(replacement.projection, candidate.dest.projection);
-            }
-
-            // We still return `Err` in any case, as `src` and `dest` do not need to be unified
-            // *again*.
-            trace!("push({:?}): already unified", candidate);
-            return Err(());
-        }
-
+    fn push(&mut self, candidate: CandidateAssignment<'tcx>) {
+        trace!("Replacements::push({:?})", candidate);
         let entry = &mut self.map[candidate.src];
-        if entry.is_some() {
-            // We're already replacing `src` with something else, so this candidate is out.
-            trace!("push({:?}): src already has replacement", candidate);
-            return Err(());
-        }
-
-        self.unified_locals.union(candidate.src, candidate.dest.local);
+        assert!(entry.is_none());
 
         *entry = Some(candidate.dest);
         self.kill.insert(candidate.src);
         self.kill.insert(candidate.dest.local);
-
-        trace!("push({:?}): accepted", candidate);
-        Ok(())
     }
 
     /// Applies the stored replacements to all replacements, until no replacements would result in
@@ -410,6 +374,9 @@ struct Conflicts<'a> {
 
     /// Preallocated `BitSet` used by `unify`.
     unify_cache: BitSet<Local>,
+
+    /// Tracks locals that have been merged together to prevent cycles and propagate conflicts.
+    unified_locals: InPlaceUnificationTable<UnifyLocal>,
 }
 
 impl Conflicts<'a> {
@@ -495,6 +462,15 @@ impl Conflicts<'a> {
             relevant_locals,
             matrix: conflicts,
             unify_cache: BitSet::new_empty(body.local_decls.len()),
+            unified_locals: {
+                let mut table = InPlaceUnificationTable::new();
+                // Pre-fill table with all locals (this creates N nodes / "connected" components,
+                // "graph"-ically speaking).
+                for local in 0..body.local_decls.len() {
+                    assert_eq!(table.new_key(()), UnifyLocal(Local::from_usize(local)));
+                }
+                table
+            },
         };
 
         let mut live_and_init_locals = Vec::new();
@@ -761,11 +737,31 @@ impl Conflicts<'a> {
         }
     }
 
-    fn contains(&self, a: Local, b: Local) -> bool {
-        self.matrix.contains(a, b)
+    /// Checks whether `a` and `b` may be merged. Returns `false` if there's a conflict.
+    fn can_unify(&mut self, a: Local, b: Local) -> bool {
+        // After some locals have been unified, their conflicts are only tracked in the root key,
+        // so look that up.
+        let a = self.unified_locals.find(a).0;
+        let b = self.unified_locals.find(b).0;
+
+        if a == b {
+            // Already merged (part of the same connected component).
+            return false;
+        }
+
+        if self.matrix.contains(a, b) {
+            // Conflict (derived via dataflow, intra-statement conflicts, or inherited from another
+            // local during unification).
+            return false;
+        }
+
+        true
     }
 
     /// Merges the conflicts of `a` and `b`, so that each one inherits all conflicts of the other.
+    /// 
+    /// `can_unify` must have returned `true` for the same locals, or this may panic or lead to
+    /// miscompiles.
     ///
     /// This is called when the pass makes the decision to unify `a` and `b` (or parts of `a` and
     /// `b`) and is needed to ensure that future unification decisions take potentially newly
@@ -781,13 +777,24 @@ impl Conflicts<'a> {
     /// `_2` with `_0`, which also doesn't have a conflict in the above list. However `_2` is now
     /// `_3`, which does conflict with `_0`.
     fn unify(&mut self, a: Local, b: Local) {
-        // FIXME: This might be somewhat slow. Conflict graphs are undirected, maybe we can use
-        // something with union-find to speed this up?
-
         trace!("unify({:?}, {:?})", a, b);
+
+        // Get the root local of the connected components. The root local stores the conflicts of
+        // all locals in the connected component (and *is stored* as the conflicting local of other
+        // locals).
+        let a = self.unified_locals.find(a).0;
+        let b = self.unified_locals.find(b).0;
+        assert_ne!(a, b);
+
+        trace!("roots: a={:?}, b={:?}", a, b);
         trace!("{:?} conflicts: {:?}", a, self.matrix.iter(a).format(", "));
         trace!("{:?} conflicts: {:?}", b, self.matrix.iter(b).format(", "));
 
+        self.unified_locals.union(a, b);
+
+        let root = self.unified_locals.find(a).0;
+        assert!(root == a || root == b);
+
         // Make all locals that conflict with `a` also conflict with `b`, and vice versa.
         self.unify_cache.clear();
         for conflicts_with_a in self.matrix.iter(a) {
diff --git a/src/test/ui/dest-prop/skeptic-miscompile.rs b/src/test/ui/dest-prop/skeptic-miscompile.rs
new file mode 100644
index 00000000000..c27a1f04532
--- /dev/null
+++ b/src/test/ui/dest-prop/skeptic-miscompile.rs
@@ -0,0 +1,24 @@
+// run-pass
+
+// compile-flags: -Zmir-opt-level=2
+
+trait IterExt: Iterator {
+    fn fold_ex<B, F>(mut self, init: B, mut f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B,
+    {
+        let mut accum = init;
+        while let Some(x) = self.next() {
+            accum = f(accum, x);
+        }
+        accum
+    }
+}
+
+impl<T: Iterator> IterExt for T {}
+
+fn main() {
+    let test = &["\n"];
+    test.iter().fold_ex(String::new(), |_, b| b.to_string());
+}