about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_ast/src/ast.rs10
-rw-r--r--compiler/rustc_resolve/src/late/diagnostics.rs24
2 files changed, 26 insertions, 8 deletions
diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs
index 8d9d4123c79..245353c2e07 100644
--- a/compiler/rustc_ast/src/ast.rs
+++ b/compiler/rustc_ast/src/ast.rs
@@ -1884,6 +1884,16 @@ impl Clone for Ty {
     }
 }
 
+impl Ty {
+    pub fn peel_refs(&self) -> &Self {
+        let mut final_ty = self;
+        while let TyKind::Rptr(_, MutTy { ty, .. }) = &final_ty.kind {
+            final_ty = &ty;
+        }
+        final_ty
+    }
+}
+
 #[derive(Clone, Encodable, Decodable, Debug)]
 pub struct BareFnTy {
     pub unsafety: Unsafe,
diff --git a/compiler/rustc_resolve/src/late/diagnostics.rs b/compiler/rustc_resolve/src/late/diagnostics.rs
index b983656c423..bee05e77382 100644
--- a/compiler/rustc_resolve/src/late/diagnostics.rs
+++ b/compiler/rustc_resolve/src/late/diagnostics.rs
@@ -442,7 +442,11 @@ impl<'a: 'ast, 'ast> LateResolutionVisitor<'a, '_, 'ast> {
 
         if !self.type_ascription_suggestion(&mut err, base_span) {
             let mut fallback = false;
-            if let PathSource::Trait(AliasPossibility::Maybe) = source {
+            if let (
+                PathSource::Trait(AliasPossibility::Maybe),
+                Some(Res::Def(DefKind::Struct | DefKind::Enum | DefKind::Union, _)),
+            ) = (source, res)
+            {
                 if let Some(bounds @ [_, .., _]) = self.diagnostic_metadata.current_trait_object {
                     fallback = true;
                     let spans: Vec<Span> = bounds
@@ -580,7 +584,7 @@ impl<'a: 'ast, 'ast> LateResolutionVisitor<'a, '_, 'ast> {
             return false;
         };
 
-        if let ast::TyKind::Path(None, type_param_path) = &ty.kind {
+        if let ast::TyKind::Path(None, type_param_path) = &ty.peel_refs().kind {
             // Confirm that the `SelfTy` is a type parameter.
             let partial_res = if let Ok(Some(partial_res)) = self.resolve_qpath_anywhere(
                 bounded_ty.id,
@@ -603,20 +607,24 @@ impl<'a: 'ast, 'ast> LateResolutionVisitor<'a, '_, 'ast> {
                 return false;
             }
             if let (
-                [ast::PathSegment { ident, args: None, .. }],
+                [ast::PathSegment { ident: constrain_ident, args: None, .. }],
                 [ast::GenericBound::Trait(poly_trait_ref, ast::TraitBoundModifier::None)],
             ) = (&type_param_path.segments[..], &bounds[..])
             {
-                if let [ast::PathSegment { ident: bound_ident, args: None, .. }] =
+                if let [ast::PathSegment { ident, args: None, .. }] =
                     &poly_trait_ref.trait_ref.path.segments[..]
                 {
-                    if bound_ident.span == span {
+                    if ident.span == span {
                         err.span_suggestion_verbose(
                             *where_span,
-                            &format!("constrain the associated type to `{}`", bound_ident),
+                            &format!("constrain the associated type to `{}`", ident),
                             format!(
                                 "{}: {}<{} = {}>",
-                                ident,
+                                self.r
+                                    .session
+                                    .source_map()
+                                    .span_to_snippet(ty.span) // Account for `<&'a T as Foo>::Bar`.
+                                    .unwrap_or_else(|_| constrain_ident.to_string()),
                                 path.segments[..*position]
                                     .iter()
                                     .map(|segment| path_segment_to_string(segment))
@@ -627,7 +635,7 @@ impl<'a: 'ast, 'ast> LateResolutionVisitor<'a, '_, 'ast> {
                                     .map(|segment| path_segment_to_string(segment))
                                     .collect::<Vec<_>>()
                                     .join("::"),
-                                bound_ident,
+                                ident,
                             ),
                             Applicability::MaybeIncorrect,
                         );