about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMakai <m4kai410@gmail.com>2025-06-22 18:01:39 +0000
committerMakai <m4kai410@gmail.com>2025-06-22 18:01:39 +0000
commite40515a494dc1bbd571aafd4469657d8ea951a70 (patch)
tree3017ffc735052960706712adc04260d9914399e6
parenta30f1783fe136d92545423dd30b12eb619973cdb (diff)
downloadrust-e40515a494dc1bbd571aafd4469657d8ea951a70.tar.gz
rust-e40515a494dc1bbd571aafd4469657d8ea951a70.zip
add method to retrieve body of coroutine
-rw-r--r--compiler/rustc_smir/src/stable_mir/ty.rs6
-rw-r--r--tests/ui-fulldeps/stable-mir/check_coroutine_body.rs105
2 files changed, 111 insertions, 0 deletions
diff --git a/compiler/rustc_smir/src/stable_mir/ty.rs b/compiler/rustc_smir/src/stable_mir/ty.rs
index 4415cd6e2e3..92fa97566c5 100644
--- a/compiler/rustc_smir/src/stable_mir/ty.rs
+++ b/compiler/rustc_smir/src/stable_mir/ty.rs
@@ -757,6 +757,12 @@ crate_def! {
 }
 
 impl CoroutineDef {
+    /// Retrieves the body of the coroutine definition. Returns None if the body
+    /// isn't available.
+    pub fn body(&self) -> Option<Body> {
+        with(|cx| cx.has_body(self.0).then(|| cx.mir_body(self.0)))
+    }
+
     pub fn discriminant_for_variant(&self, args: &GenericArgs, idx: VariantIdx) -> Discr {
         with(|cx| cx.coroutine_discr_for_variant(*self, args, idx))
     }
diff --git a/tests/ui-fulldeps/stable-mir/check_coroutine_body.rs b/tests/ui-fulldeps/stable-mir/check_coroutine_body.rs
new file mode 100644
index 00000000000..67773492958
--- /dev/null
+++ b/tests/ui-fulldeps/stable-mir/check_coroutine_body.rs
@@ -0,0 +1,105 @@
+//@ run-pass
+//! Tests stable mir API for retrieving the body of a coroutine.
+
+//@ ignore-stage1
+//@ ignore-cross-compile
+//@ ignore-remote
+//@ edition: 2024
+
+#![feature(rustc_private)]
+#![feature(assert_matches)]
+
+extern crate rustc_middle;
+#[macro_use]
+extern crate rustc_smir;
+extern crate rustc_driver;
+extern crate rustc_interface;
+extern crate stable_mir;
+
+use std::io::Write;
+use std::ops::ControlFlow;
+
+use stable_mir::mir::Body;
+use stable_mir::ty::{RigidTy, TyKind};
+
+const CRATE_NAME: &str = "crate_coroutine_body";
+
+fn test_coroutine_body() -> ControlFlow<()> {
+    let crate_items = stable_mir::all_local_items();
+    if let Some(body) = crate_items.iter().find_map(|item| {
+        let item_ty = item.ty();
+        if let TyKind::RigidTy(RigidTy::Coroutine(def, ..)) = &item_ty.kind() {
+            if def.0.name() == "gbc::{closure#0}".to_string() {
+                def.body()
+            } else {
+                None
+            }
+        } else {
+            None
+        }
+    }) {
+        check_coroutine_body(body);
+    } else {
+        panic!("Cannot find `gbc::{{closure#0}}`. All local items are: {:#?}", crate_items);
+    }
+
+    ControlFlow::Continue(())
+}
+
+fn check_coroutine_body(body: Body) {
+    let ret_ty = &body.locals()[0].ty;
+    let local_3 = &body.locals()[3].ty;
+    let local_4 = &body.locals()[4].ty;
+
+    let TyKind::RigidTy(RigidTy::Adt(def, ..)) = &ret_ty.kind()
+    else {
+        panic!("Expected RigidTy::Adt, got: {:#?}", ret_ty);
+    };
+
+    assert_eq!("std::task::Poll", def.0.name());
+
+    let TyKind::RigidTy(RigidTy::Coroutine(def, ..)) = &local_3.kind()
+    else {
+        panic!("Expected RigidTy::Coroutine, got: {:#?}", local_3);
+    };
+
+    assert_eq!("gbc::{closure#0}::{closure#0}", def.0.name());
+
+    let TyKind::RigidTy(RigidTy::Coroutine(def, ..)) = &local_4.kind()
+    else {
+        panic!("Expected RigidTy::Coroutine, got: {:#?}", local_4);
+    };
+
+    assert_eq!("gbc::{closure#0}::{closure#0}", def.0.name());
+}
+
+fn main() {
+    let path = "coroutine_body.rs";
+    generate_input(&path).unwrap();
+    let args = &[
+        "rustc".to_string(),
+        "-Cpanic=abort".to_string(),
+        "--edition".to_string(),
+        "2024".to_string(),
+        "--crate-name".to_string(),
+        CRATE_NAME.to_string(),
+        path.to_string(),
+    ];
+    run!(args, test_coroutine_body).unwrap();
+}
+
+fn generate_input(path: &str) -> std::io::Result<()> {
+    let mut file = std::fs::File::create(path)?;
+    write!(
+        file,
+        r#"
+        async fn gbc() -> i32 {{
+            let a = async {{ 1 }}.await;
+            a
+        }}
+
+        fn main() {{}}
+    "#
+    )?;
+    Ok(())
+}