about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--clippy_lints/src/unit_return_expecting_ord.rs126
1 files changed, 79 insertions, 47 deletions
diff --git a/clippy_lints/src/unit_return_expecting_ord.rs b/clippy_lints/src/unit_return_expecting_ord.rs
index 67ceac92dbc..39f4130afcf 100644
--- a/clippy_lints/src/unit_return_expecting_ord.rs
+++ b/clippy_lints/src/unit_return_expecting_ord.rs
@@ -5,7 +5,7 @@ use rustc_lint::{LateContext, LateLintPass};
 use rustc_middle::ty;
 use rustc_middle::ty::{ClauseKind, GenericPredicates, ProjectionPredicate, TraitPredicate};
 use rustc_session::declare_lint_pass;
-use rustc_span::{BytePos, Span, sym};
+use rustc_span::{BytePos, Span, Symbol, sym};
 
 declare_clippy_lint! {
     /// ### What it does
@@ -36,21 +36,26 @@ declare_clippy_lint! {
 
 declare_lint_pass!(UnitReturnExpectingOrd => [UNIT_RETURN_EXPECTING_ORD]);
 
-fn get_trait_predicates_for_trait_id<'tcx>(
+// For each
+fn get_trait_predicates_for_trait_ids<'tcx>(
     cx: &LateContext<'tcx>,
     generics: GenericPredicates<'tcx>,
-    trait_id: Option<DefId>,
-) -> Vec<TraitPredicate<'tcx>> {
-    let mut preds = Vec::new();
+    trait_ids: &[Option<DefId>], // At least 2 ids
+) -> [Vec<TraitPredicate<'tcx>>; 3] {
+    debug_assert!(trait_ids.len() >= 2);
+    let mut preds = [Vec::new(), Vec::new(), Vec::new()];
     for (pred, _) in generics.predicates {
-        if let ClauseKind::Trait(poly_trait_pred) = pred.kind().skip_binder()
-            && let trait_pred = cx
+        if let ClauseKind::Trait(poly_trait_pred) = pred.kind().skip_binder() {
+            let trait_pred = cx
                 .tcx
-                .instantiate_bound_regions_with_erased(pred.kind().rebind(poly_trait_pred))
-            && let Some(trait_def_id) = trait_id
-            && trait_def_id == trait_pred.trait_ref.def_id
-        {
-            preds.push(trait_pred);
+                .instantiate_bound_regions_with_erased(pred.kind().rebind(poly_trait_pred));
+            for (i, tid) in trait_ids.iter().enumerate() {
+                if let Some(tid) = tid
+                    && *tid == trait_pred.trait_ref.def_id
+                {
+                    preds[i].push(trait_pred);
+                }
+            }
         }
     }
     preds
@@ -74,15 +79,24 @@ fn get_projection_pred<'tcx>(
     })
 }
 
-fn get_args_to_check<'tcx>(cx: &LateContext<'tcx>, expr: &'tcx Expr<'tcx>) -> Vec<(usize, String)> {
+fn get_args_to_check<'tcx>(
+    cx: &LateContext<'tcx>,
+    expr: &'tcx Expr<'tcx>,
+    args_len: usize,
+    fn_mut_trait: DefId,
+    ord_trait: Option<DefId>,
+    partial_ord_trait: Option<DefId>,
+) -> Vec<(usize, Symbol)> {
     let mut args_to_check = Vec::new();
     if let Some(def_id) = cx.typeck_results().type_dependent_def_id(expr.hir_id) {
         let fn_sig = cx.tcx.fn_sig(def_id).instantiate_identity();
         let generics = cx.tcx.predicates_of(def_id);
-        let fn_mut_preds = get_trait_predicates_for_trait_id(cx, generics, cx.tcx.lang_items().fn_mut_trait());
-        let ord_preds = get_trait_predicates_for_trait_id(cx, generics, cx.tcx.get_diagnostic_item(sym::Ord));
-        let partial_ord_preds =
-            get_trait_predicates_for_trait_id(cx, generics, cx.tcx.lang_items().partial_ord_trait());
+        let [fn_mut_preds, ord_preds, partial_ord_preds] =
+            get_trait_predicates_for_trait_ids(cx, generics, &[Some(fn_mut_trait), ord_trait, partial_ord_trait]);
+        if fn_mut_preds.is_empty() {
+            return vec![];
+        }
+
         // Trying to call instantiate_bound_regions_with_erased on fn_sig.inputs() gives the following error
         // The trait `rustc::ty::TypeFoldable<'_>` is not implemented for
         // `&[rustc_middle::ty::Ty<'_>]`
@@ -102,12 +116,18 @@ fn get_args_to_check<'tcx>(cx: &LateContext<'tcx>, expr: &'tcx Expr<'tcx>) -> Ve
                             .iter()
                             .any(|ord| Some(ord.self_ty()) == return_ty_pred.term.as_type())
                         {
-                            args_to_check.push((i, "Ord".to_string()));
+                            args_to_check.push((i, sym::Ord));
+                            if args_to_check.len() == args_len - 1 {
+                                break;
+                            }
                         } else if partial_ord_preds
                             .iter()
                             .any(|pord| pord.self_ty() == return_ty_pred.term.expect_type())
                         {
-                            args_to_check.push((i, "PartialOrd".to_string()));
+                            args_to_check.push((i, sym::PartialOrd));
+                            if args_to_check.len() == args_len - 1 {
+                                break;
+                            }
                         }
                     }
                 }
@@ -142,38 +162,50 @@ fn check_arg<'tcx>(cx: &LateContext<'tcx>, arg: &'tcx Expr<'tcx>) -> Option<(Spa
 
 impl<'tcx> LateLintPass<'tcx> for UnitReturnExpectingOrd {
     fn check_expr(&mut self, cx: &LateContext<'tcx>, expr: &'tcx Expr<'tcx>) {
-        if let ExprKind::MethodCall(_, receiver, args, _) = expr.kind {
-            let arg_indices = get_args_to_check(cx, expr);
+        if let ExprKind::MethodCall(_, receiver, args, _) = expr.kind
+            && args.iter().any(|arg| {
+                matches!(
+                    arg.peel_blocks().peel_borrows().peel_drop_temps().kind,
+                    ExprKind::Path(_) | ExprKind::Closure(_)
+                )
+            })
+            && let Some(fn_mut_trait) = cx.tcx.lang_items().fn_mut_trait()
+        {
+            let ord_trait = cx.tcx.get_diagnostic_item(sym::Ord);
+            let partial_ord_trait = cx.tcx.lang_items().partial_ord_trait();
+            if (ord_trait, partial_ord_trait) == (None, None) {
+                return;
+            }
+
             let args = std::iter::once(receiver).chain(args.iter()).collect::<Vec<_>>();
+            let arg_indices = get_args_to_check(cx, expr, args.len(), fn_mut_trait, ord_trait, partial_ord_trait);
             for (i, trait_name) in arg_indices {
-                if i < args.len() {
-                    match check_arg(cx, args[i]) {
-                        Some((span, None)) => {
-                            span_lint(
-                                cx,
-                                UNIT_RETURN_EXPECTING_ORD,
-                                span,
-                                format!(
-                                    "this closure returns \
+                match check_arg(cx, args[i]) {
+                    Some((span, None)) => {
+                        span_lint(
+                            cx,
+                            UNIT_RETURN_EXPECTING_ORD,
+                            span,
+                            format!(
+                                "this closure returns \
                                    the unit type which also implements {trait_name}"
-                                ),
-                            );
-                        },
-                        Some((span, Some(last_semi))) => {
-                            span_lint_and_help(
-                                cx,
-                                UNIT_RETURN_EXPECTING_ORD,
-                                span,
-                                format!(
-                                    "this closure returns \
+                            ),
+                        );
+                    },
+                    Some((span, Some(last_semi))) => {
+                        span_lint_and_help(
+                            cx,
+                            UNIT_RETURN_EXPECTING_ORD,
+                            span,
+                            format!(
+                                "this closure returns \
                                    the unit type which also implements {trait_name}"
-                                ),
-                                Some(last_semi),
-                                "probably caused by this trailing semicolon",
-                            );
-                        },
-                        None => {},
-                    }
+                            ),
+                            Some(last_semi),
+                            "probably caused by this trailing semicolon",
+                        );
+                    },
+                    None => {},
                 }
             }
         }