about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/coroutine.rs7
-rw-r--r--compiler/rustc_mir_transform/src/gvn.rs2
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs2
-rw-r--r--compiler/rustc_mir_transform/src/inline/cycle.rs2
-rw-r--r--compiler/rustc_mir_transform/src/jump_threading.rs19
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs2
-rw-r--r--compiler/rustc_mir_transform/src/promote_consts.rs2
-rw-r--r--compiler/rustc_mir_transform/src/shim.rs49
8 files changed, 52 insertions, 33 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index 05674792426..4c00038365b 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -208,11 +208,8 @@ const UNRESUMED: usize = CoroutineArgs::UNRESUMED;
 const RETURNED: usize = CoroutineArgs::RETURNED;
 /// Coroutine has panicked and is poisoned.
 const POISONED: usize = CoroutineArgs::POISONED;
-
-/// Number of variants to reserve in coroutine state. Corresponds to
-/// `UNRESUMED` (beginning of a coroutine) and `RETURNED`/`POISONED`
-/// (end of a coroutine) states.
-const RESERVED_VARIANTS: usize = 3;
+/// Number of reserved variants of coroutine state.
+const RESERVED_VARIANTS: usize = CoroutineArgs::RESERVED_VARIANTS;
 
 /// A `yield` point in the coroutine.
 struct SuspensionPoint<'tcx> {
diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs
index 3dbdeb615cf..2b7d9be6d35 100644
--- a/compiler/rustc_mir_transform/src/gvn.rs
+++ b/compiler/rustc_mir_transform/src/gvn.rs
@@ -8,7 +8,7 @@
 //! `Value` is interned as a `VnIndex`, which allows us to cheaply compute identical values.
 //!
 //! From those assignments, we construct a mapping `VnIndex -> Vec<(Local, Location)>` of available
-//! values, the locals in which they are stored, and a the assignment location.
+//! values, the locals in which they are stored, and the assignment location.
 //!
 //! In a second pass, we traverse all (non SSA) assignments `x = rvalue` and operands. For each
 //! one, we compute the `VnIndex` of the rvalue. If this `VnIndex` is associated to a constant, we
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 6fa31c1174d..5075e072754 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -389,7 +389,7 @@ impl<'tcx> Inliner<'tcx> {
                 // To resolve an instance its args have to be fully normalized.
                 let args = self.tcx.try_normalize_erasing_regions(self.param_env, args).ok()?;
                 let callee =
-                    Instance::resolve(self.tcx, self.param_env, def_id, args).ok().flatten()?;
+                    Instance::try_resolve(self.tcx, self.param_env, def_id, args).ok().flatten()?;
 
                 if let InstanceKind::Virtual(..) | InstanceKind::Intrinsic(_) = callee.def {
                     return None;
diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs
index 35bcd24ce95..d4477563e3a 100644
--- a/compiler/rustc_mir_transform/src/inline/cycle.rs
+++ b/compiler/rustc_mir_transform/src/inline/cycle.rs
@@ -53,7 +53,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
                 trace!(?caller, ?param_env, ?args, "cannot normalize, skipping");
                 continue;
             };
-            let Ok(Some(callee)) = ty::Instance::resolve(tcx, param_env, callee, args) else {
+            let Ok(Some(callee)) = ty::Instance::try_resolve(tcx, param_env, callee, args) else {
                 trace!(?callee, "cannot resolve, skipping");
                 continue;
             };
diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs
index 27e506a920b..85a61f95ca3 100644
--- a/compiler/rustc_mir_transform/src/jump_threading.rs
+++ b/compiler/rustc_mir_transform/src/jump_threading.rs
@@ -47,6 +47,7 @@ use rustc_middle::mir::visit::Visitor;
 use rustc_middle::mir::*;
 use rustc_middle::ty::layout::LayoutOf;
 use rustc_middle::ty::{self, ScalarInt, TyCtxt};
+use rustc_mir_dataflow::lattice::HasBottom;
 use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
 use rustc_span::DUMMY_SP;
 use rustc_target::abi::{TagEncoding, Variants};
@@ -158,9 +159,17 @@ impl Condition {
     }
 }
 
-#[derive(Copy, Clone, Debug, Default)]
+#[derive(Copy, Clone, Debug)]
 struct ConditionSet<'a>(&'a [Condition]);
 
+impl HasBottom for ConditionSet<'_> {
+    const BOTTOM: Self = ConditionSet(&[]);
+
+    fn is_bottom(&self) -> bool {
+        self.0.is_empty()
+    }
+}
+
 impl<'a> ConditionSet<'a> {
     fn iter(self) -> impl Iterator<Item = Condition> + 'a {
         self.0.iter().copied()
@@ -177,7 +186,7 @@ impl<'a> ConditionSet<'a> {
 
 impl<'tcx, 'a> TOFinder<'tcx, 'a> {
     fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool {
-        state.all(|cs| cs.0.is_empty())
+        state.all_bottom()
     }
 
     /// Recursion entry point to find threading opportunities.
@@ -198,7 +207,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
         debug!(?discr);
 
         let cost = CostChecker::new(self.tcx, self.param_env, None, self.body);
-        let mut state = State::new(ConditionSet::default(), self.map);
+        let mut state = State::new_reachable();
 
         let conds = if let Some((value, then, else_)) = targets.as_static_if() {
             let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
@@ -255,7 +264,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
             //   _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
             //   _1 = 6
             if let Some((lhs, tail)) = self.mutated_statement(stmt) {
-                state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::default());
+                state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::BOTTOM);
             }
         }
 
@@ -609,7 +618,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
         // We can recurse through this terminator.
         let mut state = state();
         if let Some(place_to_flood) = place_to_flood {
-            state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::default());
+            state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::BOTTOM);
         }
         self.find_opportunity(bb, state, cost.clone(), depth + 1);
     }
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index f7056702cb4..5d253d7384d 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -519,7 +519,7 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         &add_subtyping_projections::Subtyper, // calling this after reveal_all ensures that we don't deal with opaque types
         &elaborate_drops::ElaborateDrops,
         // This will remove extraneous landing pads which are no longer
-        // necessary as well as well as forcing any call in a non-unwinding
+        // necessary as well as forcing any call in a non-unwinding
         // function calling a possibly-unwinding function to abort the process.
         &abort_unwinding_calls::AbortUnwindingCalls,
         // AddMovesForPackedDrops needs to run after drop
diff --git a/compiler/rustc_mir_transform/src/promote_consts.rs b/compiler/rustc_mir_transform/src/promote_consts.rs
index 3f4d2b65ff2..736647fb64b 100644
--- a/compiler/rustc_mir_transform/src/promote_consts.rs
+++ b/compiler/rustc_mir_transform/src/promote_consts.rs
@@ -816,7 +816,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
                     mut func, mut args, call_source: desugar, fn_span, ..
                 } => {
                     // This promoted involves a function call, so it may fail to evaluate.
-                    // Let's make sure it is added to `required_consts` so that that failure cannot get lost.
+                    // Let's make sure it is added to `required_consts` so that failure cannot get lost.
                     self.add_to_required = true;
 
                     self.visit_operand(&mut func, loc);
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index 25577e88e28..6835a39cf36 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -1,18 +1,17 @@
 use rustc_hir as hir;
 use rustc_hir::def_id::DefId;
 use rustc_hir::lang_items::LangItem;
+use rustc_index::{Idx, IndexVec};
 use rustc_middle::mir::*;
 use rustc_middle::query::Providers;
 use rustc_middle::ty::GenericArgs;
 use rustc_middle::ty::{self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, Ty, TyCtxt};
 use rustc_middle::{bug, span_bug};
-use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};
-
-use rustc_index::{Idx, IndexVec};
-
 use rustc_span::{source_map::Spanned, Span, DUMMY_SP};
+use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};
 use rustc_target::spec::abi::Abi;
 
+use std::assert_matches::assert_matches;
 use std::fmt;
 use std::iter;
 
@@ -1020,21 +1019,19 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
     receiver_by_ref: bool,
 ) -> Body<'tcx> {
     let mut self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
+    let mut self_local: Place<'tcx> = Local::from_usize(1).into();
     let ty::CoroutineClosure(_, args) = *self_ty.kind() else {
         bug!();
     };
 
-    // We use `&mut Self` here because we only need to emit an ABI-compatible shim body,
-    // rather than match the signature exactly (which might take `&self` instead).
+    // We use `&Self` here because we only need to emit an ABI-compatible shim body,
+    // rather than match the signature exactly (which might take `&mut self` instead).
     //
-    // The self type here is a coroutine-closure, not a coroutine, and we never read from
-    // it because it never has any captures, because this is only true in the Fn/FnMut
-    // implementation, not the AsyncFn/AsyncFnMut implementation, which is implemented only
-    // if the coroutine-closure has no captures.
+    // We adjust the `self_local` to be a deref since we want to copy fields out of
+    // a reference to the closure.
     if receiver_by_ref {
-        // Triple-check that there's no captures here.
-        assert_eq!(args.as_coroutine_closure().tupled_upvars_ty(), tcx.types.unit);
-        self_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, self_ty);
+        self_local = tcx.mk_place_deref(self_local);
+        self_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, self_ty);
     }
 
     let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
@@ -1067,11 +1064,27 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
         fields.push(Operand::Move(Local::from_usize(idx + 1).into()));
     }
     for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() {
-        fields.push(Operand::Move(tcx.mk_place_field(
-            Local::from_usize(1).into(),
-            FieldIdx::from_usize(idx),
-            ty,
-        )));
+        if receiver_by_ref {
+            // The only situation where it's possible is when we capture immuatable references,
+            // since those don't need to be reborrowed with the closure's env lifetime. Since
+            // references are always `Copy`, just emit a copy.
+            assert_matches!(
+                ty.kind(),
+                ty::Ref(_, _, hir::Mutability::Not),
+                "field should be captured by immutable ref if we have an `Fn` instance"
+            );
+            fields.push(Operand::Copy(tcx.mk_place_field(
+                self_local,
+                FieldIdx::from_usize(idx),
+                ty,
+            )));
+        } else {
+            fields.push(Operand::Move(tcx.mk_place_field(
+                self_local,
+                FieldIdx::from_usize(idx),
+                ty,
+            )));
+        }
     }
 
     let source_info = SourceInfo::outermost(span);