about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_smir/src/rustc_smir/context.rs33
-rw-r--r--compiler/rustc_smir/src/rustc_smir/convert/ty.rs8
-rw-r--r--compiler/rustc_smir/src/stable_mir/compiler_interface.rs23
-rw-r--r--compiler/rustc_smir/src/stable_mir/ty.rs15
-rw-r--r--tests/ui-fulldeps/stable-mir/check_variant.rs183
5 files changed, 254 insertions, 8 deletions
diff --git a/compiler/rustc_smir/src/rustc_smir/context.rs b/compiler/rustc_smir/src/rustc_smir/context.rs
index ef8c88355f6..baa4c0681e8 100644
--- a/compiler/rustc_smir/src/rustc_smir/context.rs
+++ b/compiler/rustc_smir/src/rustc_smir/context.rs
@@ -12,7 +12,8 @@ use rustc_middle::ty::layout::{
 };
 use rustc_middle::ty::print::{with_forced_trimmed_paths, with_no_trimmed_paths};
 use rustc_middle::ty::{
-    GenericPredicates, Instance, List, ScalarInt, TyCtxt, TypeVisitableExt, ValTree,
+    CoroutineArgsExt, GenericPredicates, Instance, List, ScalarInt, TyCtxt, TypeVisitableExt,
+    ValTree,
 };
 use rustc_middle::{mir, ty};
 use rustc_span::def_id::LOCAL_CRATE;
@@ -22,9 +23,9 @@ use stable_mir::mir::mono::{InstanceDef, StaticDef};
 use stable_mir::mir::{BinOp, Body, Place, UnOp};
 use stable_mir::target::{MachineInfo, MachineSize};
 use stable_mir::ty::{
-    AdtDef, AdtKind, Allocation, ClosureDef, ClosureKind, FieldDef, FnDef, ForeignDef,
-    ForeignItemKind, GenericArgs, IntrinsicDef, LineInfo, MirConst, PolyFnSig, RigidTy, Span, Ty,
-    TyConst, TyKind, UintTy, VariantDef,
+    AdtDef, AdtKind, Allocation, ClosureDef, ClosureKind, CoroutineDef, Discr, FieldDef, FnDef,
+    ForeignDef, ForeignItemKind, GenericArgs, IntrinsicDef, LineInfo, MirConst, PolyFnSig, RigidTy,
+    Span, Ty, TyConst, TyKind, UintTy, VariantDef, VariantIdx,
 };
 use stable_mir::{Crate, CrateDef, CrateItem, CrateNum, DefId, Error, Filename, ItemKind, Symbol};
 
@@ -447,6 +448,30 @@ impl<'tcx> SmirCtxt<'tcx> {
         def.internal(&mut *tables, tcx).variants().len()
     }
 
+    /// Discriminant for a given variant index of AdtDef
+    pub fn adt_discr_for_variant(&self, adt: AdtDef, variant: VariantIdx) -> Discr {
+        let mut tables = self.0.borrow_mut();
+        let tcx = tables.tcx;
+        let adt = adt.internal(&mut *tables, tcx);
+        let variant = variant.internal(&mut *tables, tcx);
+        adt.discriminant_for_variant(tcx, variant).stable(&mut *tables)
+    }
+
+    /// Discriminant for a given variand index and args of a coroutine
+    pub fn coroutine_discr_for_variant(
+        &self,
+        coroutine: CoroutineDef,
+        args: &GenericArgs,
+        variant: VariantIdx,
+    ) -> Discr {
+        let mut tables = self.0.borrow_mut();
+        let tcx = tables.tcx;
+        let coroutine = coroutine.def_id().internal(&mut *tables, tcx);
+        let args = args.internal(&mut *tables, tcx);
+        let variant = variant.internal(&mut *tables, tcx);
+        args.as_coroutine().discriminant_for_variant(coroutine, tcx, variant).stable(&mut *tables)
+    }
+
     /// The name of a variant.
     pub fn variant_name(&self, def: VariantDef) -> Symbol {
         let mut tables = self.0.borrow_mut();
diff --git a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs
index 6a26f5f7997..b4239ddd896 100644
--- a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs
+++ b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs
@@ -960,3 +960,11 @@ impl<'tcx> Stable<'tcx> for ty::ImplTraitInTraitData {
         }
     }
 }
+
+impl<'tcx> Stable<'tcx> for rustc_middle::ty::util::Discr<'tcx> {
+    type T = stable_mir::ty::Discr;
+
+    fn stable(&self, tables: &mut Tables<'_>) -> Self::T {
+        stable_mir::ty::Discr { val: self.val, ty: self.ty.stable(tables) }
+    }
+}
diff --git a/compiler/rustc_smir/src/stable_mir/compiler_interface.rs b/compiler/rustc_smir/src/stable_mir/compiler_interface.rs
index 3967ad13eeb..2668fba9f4f 100644
--- a/compiler/rustc_smir/src/stable_mir/compiler_interface.rs
+++ b/compiler/rustc_smir/src/stable_mir/compiler_interface.rs
@@ -13,10 +13,10 @@ use stable_mir::mir::mono::{Instance, InstanceDef, StaticDef};
 use stable_mir::mir::{BinOp, Body, Place, UnOp};
 use stable_mir::target::MachineInfo;
 use stable_mir::ty::{
-    AdtDef, AdtKind, Allocation, ClosureDef, ClosureKind, FieldDef, FnDef, ForeignDef,
-    ForeignItemKind, ForeignModule, ForeignModuleDef, GenericArgs, GenericPredicates, Generics,
-    ImplDef, ImplTrait, IntrinsicDef, LineInfo, MirConst, PolyFnSig, RigidTy, Span, TraitDecl,
-    TraitDef, Ty, TyConst, TyConstId, TyKind, UintTy, VariantDef,
+    AdtDef, AdtKind, Allocation, ClosureDef, ClosureKind, CoroutineDef, Discr, FieldDef, FnDef,
+    ForeignDef, ForeignItemKind, ForeignModule, ForeignModuleDef, GenericArgs, GenericPredicates,
+    Generics, ImplDef, ImplTrait, IntrinsicDef, LineInfo, MirConst, PolyFnSig, RigidTy, Span,
+    TraitDecl, TraitDef, Ty, TyConst, TyConstId, TyKind, UintTy, VariantDef, VariantIdx,
 };
 use stable_mir::{
     AssocItems, Crate, CrateItem, CrateItems, CrateNum, DefId, Error, Filename, ImplTraitDecls,
@@ -230,6 +230,21 @@ impl<'tcx> SmirInterface<'tcx> {
         self.cx.adt_variants_len(def)
     }
 
+    /// Discriminant for a given variant index of AdtDef
+    pub(crate) fn adt_discr_for_variant(&self, adt: AdtDef, variant: VariantIdx) -> Discr {
+        self.cx.adt_discr_for_variant(adt, variant)
+    }
+
+    /// Discriminant for a given variand index and args of a coroutine
+    pub(crate) fn coroutine_discr_for_variant(
+        &self,
+        coroutine: CoroutineDef,
+        args: &GenericArgs,
+        variant: VariantIdx,
+    ) -> Discr {
+        self.cx.coroutine_discr_for_variant(coroutine, args, variant)
+    }
+
     /// The name of a variant.
     pub(crate) fn variant_name(&self, def: VariantDef) -> Symbol {
         self.cx.variant_name(def)
diff --git a/compiler/rustc_smir/src/stable_mir/ty.rs b/compiler/rustc_smir/src/stable_mir/ty.rs
index 2934af31cd5..4415cd6e2e3 100644
--- a/compiler/rustc_smir/src/stable_mir/ty.rs
+++ b/compiler/rustc_smir/src/stable_mir/ty.rs
@@ -756,6 +756,12 @@ crate_def! {
     pub CoroutineDef;
 }
 
+impl CoroutineDef {
+    pub fn discriminant_for_variant(&self, args: &GenericArgs, idx: VariantIdx) -> Discr {
+        with(|cx| cx.coroutine_discr_for_variant(*self, args, idx))
+    }
+}
+
 crate_def! {
     #[derive(Serialize)]
     pub CoroutineClosureDef;
@@ -831,6 +837,15 @@ impl AdtDef {
     pub fn repr(&self) -> ReprOptions {
         with(|cx| cx.adt_repr(*self))
     }
+
+    pub fn discriminant_for_variant(&self, idx: VariantIdx) -> Discr {
+        with(|cx| cx.adt_discr_for_variant(*self, idx))
+    }
+}
+
+pub struct Discr {
+    pub val: u128,
+    pub ty: Ty,
 }
 
 /// Definition of a variant, which can be either a struct / union field or an enum variant.
diff --git a/tests/ui-fulldeps/stable-mir/check_variant.rs b/tests/ui-fulldeps/stable-mir/check_variant.rs
new file mode 100644
index 00000000000..b0de3369830
--- /dev/null
+++ b/tests/ui-fulldeps/stable-mir/check_variant.rs
@@ -0,0 +1,183 @@
+//@ run-pass
+//! Test that users are able to use stable mir APIs to retrieve
+//! discriminant value and type for AdtDef and Coroutine variants
+
+//@ ignore-stage1
+//@ ignore-cross-compile
+//@ ignore-remote
+//@ edition: 2024
+
+#![feature(rustc_private)]
+#![feature(assert_matches)]
+
+extern crate rustc_middle;
+#[macro_use]
+extern crate rustc_smir;
+extern crate rustc_driver;
+extern crate rustc_interface;
+extern crate stable_mir;
+
+use std::io::Write;
+use std::ops::ControlFlow;
+
+use stable_mir::CrateItem;
+use stable_mir::crate_def::CrateDef;
+use stable_mir::mir::{AggregateKind, Rvalue, Statement, StatementKind};
+use stable_mir::ty::{IntTy, RigidTy, Ty};
+
+const CRATE_NAME: &str = "crate_variant_ty";
+
+/// Test if we can retrieve discriminant info for different types.
+fn test_def_tys() -> ControlFlow<()> {
+    check_adt_mono();
+    check_adt_poly();
+    check_adt_poly2();
+
+    ControlFlow::Continue(())
+}
+
+fn check_adt_mono() {
+    let mono = get_fn("mono").expect_body();
+
+    check_statement_is_aggregate_assign(
+        &mono.blocks[0].statements[0],
+        0,
+        RigidTy::Int(IntTy::Isize),
+    );
+    check_statement_is_aggregate_assign(
+        &mono.blocks[1].statements[0],
+        1,
+        RigidTy::Int(IntTy::Isize),
+    );
+    check_statement_is_aggregate_assign(
+        &mono.blocks[2].statements[0],
+        2,
+        RigidTy::Int(IntTy::Isize),
+    );
+}
+
+fn check_adt_poly() {
+    let poly = get_fn("poly").expect_body();
+
+    check_statement_is_aggregate_assign(
+        &poly.blocks[0].statements[0],
+        0,
+        RigidTy::Int(IntTy::Isize),
+    );
+    check_statement_is_aggregate_assign(
+        &poly.blocks[1].statements[0],
+        1,
+        RigidTy::Int(IntTy::Isize),
+    );
+    check_statement_is_aggregate_assign(
+        &poly.blocks[2].statements[0],
+        2,
+        RigidTy::Int(IntTy::Isize),
+    );
+}
+
+fn check_adt_poly2() {
+    let poly = get_fn("poly2").expect_body();
+
+    check_statement_is_aggregate_assign(
+        &poly.blocks[0].statements[0],
+        0,
+        RigidTy::Int(IntTy::Isize),
+    );
+    check_statement_is_aggregate_assign(
+        &poly.blocks[1].statements[0],
+        1,
+        RigidTy::Int(IntTy::Isize),
+    );
+    check_statement_is_aggregate_assign(
+        &poly.blocks[2].statements[0],
+        2,
+        RigidTy::Int(IntTy::Isize),
+    );
+}
+
+fn get_fn(name: &str) -> CrateItem {
+    stable_mir::all_local_items().into_iter().find(|it| it.name().eq(name)).unwrap()
+}
+
+fn check_statement_is_aggregate_assign(
+    statement: &Statement,
+    expected_discr_val: u128,
+    expected_discr_ty: RigidTy,
+) {
+    if let Statement { kind: StatementKind::Assign(_, rvalue), .. } = statement
+        && let Rvalue::Aggregate(aggregate, _) = rvalue
+        && let AggregateKind::Adt(adt_def, variant_idx, ..) = aggregate
+    {
+        let discr = adt_def.discriminant_for_variant(*variant_idx);
+
+        assert_eq!(discr.val, expected_discr_val);
+        assert_eq!(discr.ty, Ty::from_rigid_kind(expected_discr_ty));
+    } else {
+        unreachable!("Unexpected statement");
+    }
+}
+
+/// This test will generate and analyze a dummy crate using the stable mir.
+/// For that, it will first write the dummy crate into a file.
+/// Then it will create a `StableMir` using custom arguments and then
+/// it will run the compiler.
+fn main() {
+    let path = "defs_ty_input.rs";
+    generate_input(&path).unwrap();
+    let args = &[
+        "rustc".to_string(),
+        "-Cpanic=abort".to_string(),
+        "--crate-name".to_string(),
+        CRATE_NAME.to_string(),
+        path.to_string(),
+    ];
+    run!(args, test_def_tys).unwrap();
+}
+
+fn generate_input(path: &str) -> std::io::Result<()> {
+    let mut file = std::fs::File::create(path)?;
+    write!(
+        file,
+        r#"
+        use std::hint::black_box;
+
+        enum Mono {{
+            A,
+            B(i32),
+            C {{ a: i32, b: u32 }},
+        }}
+
+        enum Poly<T> {{
+            A,
+            B(T),
+            C {{ t: T }},
+        }}
+
+        pub fn main() {{
+            mono();
+            poly();
+            poly2::<i32>(1);
+        }}
+
+        fn mono() {{
+            black_box(Mono::A);
+            black_box(Mono::B(6));
+            black_box(Mono::C {{a: 1, b: 10 }});
+        }}
+
+        fn poly() {{
+            black_box(Poly::<i32>::A);
+            black_box(Poly::B(1i32));
+            black_box(Poly::C {{ t: 1i32 }});
+        }}
+
+        fn poly2<T: Copy>(t: T) {{
+            black_box(Poly::<T>::A);
+            black_box(Poly::B(t));
+            black_box(Poly::C {{ t: t }});
+        }}
+    "#
+    )?;
+    Ok(())
+}