about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/miri/src/concurrency/thread.rs17
-rw-r--r--src/tools/miri/src/concurrency/weak_memory.rs10
-rw-r--r--src/tools/miri/src/lib.rs2
-rw-r--r--src/tools/miri/src/machine.rs27
-rw-r--r--src/tools/miri/src/shims/env.rs7
-rw-r--r--src/tools/miri/src/shims/panic.rs6
-rw-r--r--src/tools/miri/src/shims/tls.rs6
-rw-r--r--src/tools/miri/src/shims/unix/fs.rs4
-rw-r--r--src/tools/miri/src/stacked_borrows/mod.rs6
-rw-r--r--src/tools/miri/src/tag_gc.rs168
10 files changed, 148 insertions, 105 deletions
diff --git a/src/tools/miri/src/concurrency/thread.rs b/src/tools/miri/src/concurrency/thread.rs
index 1b05088b3d5..01ae4320f3b 100644
--- a/src/tools/miri/src/concurrency/thread.rs
+++ b/src/tools/miri/src/concurrency/thread.rs
@@ -182,15 +182,15 @@ impl<'mir, 'tcx> Thread<'mir, 'tcx> {
 }
 
 impl VisitMachineValues for Thread<'_, '_> {
-    fn visit_machine_values(&self, visit: &mut impl FnMut(&Operand<Provenance>)) {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
         let Thread { panic_payload, last_error, stack, state: _, thread_name: _, join_status: _ } =
             self;
 
         if let Some(payload) = panic_payload {
-            visit(&Operand::Immediate(Immediate::Scalar(*payload)))
+            visit.visit(*payload);
         }
         if let Some(error) = last_error {
-            visit(&Operand::Indirect(**error))
+            visit.visit(**error);
         }
         for frame in stack {
             frame.visit_machine_values(visit)
@@ -199,7 +199,7 @@ impl VisitMachineValues for Thread<'_, '_> {
 }
 
 impl VisitMachineValues for Frame<'_, '_, Provenance, FrameData<'_>> {
-    fn visit_machine_values(&self, visit: &mut impl FnMut(&Operand<Provenance>)) {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
         let Frame {
             return_place,
             locals,
@@ -213,12 +213,12 @@ impl VisitMachineValues for Frame<'_, '_, Provenance, FrameData<'_>> {
 
         // Return place.
         if let Place::Ptr(mplace) = **return_place {
-            visit(&Operand::Indirect(mplace));
+            visit.visit(mplace);
         }
         // Locals.
         for local in locals.iter() {
             if let LocalValue::Live(value) = &local.value {
-                visit(value);
+                visit.visit(value);
             }
         }
 
@@ -299,7 +299,7 @@ impl<'mir, 'tcx> Default for ThreadManager<'mir, 'tcx> {
 }
 
 impl VisitMachineValues for ThreadManager<'_, '_> {
-    fn visit_machine_values(&self, visit: &mut impl FnMut(&Operand<Provenance>)) {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
         let ThreadManager {
             threads,
             thread_local_alloc_ids,
@@ -313,8 +313,7 @@ impl VisitMachineValues for ThreadManager<'_, '_> {
             thread.visit_machine_values(visit);
         }
         for ptr in thread_local_alloc_ids.borrow().values().copied() {
-            let ptr: Pointer<Option<Provenance>> = ptr.into();
-            visit(&Operand::Indirect(MemPlace::from_ptr(ptr)));
+            visit.visit(ptr);
         }
         // FIXME: Do we need to do something for TimeoutCallback? That's a Box<dyn>, not sure what
         // to do.
diff --git a/src/tools/miri/src/concurrency/weak_memory.rs b/src/tools/miri/src/concurrency/weak_memory.rs
index becd61f4fea..15c6c8e9c0e 100644
--- a/src/tools/miri/src/concurrency/weak_memory.rs
+++ b/src/tools/miri/src/concurrency/weak_memory.rs
@@ -108,19 +108,15 @@ pub struct StoreBufferAlloc {
     store_buffers: RefCell<RangeObjectMap<StoreBuffer>>,
 }
 
-impl VisitProvenance for StoreBufferAlloc {
-    fn visit_provenance(&self, visitor: &mut impl FnMut(SbTag)) {
+impl VisitMachineValues for StoreBufferAlloc {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
         for val in self
             .store_buffers
             .borrow()
             .iter()
             .flat_map(|buf| buf.buffer.iter().map(|element| &element.val))
         {
-            if let Scalar::Ptr(ptr, _) = val {
-                if let Provenance::Concrete { sb, .. } = ptr.provenance {
-                    visitor(sb);
-                }
-            }
+            visit.visit(val);
         }
     }
 }
diff --git a/src/tools/miri/src/lib.rs b/src/tools/miri/src/lib.rs
index e60e1f15b5f..245bdc51a8a 100644
--- a/src/tools/miri/src/lib.rs
+++ b/src/tools/miri/src/lib.rs
@@ -112,7 +112,7 @@ pub use crate::range_map::RangeMap;
 pub use crate::stacked_borrows::{
     CallId, EvalContextExt as StackedBorEvalContextExt, Item, Permission, SbTag, Stack, Stacks,
 };
-pub use crate::tag_gc::{EvalContextExt as _, VisitMachineValues, VisitProvenance};
+pub use crate::tag_gc::{EvalContextExt as _, ProvenanceVisitor, VisitMachineValues};
 
 /// Insert rustc arguments at the beginning of the argument list that Miri wants to be
 /// set per default, for maximal validation power.
diff --git a/src/tools/miri/src/machine.rs b/src/tools/miri/src/machine.rs
index 35d5c0d9a87..523aad22aa5 100644
--- a/src/tools/miri/src/machine.rs
+++ b/src/tools/miri/src/machine.rs
@@ -64,7 +64,7 @@ impl<'tcx> std::fmt::Debug for FrameData<'tcx> {
 }
 
 impl VisitMachineValues for FrameData<'_> {
-    fn visit_machine_values(&self, visit: &mut impl FnMut(&Operand<Provenance>)) {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
         let FrameData { catch_unwind, stacked_borrows: _, timing: _ } = self;
 
         if let Some(catch_unwind) = catch_unwind {
@@ -261,6 +261,20 @@ pub struct AllocExtra {
     pub weak_memory: Option<weak_memory::AllocExtra>,
 }
 
+impl VisitMachineValues for AllocExtra {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
+        let AllocExtra { stacked_borrows, data_race: _, weak_memory } = self;
+
+        if let Some(stacked_borrows) = stacked_borrows {
+            stacked_borrows.borrow().visit_machine_values(visit);
+        }
+
+        if let Some(weak_memory) = weak_memory {
+            weak_memory.visit_machine_values(visit);
+        }
+    }
+}
+
 /// Precomputed layouts of primitive types
 pub struct PrimitiveLayouts<'tcx> {
     pub unit: TyAndLayout<'tcx>,
@@ -602,7 +616,7 @@ impl<'mir, 'tcx> MiriMachine<'mir, 'tcx> {
 }
 
 impl VisitMachineValues for MiriMachine<'_, '_> {
-    fn visit_machine_values(&self, visit: &mut impl FnMut(&Operand<Provenance>)) {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
         let MiriMachine {
             threads,
             tls,
@@ -621,17 +635,16 @@ impl VisitMachineValues for MiriMachine<'_, '_> {
         dir_handler.visit_machine_values(visit);
 
         if let Some(argc) = argc {
-            visit(&Operand::Indirect(*argc));
+            visit.visit(argc);
         }
         if let Some(argv) = argv {
-            visit(&Operand::Indirect(*argv));
+            visit.visit(argv);
         }
         if let Some(cmd_line) = cmd_line {
-            visit(&Operand::Indirect(*cmd_line));
+            visit.visit(cmd_line);
         }
         for ptr in extern_statics.values().copied() {
-            let ptr: Pointer<Option<Provenance>> = ptr.into();
-            visit(&Operand::Indirect(MemPlace::from_ptr(ptr)));
+            visit.visit(ptr);
         }
     }
 }
diff --git a/src/tools/miri/src/shims/env.rs b/src/tools/miri/src/shims/env.rs
index ad2d2eaab34..d922014c383 100644
--- a/src/tools/miri/src/shims/env.rs
+++ b/src/tools/miri/src/shims/env.rs
@@ -37,15 +37,14 @@ pub struct EnvVars<'tcx> {
 }
 
 impl VisitMachineValues for EnvVars<'_> {
-    fn visit_machine_values(&self, visit: &mut impl FnMut(&Operand<Provenance>)) {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
         let EnvVars { map, environ } = self;
 
         for ptr in map.values() {
-            visit(&Operand::Indirect(MemPlace::from_ptr(*ptr)));
+            visit.visit(*ptr);
         }
-
         if let Some(env) = environ {
-            visit(&Operand::Indirect(**env));
+            visit.visit(**env);
         }
     }
 }
diff --git a/src/tools/miri/src/shims/panic.rs b/src/tools/miri/src/shims/panic.rs
index be14892f696..0d681d3e09b 100644
--- a/src/tools/miri/src/shims/panic.rs
+++ b/src/tools/miri/src/shims/panic.rs
@@ -36,10 +36,10 @@ pub struct CatchUnwindData<'tcx> {
 }
 
 impl VisitMachineValues for CatchUnwindData<'_> {
-    fn visit_machine_values(&self, visit: &mut impl FnMut(&Operand<Provenance>)) {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
         let CatchUnwindData { catch_fn, data, dest: _, ret: _ } = self;
-        visit(&Operand::Indirect(MemPlace::from_ptr(*catch_fn)));
-        visit(&Operand::Immediate(Immediate::Scalar(*data)));
+        visit.visit(catch_fn);
+        visit.visit(data);
     }
 }
 
diff --git a/src/tools/miri/src/shims/tls.rs b/src/tools/miri/src/shims/tls.rs
index d1cee307d77..568eb6fa910 100644
--- a/src/tools/miri/src/shims/tls.rs
+++ b/src/tools/miri/src/shims/tls.rs
@@ -236,14 +236,14 @@ impl<'tcx> TlsData<'tcx> {
 }
 
 impl VisitMachineValues for TlsData<'_> {
-    fn visit_machine_values(&self, visit: &mut impl FnMut(&Operand<Provenance>)) {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
         let TlsData { keys, macos_thread_dtors, next_key: _, dtors_running: _ } = self;
 
         for scalar in keys.values().flat_map(|v| v.data.values()) {
-            visit(&Operand::Immediate(Immediate::Scalar(*scalar)));
+            visit.visit(scalar);
         }
         for (_, scalar) in macos_thread_dtors.values() {
-            visit(&Operand::Immediate(Immediate::Scalar(*scalar)));
+            visit.visit(scalar);
         }
     }
 }
diff --git a/src/tools/miri/src/shims/unix/fs.rs b/src/tools/miri/src/shims/unix/fs.rs
index 59d24e00dc1..5024b2ab45f 100644
--- a/src/tools/miri/src/shims/unix/fs.rs
+++ b/src/tools/miri/src/shims/unix/fs.rs
@@ -463,11 +463,11 @@ impl Default for DirHandler {
 }
 
 impl VisitMachineValues for DirHandler {
-    fn visit_machine_values(&self, visit: &mut impl FnMut(&Operand<Provenance>)) {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
         let DirHandler { streams, next_id: _ } = self;
 
         for dir in streams.values() {
-            visit(&Operand::Indirect(MemPlace::from_ptr(dir.entry)));
+            visit.visit(dir.entry);
         }
     }
 }
diff --git a/src/tools/miri/src/stacked_borrows/mod.rs b/src/tools/miri/src/stacked_borrows/mod.rs
index ab90e358449..b40358e2c15 100644
--- a/src/tools/miri/src/stacked_borrows/mod.rs
+++ b/src/tools/miri/src/stacked_borrows/mod.rs
@@ -513,10 +513,10 @@ impl Stacks {
     }
 }
 
-impl VisitProvenance for Stacks {
-    fn visit_provenance(&self, visit: &mut impl FnMut(SbTag)) {
+impl VisitMachineValues for Stacks {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
         for tag in self.exposed_tags.iter().copied() {
-            visit(tag);
+            visit.visit(tag);
         }
     }
 }
diff --git a/src/tools/miri/src/tag_gc.rs b/src/tools/miri/src/tag_gc.rs
index 0a8d5d00cfb..e2273f055dd 100644
--- a/src/tools/miri/src/tag_gc.rs
+++ b/src/tools/miri/src/tag_gc.rs
@@ -3,34 +3,120 @@ use rustc_data_structures::fx::FxHashSet;
 use crate::*;
 
 pub trait VisitMachineValues {
-    fn visit_machine_values(&self, visit: &mut impl FnMut(&Operand<Provenance>));
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor);
 }
 
-pub trait VisitProvenance {
-    fn visit_provenance(&self, visit: &mut impl FnMut(SbTag));
+pub trait MachineValue {
+    fn visit_provenance(&self, tags: &mut FxHashSet<SbTag>);
+}
+
+pub struct ProvenanceVisitor {
+    tags: FxHashSet<SbTag>,
+}
+
+impl ProvenanceVisitor {
+    pub fn visit<V>(&mut self, v: V)
+    where
+        V: MachineValue,
+    {
+        v.visit_provenance(&mut self.tags);
+    }
+}
+
+impl<T: MachineValue> MachineValue for &T {
+    fn visit_provenance(&self, tags: &mut FxHashSet<SbTag>) {
+        (**self).visit_provenance(tags);
+    }
+}
+
+impl MachineValue for Operand<Provenance> {
+    fn visit_provenance(&self, tags: &mut FxHashSet<SbTag>) {
+        match self {
+            Operand::Immediate(Immediate::Scalar(s)) => {
+                s.visit_provenance(tags);
+            }
+            Operand::Immediate(Immediate::ScalarPair(s1, s2)) => {
+                s1.visit_provenance(tags);
+                s2.visit_provenance(tags);
+            }
+            Operand::Immediate(Immediate::Uninit) => {}
+            Operand::Indirect(p) => {
+                p.visit_provenance(tags);
+            }
+        }
+    }
+}
+
+impl MachineValue for Scalar<Provenance> {
+    fn visit_provenance(&self, tags: &mut FxHashSet<SbTag>) {
+        if let Scalar::Ptr(ptr, _) = self {
+            if let Provenance::Concrete { sb, .. } = ptr.provenance {
+                tags.insert(sb);
+            }
+        }
+    }
+}
+
+impl MachineValue for MemPlace<Provenance> {
+    fn visit_provenance(&self, tags: &mut FxHashSet<SbTag>) {
+        if let Some(Provenance::Concrete { sb, .. }) = self.ptr.provenance {
+            tags.insert(sb);
+        }
+    }
+}
+
+impl MachineValue for SbTag {
+    fn visit_provenance(&self, tags: &mut FxHashSet<SbTag>) {
+        tags.insert(*self);
+    }
+}
+
+impl MachineValue for Pointer<Provenance> {
+    fn visit_provenance(&self, tags: &mut FxHashSet<SbTag>) {
+        let (prov, _offset) = self.into_parts();
+        if let Provenance::Concrete { sb, .. } = prov {
+            tags.insert(sb);
+        }
+    }
+}
+
+impl MachineValue for Pointer<Option<Provenance>> {
+    fn visit_provenance(&self, tags: &mut FxHashSet<SbTag>) {
+        let (prov, _offset) = self.into_parts();
+        if let Some(Provenance::Concrete { sb, .. }) = prov {
+            tags.insert(sb);
+        }
+    }
+}
+
+impl VisitMachineValues for Allocation<Provenance, AllocExtra> {
+    fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
+        for (_size, prov) in self.provenance().iter() {
+            if let Provenance::Concrete { sb, .. } = prov {
+                visit.visit(*sb);
+            }
+        }
+
+        self.extra.visit_machine_values(visit);
+    }
 }
 
 impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
 pub trait EvalContextExt<'mir, 'tcx: 'mir>: MiriInterpCxExt<'mir, 'tcx> {
-    /// Generic GC helper to visit everything that can store a value. The `acc` offers some chance to
-    /// accumulate everything.
-    fn visit_all_machine_values<T>(
-        &self,
-        acc: &mut T,
-        mut visit_operand: impl FnMut(&mut T, &Operand<Provenance>),
-        mut visit_alloc: impl FnMut(&mut T, &Allocation<Provenance, AllocExtra>),
-    ) {
+    /// GC helper to visit everything that can store provenance. The `ProvenanceVisitor` knows how
+    /// to extract provenance from the interpreter data types.
+    fn visit_all_machine_values(&self, acc: &mut ProvenanceVisitor) {
         let this = self.eval_context_ref();
 
         // Memory.
         this.memory.alloc_map().iter(|it| {
             for (_id, (_kind, alloc)) in it {
-                visit_alloc(acc, alloc);
+                alloc.visit_machine_values(acc);
             }
         });
 
         // And all the other machine values.
-        this.machine.visit_machine_values(&mut |op| visit_operand(acc, op));
+        this.machine.visit_machine_values(acc);
     }
 
     fn garbage_collect_tags(&mut self) -> InterpResult<'tcx> {
@@ -40,59 +126,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: MiriInterpCxExt<'mir, 'tcx> {
             return Ok(());
         }
 
-        let mut tags = FxHashSet::default();
-
-        let visit_scalar = |tags: &mut FxHashSet<SbTag>, s: &Scalar<Provenance>| {
-            if let Scalar::Ptr(ptr, _) = s {
-                if let Provenance::Concrete { sb, .. } = ptr.provenance {
-                    tags.insert(sb);
-                }
-            }
-        };
-
-        let visit_provenance = |tags: &mut FxHashSet<SbTag>, tag: SbTag| {
-            tags.insert(tag);
-        };
-
-        this.visit_all_machine_values(
-            &mut tags,
-            |tags, op| {
-                match op {
-                    Operand::Immediate(Immediate::Scalar(s)) => {
-                        visit_scalar(tags, s);
-                    }
-                    Operand::Immediate(Immediate::ScalarPair(s1, s2)) => {
-                        visit_scalar(tags, s1);
-                        visit_scalar(tags, s2);
-                    }
-                    Operand::Immediate(Immediate::Uninit) => {}
-                    Operand::Indirect(MemPlace { ptr, .. }) => {
-                        if let Some(Provenance::Concrete { sb, .. }) = ptr.provenance {
-                            tags.insert(sb);
-                        }
-                    }
-                }
-            },
-            |tags, alloc| {
-                for (_size, prov) in alloc.provenance().iter() {
-                    if let Provenance::Concrete { sb, .. } = prov {
-                        tags.insert(*sb);
-                    }
-                }
-
-                let stacks =
-                    alloc.extra.stacked_borrows.as_ref().expect(
-                        "we should not even enter the tag GC if Stacked Borrows is disabled",
-                    );
-                stacks.borrow().visit_provenance(&mut |tag| visit_provenance(tags, tag));
-
-                if let Some(store_buffers) = alloc.extra.weak_memory.as_ref() {
-                    store_buffers.visit_provenance(&mut |tag| visit_provenance(tags, tag));
-                }
-            },
-        );
-
-        self.remove_unreachable_tags(tags);
+        let mut visitor = ProvenanceVisitor { tags: FxHashSet::default() };
+        this.visit_all_machine_values(&mut visitor);
+        self.remove_unreachable_tags(visitor.tags);
 
         Ok(())
     }