about summary refs log tree commit diff
path: root/src/tools/miri/genmc-sys/cpp
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2025-09-12 16:10:02 +0000
committerGitHub <noreply@github.com>2025-09-12 16:10:02 +0000
commitfc7eb3c28d2be162dd32951811ce7852bb1a2f6a (patch)
tree6070178a641374d861db214b942b9a113eef0dd2 /src/tools/miri/genmc-sys/cpp
parentaf16b80cb927da25f3617212b42b9626959127a9 (diff)
parenta70d78a55286a85cf5e06958929f5a6071fa1c67 (diff)
downloadrust-fc7eb3c28d2be162dd32951811ce7852bb1a2f6a.tar.gz
rust-fc7eb3c28d2be162dd32951811ce7852bb1a2f6a.zip
Merge pull request #4578 from Patrick-6/miri-genmc-cas
Add compare_exchange support for GenMC mode
Diffstat (limited to 'src/tools/miri/genmc-sys/cpp')
-rw-r--r--src/tools/miri/genmc-sys/cpp/include/MiriInterface.hpp82
-rw-r--r--src/tools/miri/genmc-sys/cpp/src/MiriInterface/EventHandling.cpp65
-rw-r--r--src/tools/miri/genmc-sys/cpp/src/MiriInterface/Setup.cpp8
3 files changed, 128 insertions, 27 deletions
diff --git a/src/tools/miri/genmc-sys/cpp/include/MiriInterface.hpp b/src/tools/miri/genmc-sys/cpp/include/MiriInterface.hpp
index b0769375843..662eb0e173c 100644
--- a/src/tools/miri/genmc-sys/cpp/include/MiriInterface.hpp
+++ b/src/tools/miri/genmc-sys/cpp/include/MiriInterface.hpp
@@ -34,6 +34,7 @@ struct SchedulingResult;
 struct LoadResult;
 struct StoreResult;
 struct ReadModifyWriteResult;
+struct CompareExchangeResult;
 
 // GenMC uses `int` for its thread IDs.
 using ThreadId = int;
@@ -100,6 +101,17 @@ struct MiriGenmcShim : private GenMCDriver {
         GenmcScalar rhs_value,
         GenmcScalar old_val
     );
+    [[nodiscard]] CompareExchangeResult handle_compare_exchange(
+        ThreadId thread_id,
+        uint64_t address,
+        uint64_t size,
+        GenmcScalar expected_value,
+        GenmcScalar new_value,
+        GenmcScalar old_val,
+        MemOrdering success_ordering,
+        MemOrdering fail_load_ordering,
+        bool can_fail_spuriously
+    );
     [[nodiscard]] StoreResult handle_store(
         ThreadId thread_id,
         uint64_t address,
@@ -231,15 +243,15 @@ constexpr auto get_global_alloc_static_mask() -> uint64_t {
 namespace GenmcScalarExt {
 inline GenmcScalar uninit() {
     return GenmcScalar {
-        /* value: */ 0,
-        /* is_init: */ false,
+        .value = 0,
+        .is_init = false,
     };
 }
 
 inline GenmcScalar from_sval(SVal sval) {
     return GenmcScalar {
-        /* value: */ sval.get(),
-        /* is_init: */ true,
+        .value = sval.get(),
+        .is_init = true,
     };
 }
 
@@ -252,22 +264,22 @@ inline SVal to_sval(GenmcScalar scalar) {
 namespace LoadResultExt {
 inline LoadResult no_value() {
     return LoadResult {
-        /* error: */ std::unique_ptr<std::string>(nullptr),
-        /* has_value: */ false,
-        /* read_value: */ GenmcScalarExt::uninit(),
+        .error = std::unique_ptr<std::string>(nullptr),
+        .has_value = false,
+        .read_value = GenmcScalarExt::uninit(),
     };
 }
 
 inline LoadResult from_value(SVal read_value) {
-    return LoadResult { /* error: */ std::unique_ptr<std::string>(nullptr),
-                        /* has_value: */ true,
-                        /* read_value: */ GenmcScalarExt::from_sval(read_value) };
+    return LoadResult { .error = std::unique_ptr<std::string>(nullptr),
+                        .has_value = true,
+                        .read_value = GenmcScalarExt::from_sval(read_value) };
 }
 
 inline LoadResult from_error(std::unique_ptr<std::string> error) {
-    return LoadResult { /* error: */ std::move(error),
-                        /* has_value: */ false,
-                        /* read_value: */ GenmcScalarExt::uninit() };
+    return LoadResult { .error = std::move(error),
+                        .has_value = false,
+                        .read_value = GenmcScalarExt::uninit() };
 }
 } // namespace LoadResultExt
 
@@ -278,26 +290,50 @@ inline StoreResult ok(bool is_coherence_order_maximal_write) {
 }
 
 inline StoreResult from_error(std::unique_ptr<std::string> error) {
-    return StoreResult { /* error: */ std::move(error),
-                         /* is_coherence_order_maximal_write: */ false };
+    return StoreResult { .error = std::move(error), .is_coherence_order_maximal_write = false };
 }
 } // namespace StoreResultExt
 
 namespace ReadModifyWriteResultExt {
 inline ReadModifyWriteResult
 ok(SVal old_value, SVal new_value, bool is_coherence_order_maximal_write) {
-    return ReadModifyWriteResult { /* error: */ std::unique_ptr<std::string>(nullptr),
-                                   /* old_value: */ GenmcScalarExt::from_sval(old_value),
-                                   /* new_value: */ GenmcScalarExt::from_sval(new_value),
-                                   is_coherence_order_maximal_write };
+    return ReadModifyWriteResult { .error = std::unique_ptr<std::string>(nullptr),
+                                   .old_value = GenmcScalarExt::from_sval(old_value),
+                                   .new_value = GenmcScalarExt::from_sval(new_value),
+                                   .is_coherence_order_maximal_write =
+                                       is_coherence_order_maximal_write };
 }
 
 inline ReadModifyWriteResult from_error(std::unique_ptr<std::string> error) {
-    return ReadModifyWriteResult { /* error: */ std::move(error),
-                                   /* old_value: */ GenmcScalarExt::uninit(),
-                                   /* new_value: */ GenmcScalarExt::uninit(),
-                                   /* is_coherence_order_maximal_write: */ false };
+    return ReadModifyWriteResult { .error = std::move(error),
+                                   .old_value = GenmcScalarExt::uninit(),
+                                   .new_value = GenmcScalarExt::uninit(),
+                                   .is_coherence_order_maximal_write = false };
 }
 } // namespace ReadModifyWriteResultExt
 
+namespace CompareExchangeResultExt {
+inline CompareExchangeResult success(SVal old_value, bool is_coherence_order_maximal_write) {
+    return CompareExchangeResult { .error = nullptr,
+                                   .old_value = GenmcScalarExt::from_sval(old_value),
+                                   .is_success = true,
+                                   .is_coherence_order_maximal_write =
+                                       is_coherence_order_maximal_write };
+}
+
+inline CompareExchangeResult failure(SVal old_value) {
+    return CompareExchangeResult { .error = nullptr,
+                                   .old_value = GenmcScalarExt::from_sval(old_value),
+                                   .is_success = false,
+                                   .is_coherence_order_maximal_write = false };
+}
+
+inline CompareExchangeResult from_error(std::unique_ptr<std::string> error) {
+    return CompareExchangeResult { .error = std::move(error),
+                                   .old_value = GenmcScalarExt::uninit(),
+                                   .is_success = false,
+                                   .is_coherence_order_maximal_write = false };
+}
+} // namespace CompareExchangeResultExt
+
 #endif /* GENMC_MIRI_INTERFACE_HPP */
diff --git a/src/tools/miri/genmc-sys/cpp/src/MiriInterface/EventHandling.cpp b/src/tools/miri/genmc-sys/cpp/src/MiriInterface/EventHandling.cpp
index cd28e0d148f..05c82641df9 100644
--- a/src/tools/miri/genmc-sys/cpp/src/MiriInterface/EventHandling.cpp
+++ b/src/tools/miri/genmc-sys/cpp/src/MiriInterface/EventHandling.cpp
@@ -155,6 +155,71 @@ void MiriGenmcShim::handle_fence(ThreadId thread_id, MemOrdering ord) {
     );
 }
 
+[[nodiscard]] auto MiriGenmcShim::handle_compare_exchange(
+    ThreadId thread_id,
+    uint64_t address,
+    uint64_t size,
+    GenmcScalar expected_value,
+    GenmcScalar new_value,
+    GenmcScalar old_val,
+    MemOrdering success_ordering,
+    MemOrdering fail_load_ordering,
+    bool can_fail_spuriously
+) -> CompareExchangeResult {
+    // NOTE: Both the store and load events should get the same `ordering`, it should not be split
+    // into a load and a store component. This means we can have for example `AcqRel` loads and
+    // stores, but this is intended for CAS operations.
+
+    // FIXME(GenMC): properly handle failure memory ordering.
+
+    auto expectedVal = GenmcScalarExt::to_sval(expected_value);
+    auto new_val = GenmcScalarExt::to_sval(new_value);
+
+    const auto load_ret = handle_load_reset_if_none<EventLabel::EventLabelKind::CasRead>(
+        thread_id,
+        success_ordering,
+        SAddr(address),
+        ASize(size),
+        AType::Unsigned, // The type is only used for printing.
+        expectedVal,
+        new_val
+    );
+    if (const auto* err = std::get_if<VerificationError>(&load_ret))
+        return CompareExchangeResultExt::from_error(format_error(*err));
+    const auto* ret_val = std::get_if<SVal>(&load_ret);
+    ERROR_ON(nullptr == ret_val, "Unimplemented: load returned unexpected result.");
+    const auto read_old_val = *ret_val;
+    if (read_old_val != expectedVal)
+        return CompareExchangeResultExt::failure(read_old_val);
+
+    // FIXME(GenMC): Add support for modelling spurious failures.
+
+    const auto storePos = inc_pos(thread_id);
+    const auto store_ret = GenMCDriver::handleStore<EventLabel::EventLabelKind::CasWrite>(
+        storePos,
+        success_ordering,
+        SAddr(address),
+        ASize(size),
+        AType::Unsigned, // The type is only used for printing.
+        new_val
+    );
+    if (const auto* err = std::get_if<VerificationError>(&store_ret))
+        return CompareExchangeResultExt::from_error(format_error(*err));
+    const auto* store_ret_val = std::get_if<std::monostate>(&store_ret);
+    ERROR_ON(
+        nullptr == store_ret_val,
+        "Unimplemented: compare-exchange store returned unexpected result."
+    );
+
+    // FIXME(genmc,mixed-accesses): Use the value that GenMC returns from handleStore (once
+    // available).
+    const auto& g = getExec().getGraph();
+    return CompareExchangeResultExt::success(
+        read_old_val,
+        /* is_coherence_order_maximal_write */ g.co_max(SAddr(address))->getPos() == storePos
+    );
+}
+
 /**** Memory (de)allocation ****/
 
 auto MiriGenmcShim::handle_malloc(ThreadId thread_id, uint64_t size, uint64_t alignment)
diff --git a/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Setup.cpp b/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Setup.cpp
index 5a53fee0592..a17a83aa06e 100644
--- a/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Setup.cpp
+++ b/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Setup.cpp
@@ -152,10 +152,10 @@ static auto to_genmc_verbosity_level(const LogLevel log_level) -> VerbosityLevel
         // Miri already ensures that memory accesses are valid, so this check doesn't matter.
         // We check that the address is static, but skip checking if it is part of an actual
         // allocation.
-        /* isStaticallyAllocated: */ [](SAddr addr) { return addr.isStatic(); },
+        .isStaticallyAllocated = [](SAddr addr) { return addr.isStatic(); },
         // FIXME(genmc,error reporting): Once a proper a proper API for passing such information is
         // implemented in GenMC, Miri should use it to improve the produced error messages.
-        /* getStaticName: */ [](SAddr addr) { return "[UNKNOWN STATIC]"; },
+        .getStaticName = [](SAddr addr) { return "[UNKNOWN STATIC]"; },
         // This function is called to get the initial value stored at the given address.
         //
         // From a Miri perspective, this API doesn't work very well: most memory starts out
@@ -177,10 +177,10 @@ static auto to_genmc_verbosity_level(const LogLevel log_level) -> VerbosityLevel
         // FIXME(genmc): implement proper support for uninitialized memory in GenMC. Ideally, the
         // initial value getter would return an `optional<SVal>`, since the memory location may be
         // uninitialized.
-        /* initValGetter: */ [](const AAccess& a) { return SVal(0xDEAD); },
+        .initValGetter = [](const AAccess& a) { return SVal(0xDEAD); },
         // Miri serves non-atomic loads from its own memory and these GenMC checks are wrong in
         // that case. This should no longer be required with proper mixed-size access support.
-        /* skipUninitLoadChecks: */ [](MemOrdering ord) { return ord == MemOrdering::NotAtomic; },
+        .skipUninitLoadChecks = [](MemOrdering ord) { return ord == MemOrdering::NotAtomic; },
     };
     driver->setInterpCallbacks(std::move(interpreter_callbacks));