about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/lib.rs')
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs71
1 files changed, 68 insertions, 3 deletions
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 0b674b38f32..162f7d969b1 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -10,6 +10,7 @@
 #![feature(trusted_step)]
 #![feature(try_blocks)]
 #![feature(yeet_expr)]
+#![feature(if_let_guard)]
 #![recursion_limit = "256"]
 
 #[macro_use]
@@ -27,10 +28,13 @@ use rustc_hir::intravisit::{self, Visitor};
 use rustc_index::vec::IndexVec;
 use rustc_middle::mir::visit::Visitor as _;
 use rustc_middle::mir::{
-    traversal, AnalysisPhase, Body, ConstQualifs, MirPass, MirPhase, Promoted, RuntimePhase,
+    traversal, AnalysisPhase, Body, ConstQualifs, Constant, LocalDecl, MirPass, MirPhase, Operand,
+    Place, ProjectionElem, Promoted, RuntimePhase, Rvalue, SourceInfo, Statement, StatementKind,
+    TerminatorKind,
 };
 use rustc_middle::ty::query::Providers;
 use rustc_middle::ty::{self, TyCtxt, TypeVisitable};
+use rustc_span::sym;
 
 #[macro_use]
 mod pass_manager;
@@ -140,6 +144,64 @@ pub fn provide(providers: &mut Providers) {
     };
 }
 
+fn remap_mir_for_const_eval_select<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    mut body: Body<'tcx>,
+    context: hir::Constness,
+) -> Body<'tcx> {
+    for bb in body.basic_blocks.as_mut().iter_mut() {
+        let terminator = bb.terminator.as_mut().expect("invalid terminator");
+        match terminator.kind {
+            TerminatorKind::Call {
+                func: Operand::Constant(box Constant { ref literal, .. }),
+                ref mut args,
+                destination,
+                target,
+                cleanup,
+                fn_span,
+                ..
+            } if let ty::FnDef(def_id, _) = *literal.ty().kind()
+                && tcx.item_name(def_id) == sym::const_eval_select
+                && tcx.is_intrinsic(def_id) =>
+            {
+                let [tupled_args, called_in_const, called_at_rt]: [_; 3] = std::mem::take(args).try_into().unwrap();
+                let ty = tupled_args.ty(&body.local_decls, tcx);
+                let fields = ty.tuple_fields();
+                let num_args = fields.len();
+                let func = if context == hir::Constness::Const { called_in_const } else { called_at_rt };
+                let (method, place): (fn(Place<'tcx>) -> Operand<'tcx>, Place<'tcx>) = match tupled_args {
+                    Operand::Constant(_) => {
+                        // there is no good way of extracting a tuple arg from a constant (const generic stuff)
+                        // so we just create a temporary and deconstruct that.
+                        let local = body.local_decls.push(LocalDecl::new(ty, fn_span));
+                        bb.statements.push(Statement {
+                            source_info: SourceInfo::outermost(fn_span),
+                            kind: StatementKind::Assign(Box::new((local.into(), Rvalue::Use(tupled_args.clone())))),
+                        });
+                        (Operand::Move, local.into())
+                    }
+                    Operand::Move(place) => (Operand::Move, place),
+                    Operand::Copy(place) => (Operand::Copy, place),
+                };
+                let place_elems = place.projection;
+                let arguments = (0..num_args).map(|x| {
+                    let mut place_elems = place_elems.to_vec();
+                    place_elems.push(ProjectionElem::Field(x.into(), fields[x]));
+                    let projection = tcx.intern_place_elems(&place_elems);
+                    let place = Place {
+                        local: place.local,
+                        projection,
+                    };
+                    method(place)
+                }).collect();
+                terminator.kind = TerminatorKind::Call { func, args: arguments, destination, target, cleanup, from_hir_call: false, fn_span };
+            }
+            _ => {}
+        }
+    }
+    body
+}
+
 fn is_mir_available(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
     let def_id = def_id.expect_local();
     tcx.mir_keys(()).contains(&def_id)
@@ -325,7 +387,9 @@ fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) -
         .body_const_context(def.did)
         .expect("mir_for_ctfe should not be used for runtime functions");
 
-    let mut body = tcx.mir_drops_elaborated_and_const_checked(def).borrow().clone();
+    let body = tcx.mir_drops_elaborated_and_const_checked(def).borrow().clone();
+
+    let mut body = remap_mir_for_const_eval_select(tcx, body, hir::Constness::Const);
 
     match context {
         // Do not const prop functions, either they get executed at runtime or exported to metadata,
@@ -558,8 +622,9 @@ fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> {
         Some(other) => panic!("do not use `optimized_mir` for constants: {:?}", other),
     }
     debug!("about to call mir_drops_elaborated...");
-    let mut body =
+    let body =
         tcx.mir_drops_elaborated_and_const_checked(ty::WithOptConstParam::unknown(did)).steal();
+    let mut body = remap_mir_for_const_eval_select(tcx, body, hir::Constness::NotConst);
     debug!("body: {:#?}", body);
     run_optimization_passes(tcx, &mut body);