about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src/shim.rs
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-11-16 20:18:13 +0000
committerMichael Goulet <michael@errs.io>2024-12-10 16:52:20 +0000
commita7fa4cbcb498b80b126a954b5944f19a11e28dec (patch)
tree462d7ced92d08e35a6cae86caeb804c8dbefb635 /compiler/rustc_mir_transform/src/shim.rs
parent3b057796264e637b31c28809322028a32d22ae6f (diff)
downloadrust-a7fa4cbcb498b80b126a954b5944f19a11e28dec.tar.gz
rust-a7fa4cbcb498b80b126a954b5944f19a11e28dec.zip
Implement projection and shim for AFIDT
Diffstat (limited to 'compiler/rustc_mir_transform/src/shim.rs')
-rw-r--r--compiler/rustc_mir_transform/src/shim.rs56
1 files changed, 53 insertions, 3 deletions
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index b8383e734e2..722da3c420d 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -9,6 +9,7 @@ use rustc_index::{Idx, IndexVec};
 use rustc_middle::mir::patch::MirPatch;
 use rustc_middle::mir::*;
 use rustc_middle::query::Providers;
+use rustc_middle::ty::adjustment::PointerCoercion;
 use rustc_middle::ty::{
     self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, GenericArgs, Ty, TyCtxt,
 };
@@ -710,6 +711,13 @@ fn build_call_shim<'tcx>(
     };
 
     let def_id = instance.def_id();
+
+    let rpitit_shim = if let ty::InstanceKind::ReifyShim(..) = instance {
+        tcx.return_position_impl_trait_in_trait_shim_data(def_id)
+    } else {
+        None
+    };
+
     let sig = tcx.fn_sig(def_id);
     let sig = sig.map_bound(|sig| tcx.instantiate_bound_regions_with_erased(sig));
 
@@ -765,9 +773,34 @@ fn build_call_shim<'tcx>(
     let mut local_decls = local_decls_for_sig(&sig, span);
     let source_info = SourceInfo::outermost(span);
 
+    let mut destination = Place::return_place();
+    if let Some((rpitit_def_id, fn_args)) = rpitit_shim {
+        let rpitit_args =
+            fn_args.instantiate_identity().extend_to(tcx, rpitit_def_id, |param, _| {
+                match param.kind {
+                    ty::GenericParamDefKind::Lifetime => tcx.lifetimes.re_erased.into(),
+                    ty::GenericParamDefKind::Type { .. }
+                    | ty::GenericParamDefKind::Const { .. } => {
+                        unreachable!("rpitit should have no addition ty/ct")
+                    }
+                }
+            });
+        let dyn_star_ty = Ty::new_dynamic(
+            tcx,
+            tcx.item_bounds_to_existential_predicates(rpitit_def_id, rpitit_args),
+            tcx.lifetimes.re_erased,
+            ty::DynStar,
+        );
+        destination = local_decls.push(local_decls[RETURN_PLACE].clone()).into();
+        local_decls[RETURN_PLACE].ty = dyn_star_ty;
+        let mut inputs_and_output = sig.inputs_and_output.to_vec();
+        *inputs_and_output.last_mut().unwrap() = dyn_star_ty;
+        sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output);
+    }
+
     let rcvr_place = || {
         assert!(rcvr_adjustment.is_some());
-        Place::from(Local::new(1 + 0))
+        Place::from(Local::new(1))
     };
     let mut statements = vec![];
 
@@ -854,7 +887,7 @@ fn build_call_shim<'tcx>(
         TerminatorKind::Call {
             func: callee,
             args,
-            destination: Place::return_place(),
+            destination,
             target: Some(BasicBlock::new(1)),
             unwind: if let Some(Adjustment::RefMut) = rcvr_adjustment {
                 UnwindAction::Cleanup(BasicBlock::new(3))
@@ -882,7 +915,24 @@ fn build_call_shim<'tcx>(
         );
     }
     // BB #1/#2 - return
-    block(&mut blocks, vec![], TerminatorKind::Return, false);
+    // NOTE: If this is an RPITIT in dyn, we also want to coerce
+    // the return type of the function into a `dyn*`.
+    let stmts = if rpitit_shim.is_some() {
+        vec![Statement {
+            source_info,
+            kind: StatementKind::Assign(Box::new((
+                Place::return_place(),
+                Rvalue::Cast(
+                    CastKind::PointerCoercion(PointerCoercion::DynStar, CoercionSource::Implicit),
+                    Operand::Move(destination),
+                    sig.output(),
+                ),
+            ))),
+        }]
+    } else {
+        vec![]
+    };
+    block(&mut blocks, stmts, TerminatorKind::Return, false);
     if let Some(Adjustment::RefMut) = rcvr_adjustment {
         // BB #3 - drop if closure panics
         block(