about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_trait_selection/src/solve/trait_goals.rs23
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs95
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/mod.rs13
-rw-r--r--tests/ui/traits/trait-upcasting/add-supertrait-auto-traits.rs14
4 files changed, 93 insertions, 52 deletions
diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs
index eacdd9fde51..73bf66f6689 100644
--- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs
+++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs
@@ -1,7 +1,10 @@
 //! Dealing with trait goals, i.e. `T: Trait<'a, U>`.
 
+use crate::traits::supertrait_def_ids;
+
 use super::assembly::{self, structural_traits, Candidate};
 use super::{EvalCtxt, GoalSource, SolverMode};
+use rustc_data_structures::fx::FxIndexSet;
 use rustc_hir::def_id::DefId;
 use rustc_hir::{LangItem, Movability};
 use rustc_infer::traits::query::NoSolution;
@@ -663,13 +666,6 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
         let tcx = self.tcx();
         let Goal { predicate: (a_ty, _b_ty), .. } = goal;
 
-        // All of a's auto traits need to be in b's auto traits.
-        let auto_traits_compatible =
-            b_data.auto_traits().all(|b| a_data.auto_traits().any(|a| a == b));
-        if !auto_traits_compatible {
-            return vec![];
-        }
-
         let mut responses = vec![];
         // If the principal def ids match (or are both none), then we're not doing
         // trait upcasting. We're just removing auto traits (or shortening the lifetime).
@@ -757,6 +753,17 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
     ) -> QueryResult<'tcx> {
         let param_env = goal.param_env;
 
+        // We may upcast to auto traits that are either explicitly listed in
+        // the object type's bounds, or implied by the principal trait ref's
+        // supertraits.
+        let a_auto_traits: FxIndexSet<DefId> = a_data
+            .auto_traits()
+            .chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
+                supertrait_def_ids(self.tcx(), principal_def_id)
+                    .filter(|def_id| self.tcx().trait_is_auto(*def_id))
+            }))
+            .collect();
+
         // More than one projection in a_ty's bounds may match the projection
         // in b_ty's bound. Use this to first determine *which* apply without
         // having any inference side-effects. We process obligations because
@@ -806,7 +813,7 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
                 }
                 // Check that b_ty's auto traits are present in a_ty's bounds.
                 ty::ExistentialPredicate::AutoTrait(def_id) => {
-                    if !a_data.auto_traits().any(|source_def_id| source_def_id == def_id) {
+                    if !a_auto_traits.contains(&def_id) {
                         return Err(NoSolution);
                     }
                 }
diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
index 149dc4c75a7..39f4ceda9f1 100644
--- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
@@ -10,7 +10,7 @@ use std::ops::ControlFlow;
 
 use hir::def_id::DefId;
 use hir::LangItem;
-use rustc_data_structures::fx::FxHashSet;
+use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
 use rustc_hir as hir;
 use rustc_infer::traits::ObligationCause;
 use rustc_infer::traits::{Obligation, PolyTraitObligation, SelectionError};
@@ -968,52 +968,61 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
                 //
                 // We always perform upcasting coercions when we can because of reason
                 // #2 (region bounds).
-                let auto_traits_compatible = b_data
-                    .auto_traits()
-                    // All of a's auto traits need to be in b's auto traits.
-                    .all(|b| a_data.auto_traits().any(|a| a == b));
-                if auto_traits_compatible {
-                    let principal_def_id_a = a_data.principal_def_id();
-                    let principal_def_id_b = b_data.principal_def_id();
-                    if principal_def_id_a == principal_def_id_b {
-                        // no cyclic
+                let principal_def_id_a = a_data.principal_def_id();
+                let principal_def_id_b = b_data.principal_def_id();
+                if principal_def_id_a == principal_def_id_b {
+                    // We may upcast to auto traits that are either explicitly listed in
+                    // the object type's bounds, or implied by the principal trait ref's
+                    // supertraits.
+                    let a_auto_traits: FxIndexSet<DefId> = a_data
+                        .auto_traits()
+                        .chain(principal_def_id_a.into_iter().flat_map(|principal_def_id| {
+                            util::supertrait_def_ids(self.tcx(), principal_def_id)
+                                .filter(|def_id| self.tcx().trait_is_auto(*def_id))
+                        }))
+                        .collect();
+                    let auto_traits_compatible = b_data
+                        .auto_traits()
+                        // All of a's auto traits need to be in b's auto traits.
+                        .all(|b| a_auto_traits.contains(&b));
+                    if auto_traits_compatible {
                         candidates.vec.push(BuiltinUnsizeCandidate);
-                    } else if principal_def_id_a.is_some() && principal_def_id_b.is_some() {
-                        // not casual unsizing, now check whether this is trait upcasting coercion.
-                        let principal_a = a_data.principal().unwrap();
-                        let target_trait_did = principal_def_id_b.unwrap();
-                        let source_trait_ref = principal_a.with_self_ty(self.tcx(), source);
-                        if let Some(deref_trait_ref) = self.need_migrate_deref_output_trait_object(
-                            source,
-                            obligation.param_env,
-                            &obligation.cause,
-                        ) {
-                            if deref_trait_ref.def_id() == target_trait_did {
-                                return;
-                            }
+                    }
+                } else if principal_def_id_a.is_some() && principal_def_id_b.is_some() {
+                    // not casual unsizing, now check whether this is trait upcasting coercion.
+                    let principal_a = a_data.principal().unwrap();
+                    let target_trait_did = principal_def_id_b.unwrap();
+                    let source_trait_ref = principal_a.with_self_ty(self.tcx(), source);
+                    if let Some(deref_trait_ref) = self.need_migrate_deref_output_trait_object(
+                        source,
+                        obligation.param_env,
+                        &obligation.cause,
+                    ) {
+                        if deref_trait_ref.def_id() == target_trait_did {
+                            return;
                         }
+                    }
 
-                        for (idx, upcast_trait_ref) in
-                            util::supertraits(self.tcx(), source_trait_ref).enumerate()
-                        {
-                            self.infcx.probe(|_| {
-                                if upcast_trait_ref.def_id() == target_trait_did
-                                    && let Ok(nested) = self.match_upcast_principal(
-                                        obligation,
-                                        upcast_trait_ref,
-                                        a_data,
-                                        b_data,
-                                        a_region,
-                                        b_region,
-                                    )
-                                {
-                                    if nested.is_none() {
-                                        candidates.ambiguous = true;
-                                    }
-                                    candidates.vec.push(TraitUpcastingUnsizeCandidate(idx));
+                    for (idx, upcast_trait_ref) in
+                        util::supertraits(self.tcx(), source_trait_ref).enumerate()
+                    {
+                        self.infcx.probe(|_| {
+                            if upcast_trait_ref.def_id() == target_trait_did
+                                && let Ok(nested) = self.match_upcast_principal(
+                                    obligation,
+                                    upcast_trait_ref,
+                                    a_data,
+                                    b_data,
+                                    a_region,
+                                    b_region,
+                                )
+                            {
+                                if nested.is_none() {
+                                    candidates.ambiguous = true;
                                 }
-                            })
-                        }
+                                candidates.vec.push(TraitUpcastingUnsizeCandidate(idx));
+                            }
+                        })
                     }
                 }
             }
diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index ab0c53e6a9b..5bcf46a96ed 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -2526,6 +2526,17 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
         let tcx = self.tcx();
         let mut nested = vec![];
 
+        // We may upcast to auto traits that are either explicitly listed in
+        // the object type's bounds, or implied by the principal trait ref's
+        // supertraits.
+        let a_auto_traits: FxIndexSet<DefId> = a_data
+            .auto_traits()
+            .chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
+                util::supertrait_def_ids(tcx, principal_def_id)
+                    .filter(|def_id| tcx.trait_is_auto(*def_id))
+            }))
+            .collect();
+
         let upcast_principal = normalize_with_depth_to(
             self,
             obligation.param_env,
@@ -2588,7 +2599,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
                 }
                 // Check that b_ty's auto traits are present in a_ty's bounds.
                 ty::ExistentialPredicate::AutoTrait(def_id) => {
-                    if !a_data.auto_traits().any(|source_def_id| source_def_id == def_id) {
+                    if !a_auto_traits.contains(&def_id) {
                         return Err(SelectionError::Unimplemented);
                     }
                 }
diff --git a/tests/ui/traits/trait-upcasting/add-supertrait-auto-traits.rs b/tests/ui/traits/trait-upcasting/add-supertrait-auto-traits.rs
new file mode 100644
index 00000000000..7e242ed9126
--- /dev/null
+++ b/tests/ui/traits/trait-upcasting/add-supertrait-auto-traits.rs
@@ -0,0 +1,14 @@
+// check-pass
+// revisions: current next
+//[next] compile-flags: -Znext-solver
+
+#![feature(trait_upcasting)]
+
+trait Target {}
+trait Source: Send + Target {}
+
+fn upcast(x: &dyn Source) -> &(dyn Target + Send) { x }
+
+fn same(x: &dyn Source) -> &(dyn Source + Send) { x }
+
+fn main() {}