about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_hir_analysis/src/check/wfcheck.rs2
-rw-r--r--compiler/rustc_middle/src/ty/util.rs22
-rw-r--r--compiler/rustc_trait_selection/src/traits/project.rs36
-rw-r--r--compiler/rustc_ty_utils/src/ty.rs61
-rw-r--r--tests/ui/async-await/in-trait/async-default-fn-overridden.rs66
-rw-r--r--tests/ui/async-await/in-trait/async-default-fn-overridden.stderr11
6 files changed, 161 insertions, 37 deletions
diff --git a/compiler/rustc_hir_analysis/src/check/wfcheck.rs b/compiler/rustc_hir_analysis/src/check/wfcheck.rs
index 66c3904af96..5743f086f89 100644
--- a/compiler/rustc_hir_analysis/src/check/wfcheck.rs
+++ b/compiler/rustc_hir_analysis/src/check/wfcheck.rs
@@ -1599,7 +1599,7 @@ fn check_return_position_impl_trait_in_trait_bounds<'tcx>(
     {
         for arg in fn_output.walk() {
             if let ty::GenericArgKind::Type(ty) = arg.unpack()
-                && let ty::Alias(ty::Projection, proj) = ty.kind()
+                && let ty::Alias(ty::Opaque, proj) = ty.kind()
                 && tcx.def_kind(proj.def_id) == DefKind::ImplTraitPlaceholder
                 && tcx.impl_trait_in_trait_parent(proj.def_id) == fn_def_id.to_def_id()
             {
diff --git a/compiler/rustc_middle/src/ty/util.rs b/compiler/rustc_middle/src/ty/util.rs
index a34ee1a99a1..ca46cf29919 100644
--- a/compiler/rustc_middle/src/ty/util.rs
+++ b/compiler/rustc_middle/src/ty/util.rs
@@ -4,7 +4,7 @@ use crate::middle::codegen_fn_attrs::CodegenFnAttrFlags;
 use crate::mir;
 use crate::ty::layout::IntegerExt;
 use crate::ty::{
-    self, ir::TypeFolder, DefIdTree, FallibleTypeFolder, Ty, TyCtxt, TypeFoldable,
+    self, ir::TypeFolder, DefIdTree, FallibleTypeFolder, ToPredicate, Ty, TyCtxt, TypeFoldable,
     TypeSuperFoldable,
 };
 use crate::ty::{GenericArgKind, SubstsRef};
@@ -865,6 +865,26 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for OpaqueTypeExpander<'tcx> {
         }
         t
     }
+
+    fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
+        if let ty::PredicateKind::Clause(clause) = p.kind().skip_binder()
+            && let ty::Clause::Projection(projection_pred) = clause
+        {
+            p.kind()
+                .rebind(ty::ProjectionPredicate {
+                    projection_ty: projection_pred.projection_ty.fold_with(self),
+                    // Don't fold the term on the RHS of the projection predicate.
+                    // This is because for default trait methods with RPITITs, we
+                    // install a `NormalizesTo(Projection(RPITIT) -> Opaque(RPITIT))`
+                    // predicate, which would trivially cause a cycle when we do
+                    // anything that requires `ParamEnv::with_reveal_all_normalized`.
+                    term: projection_pred.term,
+                })
+                .to_predicate(self.tcx)
+        } else {
+            p.super_fold_with(self)
+        }
+    }
 }
 
 impl<'tcx> Ty<'tcx> {
diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs
index 9b3249e58e8..1c66fb257eb 100644
--- a/compiler/rustc_trait_selection/src/traits/project.rs
+++ b/compiler/rustc_trait_selection/src/traits/project.rs
@@ -90,15 +90,7 @@ enum ProjectionCandidate<'tcx> {
     /// From an "impl" (or a "pseudo-impl" returned by select)
     Select(Selection<'tcx>),
 
-    ImplTraitInTrait(ImplTraitInTraitCandidate<'tcx>),
-}
-
-#[derive(PartialEq, Eq, Debug)]
-enum ImplTraitInTraitCandidate<'tcx> {
-    // The `impl Trait` from a trait function's default body
-    Trait,
-    // A concrete type provided from a trait's `impl Trait` from an impl
-    Impl(ImplSourceUserDefinedData<'tcx, PredicateObligation<'tcx>>),
+    ImplTraitInTrait(ImplSourceUserDefinedData<'tcx, PredicateObligation<'tcx>>),
 }
 
 enum ProjectionCandidateSet<'tcx> {
@@ -1292,17 +1284,6 @@ fn assemble_candidate_for_impl_trait_in_trait<'cx, 'tcx>(
     let tcx = selcx.tcx();
     if tcx.def_kind(obligation.predicate.def_id) == DefKind::ImplTraitPlaceholder {
         let trait_fn_def_id = tcx.impl_trait_in_trait_parent(obligation.predicate.def_id);
-        // If we are trying to project an RPITIT with trait's default `Self` parameter,
-        // then we must be within a default trait body.
-        if obligation.predicate.self_ty()
-            == ty::InternalSubsts::identity_for_item(tcx, obligation.predicate.def_id).type_at(0)
-            && tcx.associated_item(trait_fn_def_id).defaultness(tcx).has_value()
-        {
-            candidate_set.push_candidate(ProjectionCandidate::ImplTraitInTrait(
-                ImplTraitInTraitCandidate::Trait,
-            ));
-            return;
-        }
 
         let trait_def_id = tcx.parent(trait_fn_def_id);
         let trait_substs =
@@ -1313,9 +1294,7 @@ fn assemble_candidate_for_impl_trait_in_trait<'cx, 'tcx>(
         let _ = selcx.infcx.commit_if_ok(|_| {
             match selcx.select(&obligation.with(tcx, trait_predicate)) {
                 Ok(Some(super::ImplSource::UserDefined(data))) => {
-                    candidate_set.push_candidate(ProjectionCandidate::ImplTraitInTrait(
-                        ImplTraitInTraitCandidate::Impl(data),
-                    ));
+                    candidate_set.push_candidate(ProjectionCandidate::ImplTraitInTrait(data));
                     Ok(())
                 }
                 Ok(None) => {
@@ -1777,18 +1756,9 @@ fn confirm_candidate<'cx, 'tcx>(
         ProjectionCandidate::Select(impl_source) => {
             confirm_select_candidate(selcx, obligation, impl_source)
         }
-        ProjectionCandidate::ImplTraitInTrait(ImplTraitInTraitCandidate::Impl(data)) => {
+        ProjectionCandidate::ImplTraitInTrait(data) => {
             confirm_impl_trait_in_trait_candidate(selcx, obligation, data)
         }
-        // If we're projecting an RPITIT for a default trait body, that's just
-        // the same def-id, but as an opaque type (with regular RPIT semantics).
-        ProjectionCandidate::ImplTraitInTrait(ImplTraitInTraitCandidate::Trait) => Progress {
-            term: selcx
-                .tcx()
-                .mk_opaque(obligation.predicate.def_id, obligation.predicate.substs)
-                .into(),
-            obligations: vec![],
-        },
     };
 
     // When checking for cycle during evaluation, we compare predicates with
diff --git a/compiler/rustc_ty_utils/src/ty.rs b/compiler/rustc_ty_utils/src/ty.rs
index 2c50b766d21..f1af0073e4d 100644
--- a/compiler/rustc_ty_utils/src/ty.rs
+++ b/compiler/rustc_ty_utils/src/ty.rs
@@ -1,8 +1,12 @@
-use rustc_data_structures::fx::FxIndexSet;
+use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
 use rustc_hir as hir;
+use rustc_hir::def::DefKind;
 use rustc_index::bit_set::BitSet;
+#[cfg(not(bootstrap))]
+use rustc_middle::ty::ir::TypeVisitable;
 use rustc_middle::ty::{
-    self, Binder, EarlyBinder, Predicate, PredicateKind, ToPredicate, Ty, TyCtxt,
+    self, ir::TypeVisitor, Binder, EarlyBinder, Predicate, PredicateKind, ToPredicate, Ty, TyCtxt,
+    TypeSuperVisitable,
 };
 use rustc_session::config::TraitSolver;
 use rustc_span::def_id::{DefId, CRATE_DEF_ID};
@@ -136,6 +140,19 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
         predicates.extend(environment);
     }
 
+    if tcx.def_kind(def_id) == DefKind::AssocFn
+        && tcx.associated_item(def_id).container == ty::AssocItemContainer::TraitContainer
+    {
+        let sig = tcx.fn_sig(def_id).subst_identity();
+        sig.visit_with(&mut ImplTraitInTraitFinder {
+            tcx,
+            fn_def_id: def_id,
+            bound_vars: sig.bound_vars(),
+            predicates: &mut predicates,
+            seen: FxHashSet::default(),
+        });
+    }
+
     let local_did = def_id.as_local();
     let hir_id = local_did.map(|def_id| tcx.hir().local_def_id_to_hir_id(def_id));
 
@@ -222,6 +239,46 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
     traits::normalize_param_env_or_error(tcx, unnormalized_env, cause)
 }
 
+/// Walk through a function type, gathering all RPITITs and installing a
+/// `NormalizesTo(Projection(RPITIT) -> Opaque(RPITIT))` predicate into the
+/// predicates list. This allows us to observe that an RPITIT projects to
+/// its corresponding opaque within the body of a default-body trait method.
+struct ImplTraitInTraitFinder<'a, 'tcx> {
+    tcx: TyCtxt<'tcx>,
+    predicates: &'a mut Vec<Predicate<'tcx>>,
+    fn_def_id: DefId,
+    bound_vars: &'tcx ty::List<ty::BoundVariableKind>,
+    seen: FxHashSet<DefId>,
+}
+
+impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitFinder<'_, 'tcx> {
+    fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
+        if let ty::Alias(ty::Projection, alias_ty) = *ty.kind()
+            && self.tcx.def_kind(alias_ty.def_id) == DefKind::ImplTraitPlaceholder
+            && self.tcx.impl_trait_in_trait_parent(alias_ty.def_id) == self.fn_def_id
+            && self.seen.insert(alias_ty.def_id)
+        {
+            self.predicates.push(
+                ty::Binder::bind_with_vars(
+                    ty::ProjectionPredicate {
+                        projection_ty: alias_ty,
+                        term: self.tcx.mk_alias(ty::Opaque, alias_ty).into(),
+                    },
+                    self.bound_vars,
+                )
+                .to_predicate(self.tcx),
+            );
+
+            for bound in self.tcx.item_bounds(alias_ty.def_id).subst_iter(self.tcx, alias_ty.substs)
+            {
+                bound.visit_with(self);
+            }
+        }
+
+        ty.super_visit_with(self)
+    }
+}
+
 /// Elaborate the environment.
 ///
 /// Collect a list of `Predicate`'s used for building the `ParamEnv`. Adds `TypeWellFormedFromEnv`'s
diff --git a/tests/ui/async-await/in-trait/async-default-fn-overridden.rs b/tests/ui/async-await/in-trait/async-default-fn-overridden.rs
new file mode 100644
index 00000000000..0fd1a2703db
--- /dev/null
+++ b/tests/ui/async-await/in-trait/async-default-fn-overridden.rs
@@ -0,0 +1,66 @@
+// run-pass
+// edition:2021
+
+#![feature(async_fn_in_trait)]
+//~^ WARN the feature `async_fn_in_trait` is incomplete and may not be safe to use
+
+use std::future::Future;
+
+trait AsyncTrait {
+    async fn default_impl() {
+        assert!(false);
+    }
+
+    async fn call_default_impl() {
+        Self::default_impl().await
+    }
+}
+
+struct AsyncType;
+
+impl AsyncTrait for AsyncType {
+    async fn default_impl() {
+        // :)
+    }
+}
+
+async fn async_main() {
+    // Should not assert false
+    AsyncType::call_default_impl().await;
+}
+
+// ------------------------------------------------------------------------- //
+// Implementation Details Below...
+
+use std::pin::Pin;
+use std::task::*;
+
+pub fn noop_waker() -> Waker {
+    let raw = RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE);
+
+    // SAFETY: the contracts for RawWaker and RawWakerVTable are upheld
+    unsafe { Waker::from_raw(raw) }
+}
+
+const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
+
+unsafe fn noop_clone(_p: *const ()) -> RawWaker {
+    RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE)
+}
+
+unsafe fn noop(_p: *const ()) {}
+
+fn main() {
+    let mut fut = async_main();
+
+    // Poll loop, just to test the future...
+    let waker = noop_waker();
+    let ctx = &mut Context::from_waker(&waker);
+
+    loop {
+        match unsafe { Pin::new_unchecked(&mut fut).poll(ctx) } {
+            Poll::Pending => {}
+            Poll::Ready(()) => break,
+        }
+    }
+}
diff --git a/tests/ui/async-await/in-trait/async-default-fn-overridden.stderr b/tests/ui/async-await/in-trait/async-default-fn-overridden.stderr
new file mode 100644
index 00000000000..61a826258d0
--- /dev/null
+++ b/tests/ui/async-await/in-trait/async-default-fn-overridden.stderr
@@ -0,0 +1,11 @@
+warning: the feature `async_fn_in_trait` is incomplete and may not be safe to use and/or cause compiler crashes
+  --> $DIR/async-default-fn-overridden.rs:4:12
+   |
+LL | #![feature(async_fn_in_trait)]
+   |            ^^^^^^^^^^^^^^^^^
+   |
+   = note: see issue #91611 <https://github.com/rust-lang/rust/issues/91611> for more information
+   = note: `#[warn(incomplete_features)]` on by default
+
+warning: 1 warning emitted
+