about summary refs log tree commit diff
diff options
context:
space:
mode:
authorohno418 <yutaro.ono.418@gmail.com>2022-04-04 00:08:54 +0900
committerohno418 <yutaro.ono.418@gmail.com>2022-04-05 11:40:25 +0900
commitde237823e01347a48819d34fe8e0cf130e3e54cb (patch)
tree5375d51c2995fe9ae4715f5a725b95a1adc60f47
parent0ff2f58330a590fcc967b890731d2ebedf6ecb0c (diff)
downloadrust-de237823e01347a48819d34fe8e0cf130e3e54cb.tar.gz
rust-de237823e01347a48819d34fe8e0cf130e3e54cb.zip
Suggest only when all fields impl the trait
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs2
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs66
-rw-r--r--src/test/ui/consts/const-blocks/trait-error.stderr4
-rw-r--r--src/test/ui/kindck/kindck-copy.stderr4
4 files changed, 50 insertions, 26 deletions
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs
index e196362de7a..0f5f4f3c60f 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs
@@ -536,7 +536,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
                         );
                         self.note_version_mismatch(&mut err, &trait_ref);
                         self.suggest_remove_await(&obligation, &mut err);
-                        self.suggest_derive(&mut err, trait_predicate);
+                        self.suggest_derive(&obligation, &mut err, trait_predicate);
 
                         if Some(trait_ref.def_id()) == tcx.lang_items().try_trait() {
                             self.suggest_await_before_try(
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
index 64fb352ac45..c2193bbeec4 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
@@ -190,7 +190,12 @@ pub trait InferCtxtExt<'tcx> {
         trait_ref: &ty::PolyTraitRef<'tcx>,
     );
 
-    fn suggest_derive(&self, err: &mut Diagnostic, trait_pred: ty::PolyTraitPredicate<'tcx>);
+    fn suggest_derive(
+        &self,
+        obligation: &PredicateObligation<'tcx>,
+        err: &mut Diagnostic,
+        trait_pred: ty::PolyTraitPredicate<'tcx>,
+    );
 }
 
 fn predicate_constraint(generics: &hir::Generics<'_>, pred: String) -> (Span, String) {
@@ -2592,33 +2597,60 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
         }
     }
 
-    fn suggest_derive(&self, err: &mut Diagnostic, trait_pred: ty::PolyTraitPredicate<'tcx>) {
+    fn suggest_derive(
+        &self,
+        obligation: &PredicateObligation<'tcx>,
+        err: &mut Diagnostic,
+        trait_pred: ty::PolyTraitPredicate<'tcx>,
+    ) {
         let Some(diagnostic_name) = self.tcx.get_diagnostic_name(trait_pred.def_id()) else {
             return;
         };
-        let Some(self_ty) = trait_pred.self_ty().no_bound_vars() else {
-            return;
-        };
-
-        let adt = match self_ty.ty_adt_def() {
-            Some(adt) if adt.did().is_local() => adt,
+        let (adt, substs) = match trait_pred.skip_binder().self_ty().kind() {
+            ty::Adt(adt, substs) if adt.did().is_local() => (adt, substs),
             _ => return,
         };
-        let can_derive = match diagnostic_name {
-            sym::Default => !adt.is_enum(),
-            sym::PartialEq | sym::PartialOrd => {
-                let rhs_ty = trait_pred.skip_binder().trait_ref.substs.type_at(1);
-                self_ty == rhs_ty
-            }
-            sym::Eq | sym::Ord | sym::Clone | sym::Copy | sym::Hash | sym::Debug => true,
-            _ => false,
+        let can_derive = {
+            let is_derivable_trait = match diagnostic_name {
+                sym::Default => !adt.is_enum(),
+                sym::PartialEq | sym::PartialOrd => {
+                    let rhs_ty = trait_pred.skip_binder().trait_ref.substs.type_at(1);
+                    trait_pred.skip_binder().self_ty() == rhs_ty
+                }
+                sym::Eq | sym::Ord | sym::Clone | sym::Copy | sym::Hash | sym::Debug => true,
+                _ => false,
+            };
+            is_derivable_trait &&
+                // Ensure all fields impl the trait.
+                adt.all_fields().all(|field| {
+                    let field_ty = field.ty(self.tcx, substs);
+                    let trait_substs = match diagnostic_name {
+                        sym::PartialEq | sym::PartialOrd => {
+                            self.tcx.mk_substs_trait(field_ty, &[field_ty.into()])
+                        }
+                        _ => self.tcx.mk_substs_trait(field_ty, &[]),
+                    };
+                    let trait_pred = trait_pred.map_bound_ref(|tr| ty::TraitPredicate {
+                        trait_ref: ty::TraitRef {
+                            substs: trait_substs,
+                            ..trait_pred.skip_binder().trait_ref
+                        },
+                        ..*tr
+                    });
+                    let field_obl = Obligation::new(
+                        obligation.cause.clone(),
+                        obligation.param_env,
+                        trait_pred.to_predicate(self.tcx),
+                    );
+                    self.predicate_must_hold_modulo_regions(&field_obl)
+                })
         };
         if can_derive {
             err.span_suggestion_verbose(
                 self.tcx.def_span(adt.did()).shrink_to_lo(),
                 &format!(
                     "consider annotating `{}` with `#[derive({})]`",
-                    trait_pred.skip_binder().self_ty().to_string(),
+                    trait_pred.skip_binder().self_ty(),
                     diagnostic_name.to_string(),
                 ),
                 format!("#[derive({})]\n", diagnostic_name.to_string()),
diff --git a/src/test/ui/consts/const-blocks/trait-error.stderr b/src/test/ui/consts/const-blocks/trait-error.stderr
index b6afbe1b532..26e2848e7f7 100644
--- a/src/test/ui/consts/const-blocks/trait-error.stderr
+++ b/src/test/ui/consts/const-blocks/trait-error.stderr
@@ -7,10 +7,6 @@ LL |     [Foo(String::new()); 4];
    = help: the following implementations were found:
              <Foo<T> as Copy>
    = note: the `Copy` trait is required because the repeated element will be copied
-help: consider annotating `Foo<String>` with `#[derive(Copy)]`
-   |
-LL | #[derive(Copy)]
-   |
 
 error: aborting due to previous error
 
diff --git a/src/test/ui/kindck/kindck-copy.stderr b/src/test/ui/kindck/kindck-copy.stderr
index f909eb6b14e..e147366a224 100644
--- a/src/test/ui/kindck/kindck-copy.stderr
+++ b/src/test/ui/kindck/kindck-copy.stderr
@@ -129,10 +129,6 @@ note: required by a bound in `assert_copy`
    |
 LL | fn assert_copy<T:Copy>() { }
    |                  ^^^^ required by this bound in `assert_copy`
-help: consider annotating `MyNoncopyStruct` with `#[derive(Copy)]`
-   |
-LL | #[derive(Copy)]
-   |
 
 error[E0277]: the trait bound `Rc<isize>: Copy` is not satisfied
   --> $DIR/kindck-copy.rs:67:19