about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--clippy_lints/src/derivable_impls.rs64
-rw-r--r--tests/ui/derivable_impls.fixed12
-rw-r--r--tests/ui/derivable_impls.rs15
-rw-r--r--tests/ui/derivable_impls.stderr25
-rw-r--r--tests/ui/derivable_impls_derive_const.fixed2
-rw-r--r--tests/ui/derivable_impls_derive_const.stderr2
6 files changed, 95 insertions, 25 deletions
diff --git a/clippy_lints/src/derivable_impls.rs b/clippy_lints/src/derivable_impls.rs
index c58aca6a52b..06c2393e0a3 100644
--- a/clippy_lints/src/derivable_impls.rs
+++ b/clippy_lints/src/derivable_impls.rs
@@ -10,7 +10,7 @@ use rustc_hir::{
 };
 use rustc_lint::{LateContext, LateLintPass};
 use rustc_middle::ty::adjustment::{Adjust, PointerCoercion};
-use rustc_middle::ty::{self, AdtDef, GenericArgsRef, Ty, TypeckResults};
+use rustc_middle::ty::{self, AdtDef, GenericArgsRef, Ty, TypeckResults, VariantDef};
 use rustc_session::impl_lint_pass;
 use rustc_span::sym;
 
@@ -86,13 +86,9 @@ fn contains_trait_object(ty: Ty<'_>) -> bool {
 }
 
 fn determine_derive_macro(cx: &LateContext<'_>, is_const: bool) -> Option<&'static str> {
-    if is_const {
-        if !cx.tcx.features().enabled(sym::derive_const) {
-            return None;
-        }
-        return Some("derive_const");
-    }
-    Some("derive")
+    (!is_const)
+        .then_some("derive")
+        .or_else(|| cx.tcx.features().enabled(sym::derive_const).then_some("derive_const"))
 }
 
 #[expect(clippy::too_many_arguments)]
@@ -137,18 +133,18 @@ fn check_struct<'tcx>(
         ExprKind::Tup(fields) => fields.iter().all(is_default_without_adjusts),
         ExprKind::Call(callee, args) if is_path_self(callee) => args.iter().all(is_default_without_adjusts),
         ExprKind::Struct(_, fields, _) => fields.iter().all(|ef| is_default_without_adjusts(ef.expr)),
-        _ => false,
-    };
-
-    let Some(derive_snippet) = determine_derive_macro(cx, is_const) else {
-        return;
+        _ => return,
     };
 
-    if should_emit {
+    if should_emit && let Some(derive_snippet) = determine_derive_macro(cx, is_const) {
         let struct_span = cx.tcx.def_span(adt_def.did());
+        let indent_enum = indent_of(cx, struct_span).unwrap_or(0);
         let suggestions = vec![
             (item.span, String::new()), // Remove the manual implementation
-            (struct_span.shrink_to_lo(), format!("#[{derive_snippet}(Default)]\n")), // Add the derive attribute
+            (
+                struct_span.shrink_to_lo(),
+                format!("#[{derive_snippet}(Default)]\n{}", " ".repeat(indent_enum)),
+            ), // Add the derive attribute
         ];
 
         span_lint_and_then(cx, DERIVABLE_IMPLS, item.span, "this `impl` can be derived", |diag| {
@@ -161,17 +157,41 @@ fn check_struct<'tcx>(
     }
 }
 
+fn extract_enum_variant<'tcx>(
+    cx: &LateContext<'tcx>,
+    func_expr: &'tcx Expr<'tcx>,
+    adt_def: AdtDef<'tcx>,
+) -> Option<&'tcx VariantDef> {
+    match &peel_blocks(func_expr).kind {
+        ExprKind::Path(QPath::Resolved(None, p))
+            if let Res::Def(DefKind::Ctor(CtorOf::Variant, CtorKind::Const), id) = p.res
+                && let variant_id = cx.tcx.parent(id)
+                && let Some(variant_def) = adt_def.variants().iter().find(|v| v.def_id == variant_id) =>
+        {
+            Some(variant_def)
+        },
+        ExprKind::Path(QPath::TypeRelative(ty, segment))
+            if let TyKind::Path(QPath::Resolved(None, p)) = &ty.kind
+                && let Res::SelfTyAlias {
+                    is_trait_impl: true, ..
+                } = p.res
+                && let variant_ident = segment.ident
+                && let Some(variant_def) = adt_def.variants().iter().find(|v| v.ident(cx.tcx) == variant_ident) =>
+        {
+            Some(variant_def)
+        },
+        _ => None,
+    }
+}
+
 fn check_enum<'tcx>(
     cx: &LateContext<'tcx>,
-    item: &'tcx Item<'_>,
-    func_expr: &Expr<'_>,
-    adt_def: AdtDef<'_>,
+    item: &'tcx Item<'tcx>,
+    func_expr: &'tcx Expr<'tcx>,
+    adt_def: AdtDef<'tcx>,
     is_const: bool,
 ) {
-    if let ExprKind::Path(QPath::Resolved(None, p)) = &peel_blocks(func_expr).kind
-        && let Res::Def(DefKind::Ctor(CtorOf::Variant, CtorKind::Const), id) = p.res
-        && let variant_id = cx.tcx.parent(id)
-        && let Some(variant_def) = adt_def.variants().iter().find(|v| v.def_id == variant_id)
+    if let Some(variant_def) = extract_enum_variant(cx, func_expr, adt_def)
         && variant_def.fields.is_empty()
         && !variant_def.is_field_list_non_exhaustive()
     {
diff --git a/tests/ui/derivable_impls.fixed b/tests/ui/derivable_impls.fixed
index f549aee9eb1..9f9e4e253c3 100644
--- a/tests/ui/derivable_impls.fixed
+++ b/tests/ui/derivable_impls.fixed
@@ -352,4 +352,16 @@ mod issue15493 {
     }
 }
 
+mod issue15536 {
+    #[derive(Copy, Clone)]
+    #[derive(Default)]
+    enum Bar {
+        #[default]
+        A,
+        B,
+    }
+
+    
+}
+
 fn main() {}
diff --git a/tests/ui/derivable_impls.rs b/tests/ui/derivable_impls.rs
index 1e06ff6120b..74a793b9a70 100644
--- a/tests/ui/derivable_impls.rs
+++ b/tests/ui/derivable_impls.rs
@@ -422,4 +422,19 @@ mod issue15493 {
     }
 }
 
+mod issue15536 {
+    #[derive(Copy, Clone)]
+    enum Bar {
+        A,
+        B,
+    }
+
+    impl Default for Bar {
+        //~^ derivable_impls
+        fn default() -> Self {
+            Self::A
+        }
+    }
+}
+
 fn main() {}
diff --git a/tests/ui/derivable_impls.stderr b/tests/ui/derivable_impls.stderr
index d473f2a379c..cd46414cb4a 100644
--- a/tests/ui/derivable_impls.stderr
+++ b/tests/ui/derivable_impls.stderr
@@ -187,5 +187,28 @@ LL ~     #[default]
 LL ~     Bar,
    |
 
-error: aborting due to 11 previous errors
+error: this `impl` can be derived
+  --> tests/ui/derivable_impls.rs:432:5
+   |
+LL | /     impl Default for Bar {
+LL | |
+LL | |         fn default() -> Self {
+LL | |             Self::A
+LL | |         }
+LL | |     }
+   | |_____^
+   |
+help: replace the manual implementation with a derive attribute and mark the default variant
+   |
+LL ~     #[derive(Default)]
+LL ~     enum Bar {
+LL ~         #[default]
+LL ~         A,
+LL |         B,
+LL |     }
+LL |
+LL ~     
+   |
+
+error: aborting due to 12 previous errors
 
diff --git a/tests/ui/derivable_impls_derive_const.fixed b/tests/ui/derivable_impls_derive_const.fixed
index 6df43f7fb76..f0d8d2d2409 100644
--- a/tests/ui/derivable_impls_derive_const.fixed
+++ b/tests/ui/derivable_impls_derive_const.fixed
@@ -7,7 +7,7 @@ mod issue15493 {
     #[derive(Copy, Clone)]
     #[repr(transparent)]
     #[derive_const(Default)]
-struct Foo(u64);
+    struct Foo(u64);
 
     
 
diff --git a/tests/ui/derivable_impls_derive_const.stderr b/tests/ui/derivable_impls_derive_const.stderr
index dd185676c77..196bac185dd 100644
--- a/tests/ui/derivable_impls_derive_const.stderr
+++ b/tests/ui/derivable_impls_derive_const.stderr
@@ -14,7 +14,7 @@ LL | |     }
 help: replace the manual implementation with a derive attribute
    |
 LL ~     #[derive_const(Default)]
-LL ~ struct Foo(u64);
+LL ~     struct Foo(u64);
 LL |
 LL ~     
    |