about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2022-09-04 22:21:15 +0000
committerMichael Goulet <michael@errs.io>2022-09-08 02:06:48 +0000
commit30e3673d437d7ca049d6080eee19e696c3d7429f (patch)
tree2be19a3d0efde4f4799201313c83153d9bcdcb95
parent0dbbf0f49398d6c74fd3337dd171fac6c7aa3d12 (diff)
downloadrust-30e3673d437d7ca049d6080eee19e696c3d7429f.tar.gz
rust-30e3673d437d7ca049d6080eee19e696c3d7429f.zip
Add associated item binding to non-param-ty where clause suggestions
-rw-r--r--compiler/rustc_middle/src/traits/mod.rs4
-rw-r--r--compiler/rustc_middle/src/ty/diagnostics.rs14
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs24
-rw-r--r--compiler/rustc_typeck/src/check/fn_ctxt/_impl.rs2
-rw-r--r--compiler/rustc_typeck/src/check/method/mod.rs26
-rw-r--r--compiler/rustc_typeck/src/check/op.rs47
-rw-r--r--src/test/ui/traits/resolution-in-overloaded-op.stderr4
7 files changed, 65 insertions, 56 deletions
diff --git a/compiler/rustc_middle/src/traits/mod.rs b/compiler/rustc_middle/src/traits/mod.rs
index ab7e5ba3a10..a56fac7c4dd 100644
--- a/compiler/rustc_middle/src/traits/mod.rs
+++ b/compiler/rustc_middle/src/traits/mod.rs
@@ -12,7 +12,7 @@ pub mod util;
 use crate::infer::canonical::Canonical;
 use crate::ty::abstract_const::NotConstEvaluatable;
 use crate::ty::subst::SubstsRef;
-use crate::ty::{self, AdtKind, Predicate, Ty, TyCtxt};
+use crate::ty::{self, AdtKind, Ty, TyCtxt};
 
 use rustc_data_structures::sync::Lrc;
 use rustc_errors::{Applicability, Diagnostic};
@@ -416,7 +416,7 @@ pub enum ObligationCauseCode<'tcx> {
     BinOp {
         rhs_span: Option<Span>,
         is_lit: bool,
-        output_pred: Option<Predicate<'tcx>>,
+        output_ty: Option<Ty<'tcx>>,
     },
 }
 
diff --git a/compiler/rustc_middle/src/ty/diagnostics.rs b/compiler/rustc_middle/src/ty/diagnostics.rs
index dd2f4321060..fc6a77d400d 100644
--- a/compiler/rustc_middle/src/ty/diagnostics.rs
+++ b/compiler/rustc_middle/src/ty/diagnostics.rs
@@ -102,13 +102,25 @@ pub fn suggest_arbitrary_trait_bound<'tcx>(
     generics: &hir::Generics<'_>,
     err: &mut Diagnostic,
     trait_pred: PolyTraitPredicate<'tcx>,
+    associated_ty: Option<(&'static str, Ty<'tcx>)>,
 ) -> bool {
     if !trait_pred.is_suggestable(tcx, false) {
         return false;
     }
 
     let param_name = trait_pred.skip_binder().self_ty().to_string();
-    let constraint = trait_pred.print_modifiers_and_trait_path().to_string();
+    let mut constraint = trait_pred.print_modifiers_and_trait_path().to_string();
+
+    if let Some((name, term)) = associated_ty {
+        // FIXME: this case overlaps with code in TyCtxt::note_and_explain_type_err.
+        // That should be extracted into a helper function.
+        if constraint.ends_with('>') {
+            constraint = format!("{}, {}={}>", &constraint[..constraint.len() - 1], name, term);
+        } else {
+            constraint.push_str(&format!("<{}={}>", name, term));
+        }
+    }
+
     let param = generics.params.iter().find(|p| p.name.ident().as_str() == param_name);
 
     // Skip, there is a param named Self
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
index b012073f771..595c68166bc 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
@@ -25,8 +25,7 @@ use rustc_middle::hir::map;
 use rustc_middle::ty::{
     self, suggest_arbitrary_trait_bound, suggest_constraining_type_param, AdtKind, DefIdTree,
     GeneratorDiagnosticData, GeneratorInteriorTypeCause, Infer, InferTy, IsSuggestable,
-    ProjectionPredicate, ToPredicate, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable,
-    TypeVisitable,
+    ToPredicate, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable,
 };
 use rustc_middle::ty::{TypeAndMut, TypeckResults};
 use rustc_session::Limit;
@@ -174,7 +173,7 @@ pub trait InferCtxtExt<'tcx> {
         &self,
         err: &mut Diagnostic,
         trait_pred: ty::PolyTraitPredicate<'tcx>,
-        proj_pred: Option<ty::PolyProjectionPredicate<'tcx>>,
+        associated_item: Option<(&'static str, Ty<'tcx>)>,
         body_id: hir::HirId,
     );
 
@@ -467,7 +466,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
         &self,
         mut err: &mut Diagnostic,
         trait_pred: ty::PolyTraitPredicate<'tcx>,
-        proj_pred: Option<ty::PolyProjectionPredicate<'tcx>>,
+        associated_ty: Option<(&'static str, Ty<'tcx>)>,
         body_id: hir::HirId,
     ) {
         let trait_pred = self.resolve_numeric_literals_with_default(trait_pred);
@@ -604,21 +603,18 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                         trait_pred.print_modifiers_and_trait_path().to_string()
                     );
 
-                    if let Some(proj_pred) = proj_pred {
-                        let ProjectionPredicate { projection_ty, term } = proj_pred.skip_binder();
-                        let item = self.tcx.associated_item(projection_ty.item_def_id);
-
+                    if let Some((name, term)) = associated_ty {
                         // FIXME: this case overlaps with code in TyCtxt::note_and_explain_type_err.
                         // That should be extracted into a helper function.
                         if constraint.ends_with('>') {
                             constraint = format!(
                                 "{}, {}={}>",
                                 &constraint[..constraint.len() - 1],
-                                item.name,
+                                name,
                                 term
                             );
                         } else {
-                            constraint.push_str(&format!("<{}={}>", item.name, term));
+                            constraint.push_str(&format!("<{}={}>", name, term));
                         }
                     }
 
@@ -648,7 +644,13 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                     ..
                 }) if !param_ty => {
                     // Missing generic type parameter bound.
-                    if suggest_arbitrary_trait_bound(self.tcx, generics, &mut err, trait_pred) {
+                    if suggest_arbitrary_trait_bound(
+                        self.tcx,
+                        generics,
+                        &mut err,
+                        trait_pred,
+                        associated_ty,
+                    ) {
                         return;
                     }
                 }
diff --git a/compiler/rustc_typeck/src/check/fn_ctxt/_impl.rs b/compiler/rustc_typeck/src/check/fn_ctxt/_impl.rs
index c59638f5d6f..2196a799fd0 100644
--- a/compiler/rustc_typeck/src/check/fn_ctxt/_impl.rs
+++ b/compiler/rustc_typeck/src/check/fn_ctxt/_impl.rs
@@ -409,7 +409,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                     rhs_span: opt_input_expr.map(|expr| expr.span),
                     is_lit: opt_input_expr
                         .map_or(false, |expr| matches!(expr.kind, ExprKind::Lit(_))),
-                    output_pred: None,
+                    output_ty: None,
                 },
             ),
             self.param_env,
diff --git a/compiler/rustc_typeck/src/check/method/mod.rs b/compiler/rustc_typeck/src/check/method/mod.rs
index c597efbe746..249e9c66ba7 100644
--- a/compiler/rustc_typeck/src/check/method/mod.rs
+++ b/compiler/rustc_typeck/src/check/method/mod.rs
@@ -20,10 +20,7 @@ use rustc_hir::def_id::DefId;
 use rustc_infer::infer::{self, InferOk};
 use rustc_middle::ty::subst::Subst;
 use rustc_middle::ty::subst::{InternalSubsts, SubstsRef};
-use rustc_middle::ty::{
-    self, AssocKind, DefIdTree, GenericParamDefKind, ProjectionPredicate, ProjectionTy,
-    ToPredicate, Ty, TypeVisitable,
-};
+use rustc_middle::ty::{self, DefIdTree, GenericParamDefKind, ToPredicate, Ty, TypeVisitable};
 use rustc_span::symbol::Ident;
 use rustc_span::Span;
 use rustc_trait_selection::traits;
@@ -337,22 +334,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
         // Construct an obligation
         let poly_trait_ref = ty::Binder::dummy(trait_ref);
-        let opt_output_ty =
-            expected.only_has_type(self).and_then(|ty| (!ty.needs_infer()).then(|| ty));
-        let opt_output_assoc_item = self.tcx.associated_items(trait_def_id).find_by_name_and_kind(
-            self.tcx,
-            Ident::from_str("Output"),
-            AssocKind::Type,
-            trait_def_id,
-        );
-        let output_pred =
-            opt_output_ty.zip(opt_output_assoc_item).map(|(output_ty, output_assoc_item)| {
-                ty::Binder::dummy(ty::PredicateKind::Projection(ProjectionPredicate {
-                    projection_ty: ProjectionTy { substs, item_def_id: output_assoc_item.def_id },
-                    term: output_ty.into(),
-                }))
-                .to_predicate(self.tcx)
-            });
+        let output_ty = expected.only_has_type(self).and_then(|ty| (!ty.needs_infer()).then(|| ty));
 
         (
             traits::Obligation::new(
@@ -363,7 +345,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                         rhs_span: opt_input_expr.map(|expr| expr.span),
                         is_lit: opt_input_expr
                             .map_or(false, |expr| matches!(expr.kind, hir::ExprKind::Lit(_))),
-                        output_pred,
+                        output_ty,
                     },
                 ),
                 self.param_env,
@@ -518,7 +500,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                     rhs_span: opt_input_expr.map(|expr| expr.span),
                     is_lit: opt_input_expr
                         .map_or(false, |expr| matches!(expr.kind, hir::ExprKind::Lit(_))),
-                    output_pred: None,
+                    output_ty: None,
                 },
             )
         } else {
diff --git a/compiler/rustc_typeck/src/check/op.rs b/compiler/rustc_typeck/src/check/op.rs
index a7e080c13c7..4754717c29a 100644
--- a/compiler/rustc_typeck/src/check/op.rs
+++ b/compiler/rustc_typeck/src/check/op.rs
@@ -12,7 +12,7 @@ use rustc_middle::ty::adjustment::{
     Adjust, Adjustment, AllowTwoPhase, AutoBorrow, AutoBorrowMutability,
 };
 use rustc_middle::ty::print::with_no_trimmed_paths;
-use rustc_middle::ty::{self, Ty, TyCtxt, TypeFolder, TypeSuperFoldable, TypeVisitable};
+use rustc_middle::ty::{self, DefIdTree, Ty, TyCtxt, TypeFolder, TypeSuperFoldable, TypeVisitable};
 use rustc_span::source_map::Spanned;
 use rustc_span::symbol::{sym, Ident};
 use rustc_span::Span;
@@ -310,10 +310,11 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             // error types are considered "builtin"
             Err(_) if lhs_ty.references_error() || rhs_ty.references_error() => self.tcx.ty_error(),
             Err(errors) => {
-                let (_, item) = lang_item_for_op(self.tcx, Op::Binary(op, is_assign), op.span);
-                let missing_trait =
-                    item.map(|def_id| with_no_trimmed_paths!(self.tcx.def_path_str(def_id)));
-                let (mut err, use_output) = match is_assign {
+                let (_, trait_def_id) =
+                    lang_item_for_op(self.tcx, Op::Binary(op, is_assign), op.span);
+                let missing_trait = trait_def_id
+                    .map(|def_id| with_no_trimmed_paths!(self.tcx.def_path_str(def_id)));
+                let (mut err, output_def_id) = match is_assign {
                     IsAssign::Yes => {
                         let mut err = struct_span_err!(
                             self.tcx.sess,
@@ -328,7 +329,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                             format!("cannot use `{}=` on type `{}`", op.node.as_str(), lhs_ty),
                         );
                         self.note_unmet_impls_on_type(&mut err, errors);
-                        (err, false)
+                        (err, None)
                     }
                     IsAssign::No => {
                         let message = match op.node {
@@ -368,11 +369,14 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                                 lhs_ty
                             ),
                         };
-                        let use_output = item.map_or(false, |def_id| {
-                            self.tcx.associated_item_def_ids(def_id).iter().any(|item_def_id| {
-                                self.tcx.opt_associated_item(*item_def_id).unwrap().name
-                                    == sym::Output
-                            })
+                        let output_def_id = trait_def_id.and_then(|def_id| {
+                            self.tcx
+                                .associated_item_def_ids(def_id)
+                                .iter()
+                                .find(|item_def_id| {
+                                    self.tcx.associated_item(*item_def_id).name == sym::Output
+                                })
+                                .cloned()
                         });
                         let mut err = struct_span_err!(self.tcx.sess, op.span, E0369, "{message}");
                         if !lhs_expr.span.eq(&rhs_expr.span) {
@@ -380,7 +384,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                             err.span_label(rhs_expr.span, rhs_ty.to_string());
                         }
                         self.note_unmet_impls_on_type(&mut err, errors);
-                        (err, use_output)
+                        (err, output_def_id)
                     }
                 };
 
@@ -488,12 +492,21 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                                 if let Some(trait_pred) =
                                     error.obligation.predicate.to_opt_poly_trait_pred()
                                 {
-                                    let proj_pred = match error.obligation.cause.code() {
+                                    let output_associated_item = match error.obligation.cause.code()
+                                    {
                                         ObligationCauseCode::BinOp {
-                                            output_pred: Some(output_pred),
+                                            output_ty: Some(output_ty),
                                             ..
-                                        } if use_output => {
-                                            output_pred.to_opt_poly_projection_pred()
+                                        } => {
+                                            // Make sure that we're attaching `Output = ..` to the right trait predicate
+                                            if let Some(output_def_id) = output_def_id
+                                                && let Some(trait_def_id) = trait_def_id
+                                                && self.tcx.parent(output_def_id) == trait_def_id
+                                            {
+                                                Some(("Output", *output_ty))
+                                            } else {
+                                                None
+                                            }
                                         }
                                         _ => None,
                                     };
@@ -501,7 +514,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                                     self.suggest_restricting_param_bound(
                                         &mut err,
                                         trait_pred,
-                                        proj_pred,
+                                        output_associated_item,
                                         self.body_id,
                                     );
                                 }
diff --git a/src/test/ui/traits/resolution-in-overloaded-op.stderr b/src/test/ui/traits/resolution-in-overloaded-op.stderr
index 34fae64e4d2..b67e334d40a 100644
--- a/src/test/ui/traits/resolution-in-overloaded-op.stderr
+++ b/src/test/ui/traits/resolution-in-overloaded-op.stderr
@@ -8,8 +8,8 @@ LL |     a * b
    |
 help: consider introducing a `where` clause, but there might be an alternative better way to express this requirement
    |
-LL | fn foo<T: MyMul<f64, f64>>(a: &T, b: f64) -> f64 where &T: Mul<f64> {
-   |                                                  ++++++++++++++++++
+LL | fn foo<T: MyMul<f64, f64>>(a: &T, b: f64) -> f64 where &T: Mul<f64, Output=f64> {
+   |                                                  ++++++++++++++++++++++++++++++
 
 error: aborting due to previous error