about summary refs log tree commit diff
path: root/compiler/rustc_lint/src/map_unit_fn.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_lint/src/map_unit_fn.rs')
-rw-r--r--compiler/rustc_lint/src/map_unit_fn.rs103
1 files changed, 37 insertions, 66 deletions
diff --git a/compiler/rustc_lint/src/map_unit_fn.rs b/compiler/rustc_lint/src/map_unit_fn.rs
index 34210137bde..18a947dc1ee 100644
--- a/compiler/rustc_lint/src/map_unit_fn.rs
+++ b/compiler/rustc_lint/src/map_unit_fn.rs
@@ -1,6 +1,7 @@
-use rustc_hir::{Expr, ExprKind, HirId, Stmt, StmtKind};
-use rustc_middle::ty::{self, Ty};
+use rustc_hir::{Expr, ExprKind, Stmt, StmtKind};
+use rustc_middle::ty::{self};
 use rustc_session::{declare_lint, declare_lint_pass};
+use rustc_span::sym;
 
 use crate::lints::MappingToUnit;
 use crate::{LateContext, LateLintPass, LintContext};
@@ -39,58 +40,43 @@ declare_lint_pass!(MapUnitFn => [MAP_UNIT_FN]);
 
 impl<'tcx> LateLintPass<'tcx> for MapUnitFn {
     fn check_stmt(&mut self, cx: &LateContext<'tcx>, stmt: &Stmt<'_>) {
-        if stmt.span.from_expansion() {
+        let StmtKind::Semi(expr) = stmt.kind else {
             return;
-        }
-
-        if let StmtKind::Semi(expr) = stmt.kind
-            && let ExprKind::MethodCall(path, receiver, args, span) = expr.kind
+        };
+        let ExprKind::MethodCall(path, receiver, [arg], span) = expr.kind else {
+            return;
+        };
+        if path.ident.name != sym::map
+            || stmt.span.from_expansion()
+            || receiver.span.from_expansion()
+            || arg.span.from_expansion()
+            || !is_impl_slice(cx, receiver)
+            || !cx
+                .typeck_results()
+                .type_dependent_def_id(expr.hir_id)
+                .is_some_and(|id| cx.tcx.is_diagnostic_item(sym::IteratorMap, id))
         {
-            if path.ident.name.as_str() == "map" {
-                if receiver.span.from_expansion()
-                    || args.iter().any(|e| e.span.from_expansion())
-                    || !is_impl_slice(cx, receiver)
-                    || !is_diagnostic_name(cx, expr.hir_id, "IteratorMap")
-                {
-                    return;
-                }
-                let arg_ty = cx.typeck_results().expr_ty(&args[0]);
-                let default_span = args[0].span;
-                if let ty::FnDef(id, _) = arg_ty.kind() {
-                    let fn_ty = cx.tcx.fn_sig(id).skip_binder();
-                    let ret_ty = fn_ty.output().skip_binder();
-                    if is_unit_type(ret_ty) {
-                        cx.emit_span_lint(
-                            MAP_UNIT_FN,
-                            span,
-                            MappingToUnit {
-                                function_label: cx.tcx.span_of_impl(*id).unwrap_or(default_span),
-                                argument_label: args[0].span,
-                                map_label: span,
-                                suggestion: path.ident.span,
-                                replace: "for_each".to_string(),
-                            },
-                        )
-                    }
-                } else if let ty::Closure(id, subs) = arg_ty.kind() {
-                    let cl_ty = subs.as_closure().sig();
-                    let ret_ty = cl_ty.output().skip_binder();
-                    if is_unit_type(ret_ty) {
-                        cx.emit_span_lint(
-                            MAP_UNIT_FN,
-                            span,
-                            MappingToUnit {
-                                function_label: cx.tcx.span_of_impl(*id).unwrap_or(default_span),
-                                argument_label: args[0].span,
-                                map_label: span,
-                                suggestion: path.ident.span,
-                                replace: "for_each".to_string(),
-                            },
-                        )
-                    }
-                }
-            }
+            return;
         }
+        let (id, sig) = match *cx.typeck_results().expr_ty(arg).kind() {
+            ty::Closure(id, subs) => (id, subs.as_closure().sig()),
+            ty::FnDef(id, _) => (id, cx.tcx.fn_sig(id).skip_binder()),
+            _ => return,
+        };
+        let ret_ty = sig.output().skip_binder();
+        if !(ret_ty.is_unit() || ret_ty.is_never()) {
+            return;
+        }
+        cx.emit_span_lint(
+            MAP_UNIT_FN,
+            span,
+            MappingToUnit {
+                function_label: cx.tcx.span_of_impl(id).unwrap_or(arg.span),
+                argument_label: arg.span,
+                map_label: span,
+                suggestion: path.ident.span,
+            },
+        );
     }
 }
 
@@ -102,18 +88,3 @@ fn is_impl_slice(cx: &LateContext<'_>, expr: &Expr<'_>) -> bool {
     }
     false
 }
-
-fn is_unit_type(ty: Ty<'_>) -> bool {
-    ty.is_unit() || ty.is_never()
-}
-
-fn is_diagnostic_name(cx: &LateContext<'_>, id: HirId, name: &str) -> bool {
-    if let Some(def_id) = cx.typeck_results().type_dependent_def_id(id)
-        && let Some(item) = cx.tcx.get_diagnostic_name(def_id)
-    {
-        if item.as_str() == name {
-            return true;
-        }
-    }
-    false
-}