about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-01-28 01:05:29 +0000
committerbors <bors@rust-lang.org>2023-01-28 01:05:29 +0000
commit6cd6bad51fb34a0d89e97c27814041fe4d0838b5 (patch)
tree12df4507b73b7fb3515178a55f0ef54cf2583a8a /compiler/rustc_mir_transform/src
parent7d4df2d30eb342af1ef136d83d70d281f34adcd7 (diff)
parentd3d626920abf2a4c93bd50640a9d66ce9d5a9009 (diff)
downloadrust-6cd6bad51fb34a0d89e97c27814041fe4d0838b5.tar.gz
rust-6cd6bad51fb34a0d89e97c27814041fe4d0838b5.zip
Auto merge of #101692 - cjgillot:generator-lazy-witness, r=oli-obk
Compute generator saved locals on MIR

Generators are currently type-checked by introducing a `witness` type variable, which is unified with a `GeneratorWitness(captured types)` whose purpose is to ensure that the auto traits correctly migrate from the captured types to the `witness` type.  This requires computing the captured types on HIR during type-checking, only to re-do it on MIR later.

This PR proposes to drop the HIR-based computation, and only keep the MIR one.  This is done in 3 steps.
1. During type-checking, the `witness` type variable is never unified.  This allows to stall all the obligations that depend on it until the end of type-checking.  Then, the stalled obligations are marked as successful, and saved into the typeck results for later verification.
2. At type-checking writeback, `witness` is replaced by `GeneratorWitnessMIR(def_id, substs)`.  From this point on, all trait selection involving `GeneratorWitnessMIR` will fetch the MIR-computed locals, similar to what opaque types do.  There is no lifetime to be preserved here: we consider all the lifetimes appearing in this witness type to be higher-ranked.
3. After borrowck, the stashed obligations are verified against the actually computed types, in the `check_generator_obligations` query.  If any obligation was wrongly marked as fulfilled in step 1, it should be reported here.

There are still many issues:
- ~I am not too happy having to filter out some locals from the checked bounds, I think this is MIR building that introduces raw pointers polluting the analysis;~ solved by a check specific to static variables.
- the diagnostics for captured types don't show where they are used/dropped;
- I do not attempt to support chalk.

cc `@eholk` `@jyn514` for the drop-tracking work
r? `@oli-obk` as you warned me of potential unsoundness
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/generator.rs340
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs4
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs4
3 files changed, 317 insertions, 31 deletions
diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs
index 39c61a34afc..e8871ff37f2 100644
--- a/compiler/rustc_mir_transform/src/generator.rs
+++ b/compiler/rustc_mir_transform/src/generator.rs
@@ -54,7 +54,8 @@ use crate::deref_separator::deref_finder;
 use crate::simplify;
 use crate::util::expand_aggregate;
 use crate::MirPass;
-use rustc_data_structures::fx::FxHashMap;
+use rustc_data_structures::fx::{FxHashMap, FxHashSet};
+use rustc_errors::pluralize;
 use rustc_hir as hir;
 use rustc_hir::lang_items::LangItem;
 use rustc_hir::GeneratorKind;
@@ -70,6 +71,9 @@ use rustc_mir_dataflow::impls::{
 };
 use rustc_mir_dataflow::storage::always_storage_live_locals;
 use rustc_mir_dataflow::{self, Analysis};
+use rustc_span::def_id::DefId;
+use rustc_span::symbol::sym;
+use rustc_span::Span;
 use rustc_target::abi::VariantIdx;
 use rustc_target::spec::PanicStrategy;
 use std::{iter, ops};
@@ -854,7 +858,7 @@ fn sanitize_witness<'tcx>(
     body: &Body<'tcx>,
     witness: Ty<'tcx>,
     upvars: Vec<Ty<'tcx>>,
-    saved_locals: &GeneratorSavedLocals,
+    layout: &GeneratorLayout<'tcx>,
 ) {
     let did = body.source.def_id();
     let param_env = tcx.param_env(did);
@@ -873,31 +877,36 @@ fn sanitize_witness<'tcx>(
         }
     };
 
-    for (local, decl) in body.local_decls.iter_enumerated() {
-        // Ignore locals which are internal or not saved between yields.
-        if !saved_locals.contains(local) || decl.internal {
+    let mut mismatches = Vec::new();
+    for fty in &layout.field_tys {
+        if fty.ignore_for_traits {
             continue;
         }
-        let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty);
+        let decl_ty = tcx.normalize_erasing_regions(param_env, fty.ty);
 
         // Sanity check that typeck knows about the type of locals which are
         // live across a suspension point
         if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) {
-            span_bug!(
-                body.span,
-                "Broken MIR: generator contains type {} in MIR, \
-                       but typeck only knows about {} and {:?}",
-                decl_ty,
-                allowed,
-                allowed_upvars
-            );
+            mismatches.push(decl_ty);
         }
     }
+
+    if !mismatches.is_empty() {
+        span_bug!(
+            body.span,
+            "Broken MIR: generator contains type {:?} in MIR, \
+                       but typeck only knows about {} and {:?}",
+            mismatches,
+            allowed,
+            allowed_upvars
+        );
+    }
 }
 
 fn compute_layout<'tcx>(
+    tcx: TyCtxt<'tcx>,
     liveness: LivenessInfo,
-    body: &mut Body<'tcx>,
+    body: &Body<'tcx>,
 ) -> (
     FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>,
     GeneratorLayout<'tcx>,
@@ -915,9 +924,33 @@ fn compute_layout<'tcx>(
     let mut locals = IndexVec::<GeneratorSavedLocal, _>::new();
     let mut tys = IndexVec::<GeneratorSavedLocal, _>::new();
     for (saved_local, local) in saved_locals.iter_enumerated() {
-        locals.push(local);
-        tys.push(body.local_decls[local].ty);
         debug!("generator saved local {:?} => {:?}", saved_local, local);
+
+        locals.push(local);
+        let decl = &body.local_decls[local];
+        debug!(?decl);
+
+        let ignore_for_traits = if tcx.sess.opts.unstable_opts.drop_tracking_mir {
+            match decl.local_info {
+                // Do not include raw pointers created from accessing `static` items, as those could
+                // well be re-created by another access to the same static.
+                Some(box LocalInfo::StaticRef { is_thread_local, .. }) => !is_thread_local,
+                // Fake borrows are only read by fake reads, so do not have any reality in
+                // post-analysis MIR.
+                Some(box LocalInfo::FakeBorrow) => true,
+                _ => false,
+            }
+        } else {
+            // FIXME(#105084) HIR-based drop tracking does not account for all the temporaries that
+            // MIR building may introduce. This leads to wrongly ignored types, but this is
+            // necessary for internal consistency and to avoid ICEs.
+            decl.internal
+        };
+        let decl =
+            GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
+        debug!(?decl);
+
+        tys.push(decl);
     }
 
     // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states.
@@ -947,7 +980,7 @@ fn compute_layout<'tcx>(
             // just use the first one here. That's fine; fields do not move
             // around inside generators, so it doesn't matter which variant
             // index we access them by.
-            remap.entry(locals[saved_local]).or_insert((tys[saved_local], variant_index, idx));
+            remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx));
         }
         variant_fields.push(fields);
         variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]);
@@ -957,6 +990,7 @@ fn compute_layout<'tcx>(
 
     let layout =
         GeneratorLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts };
+    debug!(?layout);
 
     (remap, layout, storage_liveness)
 }
@@ -1351,6 +1385,52 @@ fn create_cases<'tcx>(
         .collect()
 }
 
+#[instrument(level = "debug", skip(tcx), ret)]
+pub(crate) fn mir_generator_witnesses<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    def_id: DefId,
+) -> GeneratorLayout<'tcx> {
+    let def_id = def_id.expect_local();
+
+    let (body, _) = tcx.mir_promoted(ty::WithOptConstParam::unknown(def_id));
+    let body = body.borrow();
+    let body = &*body;
+
+    // The first argument is the generator type passed by value
+    let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
+
+    // Get the interior types and substs which typeck computed
+    let (upvars, interior, movable) = match *gen_ty.kind() {
+        ty::Generator(_, substs, movability) => {
+            let substs = substs.as_generator();
+            (
+                substs.upvar_tys().collect::<Vec<_>>(),
+                substs.witness(),
+                movability == hir::Movability::Movable,
+            )
+        }
+        _ => span_bug!(body.span, "unexpected generator type {}", gen_ty),
+    };
+
+    // When first entering the generator, move the resume argument into its new local.
+    let always_live_locals = always_storage_live_locals(&body);
+
+    let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
+
+    // Extract locals which are live across suspension point into `layout`
+    // `remap` gives a mapping from local indices onto generator struct indices
+    // `storage_liveness` tells us which locals have live storage at suspension points
+    let (_, generator_layout, _) = compute_layout(tcx, liveness_info, body);
+
+    if tcx.sess.opts.unstable_opts.drop_tracking_mir {
+        check_suspend_tys(tcx, &generator_layout, &body);
+    } else {
+        sanitize_witness(tcx, body, interior, upvars, &generator_layout);
+    }
+
+    generator_layout
+}
+
 impl<'tcx> MirPass<'tcx> for StateTransform {
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         let Some(yield_ty) = body.yield_ty() else {
@@ -1363,16 +1443,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         // The first argument is the generator type passed by value
         let gen_ty = body.local_decls.raw[1].ty;
 
-        // Get the interior types and substs which typeck computed
-        let (upvars, interior, discr_ty, movable) = match *gen_ty.kind() {
+        // Get the discriminant type and substs which typeck computed
+        let (discr_ty, movable) = match *gen_ty.kind() {
             ty::Generator(_, substs, movability) => {
                 let substs = substs.as_generator();
-                (
-                    substs.upvar_tys().collect(),
-                    substs.witness(),
-                    substs.discr_ty(tcx),
-                    movability == hir::Movability::Movable,
-                )
+                (substs.discr_ty(tcx), movability == hir::Movability::Movable)
             }
             _ => {
                 tcx.sess
@@ -1434,8 +1509,6 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         let liveness_info =
             locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
 
-        sanitize_witness(tcx, body, interior, upvars, &liveness_info.saved_locals);
-
         if tcx.sess.opts.unstable_opts.validate_mir {
             let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias {
                 assigned_local: None,
@@ -1449,7 +1522,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         // Extract locals which are live across suspension point into `layout`
         // `remap` gives a mapping from local indices onto generator struct indices
         // `storage_liveness` tells us which locals have live storage at suspension points
-        let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
+        let (remap, layout, storage_liveness) = compute_layout(tcx, liveness_info, body);
 
         let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id()));
 
@@ -1631,3 +1704,212 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
         }
     }
 }
+
+fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) {
+    let mut linted_tys = FxHashSet::default();
+
+    // We want a user-facing param-env.
+    let param_env = tcx.param_env(body.source.def_id());
+
+    for (variant, yield_source_info) in
+        layout.variant_fields.iter().zip(&layout.variant_source_info)
+    {
+        debug!(?variant);
+        for &local in variant {
+            let decl = &layout.field_tys[local];
+            debug!(?decl);
+
+            if !decl.ignore_for_traits && linted_tys.insert(decl.ty) {
+                let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else { continue };
+
+                check_must_not_suspend_ty(
+                    tcx,
+                    decl.ty,
+                    hir_id,
+                    param_env,
+                    SuspendCheckData {
+                        source_span: decl.source_info.span,
+                        yield_span: yield_source_info.span,
+                        plural_len: 1,
+                        ..Default::default()
+                    },
+                );
+            }
+        }
+    }
+}
+
+#[derive(Default)]
+struct SuspendCheckData<'a> {
+    source_span: Span,
+    yield_span: Span,
+    descr_pre: &'a str,
+    descr_post: &'a str,
+    plural_len: usize,
+}
+
+// Returns whether it emitted a diagnostic or not
+// Note that this fn and the proceeding one are based on the code
+// for creating must_use diagnostics
+//
+// Note that this technique was chosen over things like a `Suspend` marker trait
+// as it is simpler and has precedent in the compiler
+fn check_must_not_suspend_ty<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    ty: Ty<'tcx>,
+    hir_id: hir::HirId,
+    param_env: ty::ParamEnv<'tcx>,
+    data: SuspendCheckData<'_>,
+) -> bool {
+    if ty.is_unit() {
+        return false;
+    }
+
+    let plural_suffix = pluralize!(data.plural_len);
+
+    debug!("Checking must_not_suspend for {}", ty);
+
+    match *ty.kind() {
+        ty::Adt(..) if ty.is_box() => {
+            let boxed_ty = ty.boxed_ty();
+            let descr_pre = &format!("{}boxed ", data.descr_pre);
+            check_must_not_suspend_ty(
+                tcx,
+                boxed_ty,
+                hir_id,
+                param_env,
+                SuspendCheckData { descr_pre, ..data },
+            )
+        }
+        ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data),
+        // FIXME: support adding the attribute to TAITs
+        ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => {
+            let mut has_emitted = false;
+            for &(predicate, _) in tcx.explicit_item_bounds(def) {
+                // We only look at the `DefId`, so it is safe to skip the binder here.
+                if let ty::PredicateKind::Clause(ty::Clause::Trait(ref poly_trait_predicate)) =
+                    predicate.kind().skip_binder()
+                {
+                    let def_id = poly_trait_predicate.trait_ref.def_id;
+                    let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix);
+                    if check_must_not_suspend_def(
+                        tcx,
+                        def_id,
+                        hir_id,
+                        SuspendCheckData { descr_pre, ..data },
+                    ) {
+                        has_emitted = true;
+                        break;
+                    }
+                }
+            }
+            has_emitted
+        }
+        ty::Dynamic(binder, _, _) => {
+            let mut has_emitted = false;
+            for predicate in binder.iter() {
+                if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() {
+                    let def_id = trait_ref.def_id;
+                    let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post);
+                    if check_must_not_suspend_def(
+                        tcx,
+                        def_id,
+                        hir_id,
+                        SuspendCheckData { descr_post, ..data },
+                    ) {
+                        has_emitted = true;
+                        break;
+                    }
+                }
+            }
+            has_emitted
+        }
+        ty::Tuple(fields) => {
+            let mut has_emitted = false;
+            for (i, ty) in fields.iter().enumerate() {
+                let descr_post = &format!(" in tuple element {i}");
+                if check_must_not_suspend_ty(
+                    tcx,
+                    ty,
+                    hir_id,
+                    param_env,
+                    SuspendCheckData { descr_post, ..data },
+                ) {
+                    has_emitted = true;
+                }
+            }
+            has_emitted
+        }
+        ty::Array(ty, len) => {
+            let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix);
+            check_must_not_suspend_ty(
+                tcx,
+                ty,
+                hir_id,
+                param_env,
+                SuspendCheckData {
+                    descr_pre,
+                    plural_len: len.try_eval_usize(tcx, param_env).unwrap_or(0) as usize + 1,
+                    ..data
+                },
+            )
+        }
+        // If drop tracking is enabled, we want to look through references, since the referrent
+        // may not be considered live across the await point.
+        ty::Ref(_region, ty, _mutability) => {
+            let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix);
+            check_must_not_suspend_ty(
+                tcx,
+                ty,
+                hir_id,
+                param_env,
+                SuspendCheckData { descr_pre, ..data },
+            )
+        }
+        _ => false,
+    }
+}
+
+fn check_must_not_suspend_def(
+    tcx: TyCtxt<'_>,
+    def_id: DefId,
+    hir_id: hir::HirId,
+    data: SuspendCheckData<'_>,
+) -> bool {
+    if let Some(attr) = tcx.get_attr(def_id, sym::must_not_suspend) {
+        let msg = format!(
+            "{}`{}`{} held across a suspend point, but should not be",
+            data.descr_pre,
+            tcx.def_path_str(def_id),
+            data.descr_post,
+        );
+        tcx.struct_span_lint_hir(
+            rustc_session::lint::builtin::MUST_NOT_SUSPEND,
+            hir_id,
+            data.source_span,
+            msg,
+            |lint| {
+                // add span pointing to the offending yield/await
+                lint.span_label(data.yield_span, "the value is held across this suspend point");
+
+                // Add optional reason note
+                if let Some(note) = attr.value_str() {
+                    // FIXME(guswynn): consider formatting this better
+                    lint.span_note(data.source_span, note.as_str());
+                }
+
+                // Add some quick suggestions on what to do
+                // FIXME: can `drop` work as a suggestion here as well?
+                lint.span_help(
+                    data.source_span,
+                    "consider using a block (`{ ... }`) \
+                    to shrink the value's scope, ending before the suspend point",
+                )
+            },
+        );
+
+        true
+    } else {
+        false
+    }
+}
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 69627fc5cb2..84640b703c8 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -947,12 +947,12 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
                             return;
                         };
 
-                        let Some(&f_ty) = layout.field_tys.get(local) else {
+                        let Some(f_ty) = layout.field_tys.get(local) else {
                             self.validation = Err("malformed MIR");
                             return;
                         };
 
-                        f_ty
+                        f_ty.ty
                     } else {
                         let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else {
                             self.validation = Err("malformed MIR");
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 4a598862d10..fe3d5b1cce4 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -123,6 +123,7 @@ pub fn provide(providers: &mut Providers) {
         mir_drops_elaborated_and_const_checked,
         mir_for_ctfe,
         mir_for_ctfe_of_const_arg,
+        mir_generator_witnesses: generator::mir_generator_witnesses,
         optimized_mir,
         is_mir_available,
         is_ctfe_mir_available: |tcx, did| is_mir_available(tcx, did),
@@ -425,6 +426,9 @@ fn mir_drops_elaborated_and_const_checked(
         return tcx.mir_drops_elaborated_and_const_checked(def);
     }
 
+    if tcx.generator_kind(def.did).is_some() {
+        tcx.ensure().mir_generator_witnesses(def.did);
+    }
     let mir_borrowck = tcx.mir_borrowck_opt_const_arg(def);
 
     let is_fn_like = tcx.def_kind(def.did).is_fn_like();