about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--clippy_lints/src/manual_let_else.rs71
-rw-r--r--tests/ui/manual_let_else.rs32
-rw-r--r--tests/ui/manual_let_else.stderr22
3 files changed, 119 insertions, 6 deletions
diff --git a/clippy_lints/src/manual_let_else.rs b/clippy_lints/src/manual_let_else.rs
index d6ac6e106b4..0b3bec714c0 100644
--- a/clippy_lints/src/manual_let_else.rs
+++ b/clippy_lints/src/manual_let_else.rs
@@ -4,11 +4,14 @@ use clippy_utils::diagnostics::span_lint_and_then;
 use clippy_utils::higher::IfLetOrMatch;
 use clippy_utils::source::snippet_with_context;
 use clippy_utils::ty::is_type_diagnostic_item;
-use clippy_utils::{is_lint_allowed, is_never_expr, msrvs, pat_and_expr_can_be_question_mark, peel_blocks};
+use clippy_utils::{
+    MaybePath, is_lint_allowed, is_never_expr, is_wild, msrvs, pat_and_expr_can_be_question_mark, path_res, peel_blocks,
+};
 use rustc_ast::BindingMode;
 use rustc_data_structures::fx::FxHashMap;
 use rustc_errors::Applicability;
-use rustc_hir::{Expr, ExprKind, MatchSource, Pat, PatExpr, PatExprKind, PatKind, QPath, Stmt, StmtKind};
+use rustc_hir::def::{CtorOf, DefKind, Res};
+use rustc_hir::{Arm, Expr, ExprKind, HirId, MatchSource, Pat, PatExpr, PatExprKind, PatKind, QPath, Stmt, StmtKind};
 use rustc_lint::{LateContext, LintContext};
 
 use rustc_span::Span;
@@ -91,14 +94,15 @@ impl<'tcx> QuestionMark {
                     let Some((idx, diverging_arm)) = diverging_arm_opt else {
                         return;
                     };
+
+                    let pat_arm = &arms[1 - idx];
                     // If the non-diverging arm is the first one, its pattern can be reused in a let/else statement.
                     // However, if it arrives in second position, its pattern may cover some cases already covered
                     // by the diverging one.
-                    // TODO: accept the non-diverging arm as a second position if patterns are disjointed.
-                    if idx == 0 {
+                    if idx == 0 && !is_arms_disjointed(cx, diverging_arm, pat_arm) {
                         return;
                     }
-                    let pat_arm = &arms[1 - idx];
+
                     let Some(ident_map) = expr_simple_identity_map(local.pat, pat_arm.pat, pat_arm.body) else {
                         return;
                     };
@@ -110,6 +114,63 @@ impl<'tcx> QuestionMark {
     }
 }
 
+/// Checks if the patterns of the arms are disjointed. Currently, we only support patterns of simple
+/// enum variants without nested patterns or bindings.
+///
+/// TODO: Support more complex patterns.
+fn is_arms_disjointed(cx: &LateContext<'_>, arm1: &Arm<'_>, arm2: &Arm<'_>) -> bool {
+    if arm1.guard.is_some() || arm2.guard.is_some() {
+        return false;
+    }
+
+    if !is_enum_variant(cx, arm1.pat) || !is_enum_variant(cx, arm2.pat) {
+        return false;
+    }
+
+    true
+}
+
+/// Returns `true` if the given pattern is a variant of an enum.
+pub fn is_enum_variant(cx: &LateContext<'_>, pat: &Pat<'_>) -> bool {
+    struct Pat<'hir>(&'hir rustc_hir::Pat<'hir>);
+
+    impl<'hir> MaybePath<'hir> for Pat<'hir> {
+        fn qpath_opt(&self) -> Option<&QPath<'hir>> {
+            match self.0.kind {
+                PatKind::Struct(ref qpath, fields, _)
+                    if fields
+                        .iter()
+                        .all(|field| is_wild(field.pat) || matches!(field.pat.kind, PatKind::Binding(..))) =>
+                {
+                    Some(qpath)
+                },
+                PatKind::TupleStruct(ref qpath, pats, _)
+                    if pats
+                        .iter()
+                        .all(|pat| is_wild(pat) || matches!(pat.kind, PatKind::Binding(..))) =>
+                {
+                    Some(qpath)
+                },
+                PatKind::Expr(&PatExpr {
+                    kind: PatExprKind::Path(ref qpath),
+                    ..
+                }) => Some(qpath),
+                _ => None,
+            }
+        }
+
+        fn hir_id(&self) -> HirId {
+            self.0.hir_id
+        }
+    }
+
+    let res = path_res(cx, &Pat(pat));
+    matches!(
+        res,
+        Res::Def(DefKind::Variant, ..) | Res::Def(DefKind::Ctor(CtorOf::Variant, _), _)
+    )
+}
+
 fn emit_manual_let_else(
     cx: &LateContext<'_>,
     span: Span,
diff --git a/tests/ui/manual_let_else.rs b/tests/ui/manual_let_else.rs
index a753566b34c..3781ba1676f 100644
--- a/tests/ui/manual_let_else.rs
+++ b/tests/ui/manual_let_else.rs
@@ -514,3 +514,35 @@ mod issue13768 {
         };
     }
 }
+
+mod issue14598 {
+    fn bar() -> Result<bool, &'static str> {
+        let value = match foo() {
+            //~^ manual_let_else
+            Err(_) => return Err("abc"),
+            Ok(value) => value,
+        };
+
+        let w = Some(0);
+        let v = match w {
+            //~^ manual_let_else
+            None => return Err("abc"),
+            Some(x) => x,
+        };
+
+        enum Foo<T> {
+            Foo(T),
+        }
+
+        let v = match Foo::Foo(Some(())) {
+            Foo::Foo(Some(_)) => return Err("abc"),
+            Foo::Foo(v) => v,
+        };
+
+        Ok(value == 42)
+    }
+
+    fn foo() -> Result<u32, &'static str> {
+        todo!()
+    }
+}
diff --git a/tests/ui/manual_let_else.stderr b/tests/ui/manual_let_else.stderr
index ef042192114..a1eea041929 100644
--- a/tests/ui/manual_let_else.stderr
+++ b/tests/ui/manual_let_else.stderr
@@ -529,5 +529,25 @@ LL +                 return;
 LL +             };
    |
 
-error: aborting due to 33 previous errors
+error: this could be rewritten as `let...else`
+  --> tests/ui/manual_let_else.rs:520:9
+   |
+LL | /         let value = match foo() {
+LL | |
+LL | |             Err(_) => return Err("abc"),
+LL | |             Ok(value) => value,
+LL | |         };
+   | |__________^ help: consider writing: `let Ok(value) = foo() else { return Err("abc") };`
+
+error: this could be rewritten as `let...else`
+  --> tests/ui/manual_let_else.rs:527:9
+   |
+LL | /         let v = match w {
+LL | |
+LL | |             None => return Err("abc"),
+LL | |             Some(x) => x,
+LL | |         };
+   | |__________^ help: consider writing: `let Some(v) = w else { return Err("abc") };`
+
+error: aborting due to 35 previous errors