about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_hir/src/lang_items.rs1
-rw-r--r--compiler/rustc_middle/src/traits/mod.rs1
-rw-r--r--compiler/rustc_middle/src/traits/util.rs49
-rw-r--r--compiler/rustc_middle/src/ty/print/pretty.rs303
-rw-r--r--compiler/rustc_middle/src/ty/sty.rs2
-rw-r--r--compiler/rustc_span/src/symbol.rs1
6 files changed, 282 insertions, 75 deletions
diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs
index 3037996d48b..05659e976dd 100644
--- a/compiler/rustc_hir/src/lang_items.rs
+++ b/compiler/rustc_hir/src/lang_items.rs
@@ -268,6 +268,7 @@ language_item_table! {
     Future,                  sym::future_trait,        future_trait,               Target::Trait,          GenericRequirement::Exact(0);
     GeneratorState,          sym::generator_state,     gen_state,                  Target::Enum,           GenericRequirement::None;
     Generator,               sym::generator,           gen_trait,                  Target::Trait,          GenericRequirement::Minimum(1);
+    GeneratorReturn,         sym::generator_return,    generator_return,           Target::AssocTy,        GenericRequirement::None;
     Unpin,                   sym::unpin,               unpin_trait,                Target::Trait,          GenericRequirement::None;
     Pin,                     sym::pin,                 pin_type,                   Target::Struct,         GenericRequirement::None;
 
diff --git a/compiler/rustc_middle/src/traits/mod.rs b/compiler/rustc_middle/src/traits/mod.rs
index 49071e7995b..49a64cb246a 100644
--- a/compiler/rustc_middle/src/traits/mod.rs
+++ b/compiler/rustc_middle/src/traits/mod.rs
@@ -7,6 +7,7 @@ pub mod query;
 pub mod select;
 pub mod specialization_graph;
 mod structural_impls;
+pub mod util;
 
 use crate::infer::canonical::Canonical;
 use crate::thir::abstract_const::NotConstEvaluatable;
diff --git a/compiler/rustc_middle/src/traits/util.rs b/compiler/rustc_middle/src/traits/util.rs
new file mode 100644
index 00000000000..3490c688170
--- /dev/null
+++ b/compiler/rustc_middle/src/traits/util.rs
@@ -0,0 +1,49 @@
+use rustc_data_structures::stable_set::FxHashSet;
+
+use crate::ty::{PolyTraitRef, TyCtxt};
+
+/// Given a PolyTraitRef, get the PolyTraitRefs of the trait's (transitive) supertraits.
+///
+/// A simplfied version of the same function at `rustc_infer::traits::util::supertraits`.
+pub fn supertraits<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    trait_ref: PolyTraitRef<'tcx>,
+) -> impl Iterator<Item = PolyTraitRef<'tcx>> {
+    Elaborator { tcx, visited: FxHashSet::from_iter([trait_ref]), stack: vec![trait_ref] }
+}
+
+struct Elaborator<'tcx> {
+    tcx: TyCtxt<'tcx>,
+    visited: FxHashSet<PolyTraitRef<'tcx>>,
+    stack: Vec<PolyTraitRef<'tcx>>,
+}
+
+impl<'tcx> Elaborator<'tcx> {
+    fn elaborate(&mut self, trait_ref: PolyTraitRef<'tcx>) {
+        let supertrait_refs = self
+            .tcx
+            .super_predicates_of(trait_ref.def_id())
+            .predicates
+            .into_iter()
+            .flat_map(|(pred, _)| {
+                pred.subst_supertrait(self.tcx, &trait_ref).to_opt_poly_trait_ref()
+            })
+            .map(|t| t.value)
+            .filter(|supertrait_ref| self.visited.insert(*supertrait_ref));
+
+        self.stack.extend(supertrait_refs);
+    }
+}
+
+impl<'tcx> Iterator for Elaborator<'tcx> {
+    type Item = PolyTraitRef<'tcx>;
+
+    fn next(&mut self) -> Option<PolyTraitRef<'tcx>> {
+        if let Some(trait_ref) = self.stack.pop() {
+            self.elaborate(trait_ref);
+            Some(trait_ref)
+        } else {
+            None
+        }
+    }
+}
diff --git a/compiler/rustc_middle/src/ty/print/pretty.rs b/compiler/rustc_middle/src/ty/print/pretty.rs
index 3846cf19d91..175295b3199 100644
--- a/compiler/rustc_middle/src/ty/print/pretty.rs
+++ b/compiler/rustc_middle/src/ty/print/pretty.rs
@@ -643,81 +643,8 @@ pub trait PrettyPrinter<'tcx>:
                         }
                         return Ok(self);
                     }
-                    // Grab the "TraitA + TraitB" from `impl TraitA + TraitB`,
-                    // by looking up the projections associated with the def_id.
-                    let bounds = self.tcx().explicit_item_bounds(def_id);
-
-                    let mut first = true;
-                    let mut is_sized = false;
-                    let mut is_future = false;
-                    let mut future_output_ty = None;
-
-                    p!("impl");
-                    for (predicate, _) in bounds {
-                        let predicate = predicate.subst(self.tcx(), substs);
-                        let bound_predicate = predicate.kind();
-
-                        match bound_predicate.skip_binder() {
-                            ty::PredicateKind::Projection(projection_predicate) => {
-                                let Some(future_trait) = self.tcx().lang_items().future_trait() else { continue };
-                                let future_output_def_id =
-                                    self.tcx().associated_item_def_ids(future_trait)[0];
-
-                                if projection_predicate.projection_ty.item_def_id
-                                    == future_output_def_id
-                                {
-                                    // We don't account for multiple `Future::Output = Ty` contraints.
-                                    is_future = true;
-                                    future_output_ty = Some(projection_predicate.ty);
-                                }
-                            }
-                            ty::PredicateKind::Trait(pred) => {
-                                let trait_ref = bound_predicate.rebind(pred.trait_ref);
-                                // Don't print +Sized, but rather +?Sized if absent.
-                                if Some(trait_ref.def_id()) == self.tcx().lang_items().sized_trait()
-                                {
-                                    is_sized = true;
-                                    continue;
-                                }
-
-                                if Some(trait_ref.def_id())
-                                    == self.tcx().lang_items().future_trait()
-                                {
-                                    is_future = true;
-                                    continue;
-                                }
-
-                                p!(
-                                    write("{}", if first { " " } else { " + " }),
-                                    print(trait_ref.print_only_trait_path())
-                                );
-
-                                first = false;
-                            }
-                            _ => {}
-                        }
-                    }
-
-                    if is_future {
-                        p!(write("{}Future", if first { " " } else { " + " }));
-                        first = false;
 
-                        if let Some(future_output_ty) = future_output_ty {
-                            // Don't print projection types, which we (unfortunately) see often
-                            // in the error outputs involving async blocks.
-                            if !matches!(future_output_ty.kind(), ty::Projection(_)) {
-                                p!("<Output = ", print(future_output_ty), ">");
-                            }
-                        }
-                    }
-
-                    if !is_sized {
-                        p!(write("{}?Sized", if first { " " } else { " + " }));
-                    } else if first {
-                        p!(" Sized");
-                    }
-
-                    Ok(self)
+                    self.pretty_print_opaque_impl_type(def_id, substs)
                 });
             }
             ty::Str => p!("str"),
@@ -826,6 +753,225 @@ pub trait PrettyPrinter<'tcx>:
         Ok(self)
     }
 
+    fn pretty_print_opaque_impl_type(
+        mut self,
+        def_id: DefId,
+        substs: &'tcx ty::List<ty::GenericArg<'tcx>>,
+    ) -> Result<Self::Type, Self::Error> {
+        define_scoped_cx!(self);
+
+        // Grab the "TraitA + TraitB" from `impl TraitA + TraitB`,
+        // by looking up the projections associated with the def_id.
+        let bounds = self.tcx().explicit_item_bounds(def_id);
+
+        let mut traits = BTreeMap::new();
+        let mut fn_traits = BTreeMap::new();
+        let mut is_sized = false;
+
+        for (predicate, _) in bounds {
+            let predicate = predicate.subst(self.tcx(), substs);
+            let bound_predicate = predicate.kind();
+
+            match bound_predicate.skip_binder() {
+                ty::PredicateKind::Trait(pred) => {
+                    let trait_ref = bound_predicate.rebind(pred.trait_ref);
+
+                    // Don't print + Sized, but rather + ?Sized if absent.
+                    if Some(trait_ref.def_id()) == self.tcx().lang_items().sized_trait() {
+                        is_sized = true;
+                        continue;
+                    }
+
+                    self.insert_trait_and_projection(trait_ref, None, &mut traits, &mut fn_traits);
+                }
+                ty::PredicateKind::Projection(pred) => {
+                    let proj_ref = bound_predicate.rebind(pred);
+                    let trait_ref = proj_ref.required_poly_trait_ref(self.tcx());
+
+                    // Projection type entry -- the def-id for naming, and the ty.
+                    let proj_ty = (proj_ref.projection_def_id(), proj_ref.ty());
+
+                    self.insert_trait_and_projection(
+                        trait_ref,
+                        Some(proj_ty),
+                        &mut traits,
+                        &mut fn_traits,
+                    );
+                }
+                _ => {}
+            }
+        }
+
+        let mut first = true;
+        // Insert parenthesis around (Fn(A, B) -> C) if the opaque ty has more than one other trait
+        let paren_needed = fn_traits.len() > 1 || traits.len() > 0 || !is_sized;
+
+        p!("impl");
+
+        for (fn_once_trait_ref, entry) in fn_traits {
+            // Get the (single) generic ty (the args) of this FnOnce trait ref.
+            let generics = self.generic_args_to_print(
+                self.tcx().generics_of(fn_once_trait_ref.def_id()),
+                fn_once_trait_ref.skip_binder().substs,
+            );
+
+            match (entry.return_ty, generics[0].expect_ty()) {
+                // We can only print `impl Fn() -> ()` if we have a tuple of args and we recorded
+                // a return type.
+                (Some(return_ty), arg_tys) if matches!(arg_tys.kind(), ty::Tuple(_)) => {
+                    let name = if entry.fn_trait_ref.is_some() {
+                        "Fn"
+                    } else if entry.fn_mut_trait_ref.is_some() {
+                        "FnMut"
+                    } else {
+                        "FnOnce"
+                    };
+
+                    p!(
+                        write("{}", if first { " " } else { " + " }),
+                        write("{}{}(", if paren_needed { "(" } else { "" }, name)
+                    );
+
+                    for (idx, ty) in arg_tys.tuple_fields().enumerate() {
+                        if idx > 0 {
+                            p!(", ");
+                        }
+                        p!(print(ty));
+                    }
+
+                    p!(")");
+                    if !return_ty.skip_binder().is_unit() {
+                        p!("-> ", print(return_ty));
+                    }
+                    p!(write("{}", if paren_needed { ")" } else { "" }));
+
+                    first = false;
+                }
+                // If we got here, we can't print as a `impl Fn(A, B) -> C`. Just record the
+                // trait_refs we collected in the OpaqueFnEntry as normal trait refs.
+                _ => {
+                    if entry.has_fn_once {
+                        traits.entry(fn_once_trait_ref).or_default().extend(
+                            // Group the return ty with its def id, if we had one.
+                            entry
+                                .return_ty
+                                .map(|ty| (self.tcx().lang_items().fn_once_output().unwrap(), ty)),
+                        );
+                    }
+                    if let Some(trait_ref) = entry.fn_mut_trait_ref {
+                        traits.entry(trait_ref).or_default();
+                    }
+                    if let Some(trait_ref) = entry.fn_trait_ref {
+                        traits.entry(trait_ref).or_default();
+                    }
+                }
+            }
+        }
+
+        // Print the rest of the trait types (that aren't Fn* family of traits)
+        for (trait_ref, assoc_items) in traits {
+            p!(
+                write("{}", if first { " " } else { " + " }),
+                print(trait_ref.skip_binder().print_only_trait_name())
+            );
+
+            let generics = self.generic_args_to_print(
+                self.tcx().generics_of(trait_ref.def_id()),
+                trait_ref.skip_binder().substs,
+            );
+
+            if !generics.is_empty() || !assoc_items.is_empty() {
+                p!("<");
+                let mut first = true;
+
+                for ty in generics {
+                    if !first {
+                        p!(", ");
+                    }
+                    p!(print(trait_ref.rebind(*ty)));
+                    first = false;
+                }
+
+                for (assoc_item_def_id, ty) in assoc_items {
+                    if !first {
+                        p!(", ");
+                    }
+                    p!(write("{} = ", self.tcx().associated_item(assoc_item_def_id).ident));
+
+                    // Skip printing `<[generator@] as Generator<_>>::Return` from async blocks
+                    match ty.skip_binder().kind() {
+                        ty::Projection(ty::ProjectionTy { item_def_id, .. })
+                            if Some(*item_def_id) == self.tcx().lang_items().generator_return() =>
+                        {
+                            p!("[async output]")
+                        }
+                        _ => {
+                            p!(print(ty))
+                        }
+                    }
+
+                    first = false;
+                }
+
+                p!(">");
+            }
+
+            first = false;
+        }
+
+        if !is_sized {
+            p!(write("{}?Sized", if first { " " } else { " + " }));
+        } else if first {
+            p!(" Sized");
+        }
+
+        Ok(self)
+    }
+
+    /// Insert the trait ref and optionally a projection type associated with it into either the
+    /// traits map or fn_traits map, depending on if the trait is in the Fn* family of traits.
+    fn insert_trait_and_projection(
+        &mut self,
+        trait_ref: ty::PolyTraitRef<'tcx>,
+        proj_ty: Option<(DefId, ty::Binder<'tcx, Ty<'tcx>>)>,
+        traits: &mut BTreeMap<ty::PolyTraitRef<'tcx>, BTreeMap<DefId, ty::Binder<'tcx, Ty<'tcx>>>>,
+        fn_traits: &mut BTreeMap<ty::PolyTraitRef<'tcx>, OpaqueFnEntry<'tcx>>,
+    ) {
+        let trait_def_id = trait_ref.def_id();
+
+        // If our trait_ref is FnOnce or any of its children, project it onto the parent FnOnce
+        // super-trait ref and record it there.
+        if let Some(fn_once_trait) = self.tcx().lang_items().fn_once_trait() {
+            // If we have a FnOnce, then insert it into
+            if trait_def_id == fn_once_trait {
+                let entry = fn_traits.entry(trait_ref).or_default();
+                // Optionally insert the return_ty as well.
+                if let Some((_, ty)) = proj_ty {
+                    entry.return_ty = Some(ty);
+                }
+                entry.has_fn_once = true;
+                return;
+            } else if Some(trait_def_id) == self.tcx().lang_items().fn_mut_trait() {
+                let super_trait_ref = crate::traits::util::supertraits(self.tcx(), trait_ref)
+                    .find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
+                    .unwrap();
+
+                fn_traits.entry(super_trait_ref).or_default().fn_mut_trait_ref = Some(trait_ref);
+                return;
+            } else if Some(trait_def_id) == self.tcx().lang_items().fn_trait() {
+                let super_trait_ref = crate::traits::util::supertraits(self.tcx(), trait_ref)
+                    .find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
+                    .unwrap();
+
+                fn_traits.entry(super_trait_ref).or_default().fn_trait_ref = Some(trait_ref);
+                return;
+            }
+        }
+
+        // Otherwise, just group our traits and projection types.
+        traits.entry(trait_ref).or_default().extend(proj_ty);
+    }
+
     fn pretty_print_bound_var(
         &mut self,
         debruijn: ty::DebruijnIndex,
@@ -2553,3 +2699,12 @@ fn trimmed_def_paths(tcx: TyCtxt<'_>, (): ()) -> FxHashMap<DefId, Symbol> {
 pub fn provide(providers: &mut ty::query::Providers) {
     *providers = ty::query::Providers { trimmed_def_paths, ..*providers };
 }
+
+#[derive(Default)]
+pub struct OpaqueFnEntry<'tcx> {
+    // The trait ref is already stored as a key, so just track if we have it as a real predicate
+    has_fn_once: bool,
+    fn_mut_trait_ref: Option<ty::PolyTraitRef<'tcx>>,
+    fn_trait_ref: Option<ty::PolyTraitRef<'tcx>>,
+    return_ty: Option<ty::Binder<'tcx, Ty<'tcx>>>,
+}
diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs
index 7e054d1e17f..c2b32cd06ea 100644
--- a/compiler/rustc_middle/src/ty/sty.rs
+++ b/compiler/rustc_middle/src/ty/sty.rs
@@ -890,7 +890,7 @@ impl<'tcx> List<ty::Binder<'tcx, ExistentialPredicate<'tcx>>> {
 ///
 /// Trait references also appear in object types like `Foo<U>`, but in
 /// that case the `Self` parameter is absent from the substitutions.
-#[derive(Copy, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, TyEncodable, TyDecodable)]
 #[derive(HashStable, TypeFoldable)]
 pub struct TraitRef<'tcx> {
     pub def_id: DefId,
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index 97b155d2377..247d69d6ee9 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -679,6 +679,7 @@ symbols! {
         gen_future,
         gen_kill,
         generator,
+        generator_return,
         generator_state,
         generators,
         generic_arg_infer,