about summary refs log tree commit diff
path: root/compiler/rustc_trait_selection
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_trait_selection')
-rw-r--r--compiler/rustc_trait_selection/src/traits/project.rs58
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/confirmation.rs74
2 files changed, 131 insertions, 1 deletions
diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs
index 49c34550f8e..9b8317fda75 100644
--- a/compiler/rustc_trait_selection/src/traits/project.rs
+++ b/compiler/rustc_trait_selection/src/traits/project.rs
@@ -7,8 +7,8 @@ use rustc_data_structures::stack::ensure_sufficient_stack;
 use rustc_errors::ErrorGuaranteed;
 use rustc_hir::def::DefKind;
 use rustc_hir::lang_items::LangItem;
-use rustc_infer::infer::DefineOpaqueTypes;
 use rustc_infer::infer::resolve::OpportunisticRegionResolver;
+use rustc_infer::infer::{DefineOpaqueTypes, RegionVariableOrigin};
 use rustc_infer::traits::{ObligationCauseCode, PredicateObligations};
 use rustc_middle::traits::select::OverflowError;
 use rustc_middle::traits::{BuiltinImplSource, ImplSource, ImplSourceUserDefinedData};
@@ -18,6 +18,7 @@ use rustc_middle::ty::visit::TypeVisitableExt;
 use rustc_middle::ty::{self, Term, Ty, TyCtxt, TypingMode, Upcast};
 use rustc_middle::{bug, span_bug};
 use rustc_span::symbol::sym;
+use thin_vec::thin_vec;
 use tracing::{debug, instrument};
 
 use super::{
@@ -61,6 +62,9 @@ enum ProjectionCandidate<'tcx> {
     /// Bounds specified on an object type
     Object(ty::PolyProjectionPredicate<'tcx>),
 
+    /// Built-in bound for a dyn async fn in trait
+    ObjectRpitit,
+
     /// From an "impl" (or a "pseudo-impl" returned by select)
     Select(Selection<'tcx>),
 }
@@ -827,6 +831,17 @@ fn assemble_candidates_from_object_ty<'cx, 'tcx>(
         env_predicates,
         false,
     );
+
+    // `dyn Trait` automagically project their AFITs to `dyn* Future`.
+    if tcx.is_impl_trait_in_trait(obligation.predicate.def_id)
+        && let Some(out_trait_def_id) = data.principal_def_id()
+        && let rpitit_trait_def_id = tcx.parent(obligation.predicate.def_id)
+        && tcx
+            .supertrait_def_ids(out_trait_def_id)
+            .any(|trait_def_id| trait_def_id == rpitit_trait_def_id)
+    {
+        candidate_set.push_candidate(ProjectionCandidate::ObjectRpitit);
+    }
 }
 
 #[instrument(
@@ -1247,6 +1262,8 @@ fn confirm_candidate<'cx, 'tcx>(
         ProjectionCandidate::Select(impl_source) => {
             confirm_select_candidate(selcx, obligation, impl_source)
         }
+
+        ProjectionCandidate::ObjectRpitit => confirm_object_rpitit_candidate(selcx, obligation),
     };
 
     // When checking for cycle during evaluation, we compare predicates with
@@ -2034,6 +2051,45 @@ fn confirm_impl_candidate<'cx, 'tcx>(
     }
 }
 
+fn confirm_object_rpitit_candidate<'cx, 'tcx>(
+    selcx: &mut SelectionContext<'cx, 'tcx>,
+    obligation: &ProjectionTermObligation<'tcx>,
+) -> Progress<'tcx> {
+    let tcx = selcx.tcx();
+    let mut obligations = thin_vec![];
+
+    // Compute an intersection lifetime for all the input components of this GAT.
+    let intersection =
+        selcx.infcx.next_region_var(RegionVariableOrigin::MiscVariable(obligation.cause.span));
+    for component in obligation.predicate.args {
+        match component.unpack() {
+            ty::GenericArgKind::Lifetime(lt) => {
+                obligations.push(obligation.with(tcx, ty::OutlivesPredicate(lt, intersection)));
+            }
+            ty::GenericArgKind::Type(ty) => {
+                obligations.push(obligation.with(tcx, ty::OutlivesPredicate(ty, intersection)));
+            }
+            ty::GenericArgKind::Const(_ct) => {
+                // Consts have no outlives...
+            }
+        }
+    }
+
+    Progress {
+        term: Ty::new_dynamic(
+            tcx,
+            tcx.item_bounds_to_existential_predicates(
+                obligation.predicate.def_id,
+                obligation.predicate.args,
+            ),
+            intersection,
+            ty::DynStar,
+        )
+        .into(),
+        obligations,
+    }
+}
+
 // Get obligations corresponding to the predicates from the where-clause of the
 // associated type itself.
 fn assoc_ty_own_obligations<'cx, 'tcx>(
diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
index 2fbe2e1e323..3664121ac4b 100644
--- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
@@ -19,6 +19,7 @@ use rustc_middle::traits::{BuiltinImplSource, SignatureMismatchData};
 use rustc_middle::ty::{self, GenericArgsRef, ToPolyTraitRef, Ty, TyCtxt, Upcast};
 use rustc_middle::{bug, span_bug};
 use rustc_span::def_id::DefId;
+use rustc_type_ir::elaborate;
 use tracing::{debug, instrument};
 
 use super::SelectionCandidate::{self, *};
@@ -624,6 +625,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         for assoc_type in assoc_types {
             let defs: &ty::Generics = tcx.generics_of(assoc_type);
 
+            // When `async_fn_in_dyn_trait` is enabled, we don't need to check the
+            // RPITIT for compatibility, since it's not provided by the user.
+            if tcx.features().async_fn_in_dyn_trait() && tcx.is_impl_trait_in_trait(assoc_type) {
+                continue;
+            }
+
             if !defs.own_params.is_empty() {
                 tcx.dcx().span_delayed_bug(
                     obligation.cause.span,
@@ -1175,6 +1182,33 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
                     ty::ClauseKind::TypeOutlives(outlives).upcast(tcx),
                 ));
 
+                // Require that all AFIT will return something that can be coerced into `dyn*`
+                // -- a shim will be responsible for doing the actual coercion to `dyn*`.
+                if let Some(principal) = data.principal() {
+                    for supertrait in
+                        elaborate::supertraits(tcx, principal.with_self_ty(tcx, source))
+                    {
+                        if tcx.is_trait_alias(supertrait.def_id()) {
+                            continue;
+                        }
+
+                        for &assoc_item in tcx.associated_item_def_ids(supertrait.def_id()) {
+                            if !tcx.is_impl_trait_in_trait(assoc_item) {
+                                continue;
+                            }
+
+                            let pointer_like_goal = pointer_like_goal_for_rpitit(
+                                tcx,
+                                supertrait,
+                                assoc_item,
+                                &obligation.cause,
+                            );
+
+                            nested.push(predicate_to_obligation(pointer_like_goal.upcast(tcx)));
+                        }
+                    }
+                }
+
                 ImplSource::Builtin(BuiltinImplSource::Misc, nested)
             }
 
@@ -1280,3 +1314,43 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         })
     }
 }
+
+/// Compute a goal that some RPITIT (right now, only RPITITs corresponding to Futures)
+/// implements the `PointerLike` trait, which is a requirement for the RPITIT to be
+/// coercible to `dyn* Future`, which is itself a requirement for the RPITIT's parent
+/// trait to be coercible to `dyn Trait`.
+///
+/// We do this given a supertrait's substitutions, and then augment the substitutions
+/// with bound variables to compute the goal universally. Given that `PointerLike` has
+/// no region requirements (at least for the built-in pointer types), this shouldn't
+/// *really* matter, but it is the best choice for soundness.
+fn pointer_like_goal_for_rpitit<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    supertrait: ty::PolyTraitRef<'tcx>,
+    rpitit_item: DefId,
+    cause: &ObligationCause<'tcx>,
+) -> ty::PolyTraitRef<'tcx> {
+    let mut bound_vars = supertrait.bound_vars().to_vec();
+
+    let args = supertrait.skip_binder().args.extend_to(tcx, rpitit_item, |arg, _| match arg.kind {
+        ty::GenericParamDefKind::Lifetime => {
+            let kind = ty::BoundRegionKind::Named(arg.def_id, tcx.item_name(arg.def_id));
+            bound_vars.push(ty::BoundVariableKind::Region(kind));
+            ty::Region::new_bound(tcx, ty::INNERMOST, ty::BoundRegion {
+                var: ty::BoundVar::from_usize(bound_vars.len() - 1),
+                kind,
+            })
+            .into()
+        }
+        ty::GenericParamDefKind::Type { .. } | ty::GenericParamDefKind::Const { .. } => {
+            unreachable!()
+        }
+    });
+
+    ty::Binder::bind_with_vars(
+        ty::TraitRef::new(tcx, tcx.require_lang_item(LangItem::PointerLike, Some(cause.span)), [
+            Ty::new_projection_from_args(tcx, rpitit_item, args),
+        ]),
+        tcx.mk_bound_variable_kinds(&bound_vars),
+    )
+}