about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCelina G. Val <celinval@amazon.com>2023-11-07 14:07:32 -0800
committerCelina G. Val <celinval@amazon.com>2023-11-16 11:05:36 -0800
commite70839ac84aa3dc8b4525a11e64961e375f8f3ba (patch)
tree715dd66e6c40fa1efdee8df5dcede0b7e9c1afa8
parent0ea7ddcc35a2fcaa5da8a7dcfc118c9fb4a63b95 (diff)
downloadrust-e70839ac84aa3dc8b4525a11e64961e375f8f3ba.tar.gz
rust-e70839ac84aa3dc8b4525a11e64961e375f8f3ba.zip
Add more SMIR internal impl and callback return value
In cases like Kani, we will invoke the rustc_internal run command
directly for now. It would be handly to be able to have a callback
that can return a value.

We also need extra methods to convert stable constructs into internal
ones, so we can break down the transition into finer grain commits.
-rw-r--r--compiler/rustc_smir/src/rustc_internal/internal.rs136
-rw-r--r--compiler/rustc_smir/src/rustc_internal/mod.rs24
-rw-r--r--compiler/stable_mir/src/lib.rs20
3 files changed, 159 insertions, 21 deletions
diff --git a/compiler/rustc_smir/src/rustc_internal/internal.rs b/compiler/rustc_smir/src/rustc_internal/internal.rs
index 7cfdbbbf703..5bb3c1a0d4c 100644
--- a/compiler/rustc_smir/src/rustc_internal/internal.rs
+++ b/compiler/rustc_smir/src/rustc_internal/internal.rs
@@ -6,11 +6,23 @@
 // Prefer importing stable_mir over internal rustc constructs to make this file more readable.
 use crate::rustc_smir::Tables;
 use rustc_middle::ty::{self as rustc_ty, Ty as InternalTy};
-use stable_mir::ty::{Const, GenericArgKind, GenericArgs, Region, Ty};
-use stable_mir::DefId;
+use rustc_span::Symbol;
+use stable_mir::mir::mono::{Instance, MonoItem, StaticDef};
+use stable_mir::ty::{
+    Binder, BoundRegionKind, BoundTyKind, BoundVariableKind, ClosureKind, Const, GenericArgKind,
+    GenericArgs, Region, TraitRef, Ty,
+};
+use stable_mir::{AllocId, CrateItem, DefId};
 
 use super::RustcInternal;
 
+impl<'tcx> RustcInternal<'tcx> for CrateItem {
+    type T = rustc_span::def_id::DefId;
+    fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T {
+        self.0.internal(tables)
+    }
+}
+
 impl<'tcx> RustcInternal<'tcx> for DefId {
     type T = rustc_span::def_id::DefId;
     fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T {
@@ -38,8 +50,9 @@ impl<'tcx> RustcInternal<'tcx> for GenericArgKind {
 
 impl<'tcx> RustcInternal<'tcx> for Region {
     type T = rustc_ty::Region<'tcx>;
-    fn internal(&self, _tables: &mut Tables<'tcx>) -> Self::T {
-        todo!()
+    fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T {
+        // Cannot recover region. Use erased instead.
+        tables.tcx.lifetimes.re_erased
     }
 }
 
@@ -65,3 +78,118 @@ impl<'tcx> RustcInternal<'tcx> for Const {
         tables.constants[self.id]
     }
 }
+
+impl<'tcx> RustcInternal<'tcx> for MonoItem {
+    type T = rustc_middle::mir::mono::MonoItem<'tcx>;
+
+    fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T {
+        use rustc_middle::mir::mono as rustc_mono;
+        match self {
+            MonoItem::Fn(instance) => rustc_mono::MonoItem::Fn(instance.internal(tables)),
+            MonoItem::Static(def) => rustc_mono::MonoItem::Static(def.internal(tables)),
+            MonoItem::GlobalAsm(_) => {
+                unimplemented!()
+            }
+        }
+    }
+}
+
+impl<'tcx> RustcInternal<'tcx> for Instance {
+    type T = rustc_ty::Instance<'tcx>;
+
+    fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T {
+        tables.instances[self.def]
+    }
+}
+
+impl<'tcx> RustcInternal<'tcx> for StaticDef {
+    type T = rustc_span::def_id::DefId;
+
+    fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T {
+        self.0.internal(tables)
+    }
+}
+
+#[allow(rustc::usage_of_qualified_ty)]
+impl<'tcx, T> RustcInternal<'tcx> for Binder<T>
+where
+    T: RustcInternal<'tcx>,
+    T::T: rustc_ty::TypeVisitable<rustc_ty::TyCtxt<'tcx>>,
+{
+    type T = rustc_ty::Binder<'tcx, T::T>;
+
+    fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T {
+        rustc_ty::Binder::bind_with_vars(
+            self.value.internal(tables),
+            tables.tcx.mk_bound_variable_kinds_from_iter(
+                self.bound_vars.iter().map(|bound| bound.internal(tables)),
+            ),
+        )
+    }
+}
+
+impl<'tcx> RustcInternal<'tcx> for BoundVariableKind {
+    type T = rustc_ty::BoundVariableKind;
+
+    fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T {
+        match self {
+            BoundVariableKind::Ty(kind) => rustc_ty::BoundVariableKind::Ty(match kind {
+                BoundTyKind::Anon => rustc_ty::BoundTyKind::Anon,
+                BoundTyKind::Param(def, symbol) => {
+                    rustc_ty::BoundTyKind::Param(def.0.internal(tables), Symbol::intern(&symbol))
+                }
+            }),
+            BoundVariableKind::Region(kind) => rustc_ty::BoundVariableKind::Region(match kind {
+                BoundRegionKind::BrAnon => rustc_ty::BoundRegionKind::BrAnon,
+                BoundRegionKind::BrNamed(def, symbol) => rustc_ty::BoundRegionKind::BrNamed(
+                    def.0.internal(tables),
+                    Symbol::intern(&symbol),
+                ),
+                BoundRegionKind::BrEnv => rustc_ty::BoundRegionKind::BrEnv,
+            }),
+            BoundVariableKind::Const => rustc_ty::BoundVariableKind::Const,
+        }
+    }
+}
+
+impl<'tcx> RustcInternal<'tcx> for TraitRef {
+    type T = rustc_ty::TraitRef<'tcx>;
+
+    fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T {
+        rustc_ty::TraitRef::new(
+            tables.tcx,
+            self.def_id.0.internal(tables),
+            self.args().internal(tables),
+        )
+    }
+}
+
+impl<'tcx> RustcInternal<'tcx> for AllocId {
+    type T = rustc_middle::mir::interpret::AllocId;
+    fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T {
+        tables.alloc_ids[*self]
+    }
+}
+
+impl<'tcx> RustcInternal<'tcx> for ClosureKind {
+    type T = rustc_ty::ClosureKind;
+
+    fn internal(&self, _tables: &mut Tables<'tcx>) -> Self::T {
+        match self {
+            ClosureKind::Fn => rustc_ty::ClosureKind::Fn,
+            ClosureKind::FnMut => rustc_ty::ClosureKind::FnMut,
+            ClosureKind::FnOnce => rustc_ty::ClosureKind::FnOnce,
+        }
+    }
+}
+
+impl<'tcx, T> RustcInternal<'tcx> for &T
+where
+    T: RustcInternal<'tcx>,
+{
+    type T = T::T;
+
+    fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T {
+        (*self).internal(tables)
+    }
+}
diff --git a/compiler/rustc_smir/src/rustc_internal/mod.rs b/compiler/rustc_smir/src/rustc_internal/mod.rs
index f0b368bec39..c82f948f195 100644
--- a/compiler/rustc_smir/src/rustc_internal/mod.rs
+++ b/compiler/rustc_smir/src/rustc_internal/mod.rs
@@ -13,6 +13,7 @@ use rustc_span::def_id::{CrateNum, DefId};
 use rustc_span::Span;
 use scoped_tls::scoped_thread_local;
 use stable_mir::ty::IndexedVal;
+use stable_mir::Error;
 use std::cell::Cell;
 use std::cell::RefCell;
 use std::fmt::Debug;
@@ -21,11 +22,11 @@ use std::ops::Index;
 
 mod internal;
 
-pub fn stable<'tcx, S: Stable<'tcx>>(item: &S) -> S::T {
+pub fn stable<'tcx, S: Stable<'tcx>>(item: S) -> S::T {
     with_tables(|tables| item.stable(tables))
 }
 
-pub fn internal<'tcx, S: RustcInternal<'tcx>>(item: &S) -> S::T {
+pub fn internal<'tcx, S: RustcInternal<'tcx>>(item: S) -> S::T {
     with_tables(|tables| item.internal(tables))
 }
 
@@ -144,12 +145,13 @@ pub fn crate_num(item: &stable_mir::Crate) -> CrateNum {
 // datastructures and stable MIR datastructures
 scoped_thread_local! (static TLV: Cell<*const ()>);
 
-pub(crate) fn init<'tcx>(tables: &TablesWrapper<'tcx>, f: impl FnOnce()) {
+pub(crate) fn init<'tcx, F, T>(tables: &TablesWrapper<'tcx>, f: F) -> T
+where
+    F: FnOnce() -> T,
+{
     assert!(!TLV.is_set());
     let ptr = tables as *const _ as *const ();
-    TLV.set(&Cell::new(ptr), || {
-        f();
-    });
+    TLV.set(&Cell::new(ptr), || f())
 }
 
 /// Loads the current context and calls a function with it.
@@ -165,7 +167,10 @@ pub(crate) fn with_tables<'tcx, R>(f: impl FnOnce(&mut Tables<'tcx>) -> R) -> R
     })
 }
 
-pub fn run(tcx: TyCtxt<'_>, f: impl FnOnce()) {
+pub fn run<F, T>(tcx: TyCtxt<'_>, f: F) -> Result<T, Error>
+where
+    F: FnOnce() -> T,
+{
     let tables = TablesWrapper(RefCell::new(Tables {
         tcx,
         def_ids: IndexMap::default(),
@@ -175,7 +180,7 @@ pub fn run(tcx: TyCtxt<'_>, f: impl FnOnce()) {
         instances: IndexMap::default(),
         constants: IndexMap::default(),
     }));
-    stable_mir::run(&tables, || init(&tables, f));
+    stable_mir::run(&tables, || init(&tables, f))
 }
 
 #[macro_export]
@@ -241,7 +246,8 @@ macro_rules! run {
                 queries.global_ctxt().unwrap().enter(|tcx| {
                     rustc_internal::run(tcx, || {
                         self.result = Some((self.callback)(tcx));
-                    });
+                    })
+                    .unwrap();
                     if self.result.as_ref().is_some_and(|val| val.is_continue()) {
                         Compilation::Continue
                     } else {
diff --git a/compiler/stable_mir/src/lib.rs b/compiler/stable_mir/src/lib.rs
index f316671b278..63e9d54544b 100644
--- a/compiler/stable_mir/src/lib.rs
+++ b/compiler/stable_mir/src/lib.rs
@@ -47,7 +47,7 @@ pub type Symbol = String;
 pub type CrateNum = usize;
 
 /// A unique identification number for each item accessible for the current compilation unit.
-#[derive(Clone, Copy, PartialEq, Eq)]
+#[derive(Clone, Copy, PartialEq, Eq, Hash)]
 pub struct DefId(usize);
 
 impl Debug for DefId {
@@ -240,12 +240,16 @@ pub trait Context {
 // datastructures and stable MIR datastructures
 scoped_thread_local! (static TLV: Cell<*const ()>);
 
-pub fn run(context: &dyn Context, f: impl FnOnce()) {
-    assert!(!TLV.is_set());
-    let ptr: *const () = &context as *const &_ as _;
-    TLV.set(&Cell::new(ptr), || {
-        f();
-    });
+pub fn run<F, T>(context: &dyn Context, f: F) -> Result<T, Error>
+where
+    F: FnOnce() -> T,
+{
+    if TLV.is_set() {
+        Err(Error::from("StableMIR already running"))
+    } else {
+        let ptr: *const () = &context as *const &_ as _;
+        TLV.set(&Cell::new(ptr), || Ok(f()))
+    }
 }
 
 /// Loads the current context and calls a function with it.
@@ -260,7 +264,7 @@ pub fn with<R>(f: impl FnOnce(&dyn Context) -> R) -> R {
 }
 
 /// A type that provides internal information but that can still be used for debug purpose.
-#[derive(Clone, Eq, PartialEq)]
+#[derive(Clone, PartialEq, Eq, Hash)]
 pub struct Opaque(String);
 
 impl std::fmt::Display for Opaque {