about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir/src/diagnostics.rs2
-rw-r--r--crates/hir_def/src/path.rs1
-rw-r--r--crates/hir_expand/src/name.rs3
-rw-r--r--crates/hir_ty/src/diagnostics.rs116
-rw-r--r--crates/hir_ty/src/diagnostics/expr.rs85
-rw-r--r--crates/ide/src/diagnostics.rs3
-rw-r--r--crates/ide/src/diagnostics/fixes.rs33
7 files changed, 224 insertions, 19 deletions
diff --git a/crates/hir/src/diagnostics.rs b/crates/hir/src/diagnostics.rs
index 447faa04f18..5343a036c01 100644
--- a/crates/hir/src/diagnostics.rs
+++ b/crates/hir/src/diagnostics.rs
@@ -5,5 +5,5 @@ pub use hir_expand::diagnostics::{
 };
 pub use hir_ty::diagnostics::{
     IncorrectCase, MismatchedArgCount, MissingFields, MissingMatchArms, MissingOkOrSomeInTailExpr,
-    NoSuchField, RemoveThisSemicolon,
+    NoSuchField, RemoveThisSemicolon, ReplaceFilterMapNextWithFindMap,
 };
diff --git a/crates/hir_def/src/path.rs b/crates/hir_def/src/path.rs
index e34cd7f2f28..84ea09b5387 100644
--- a/crates/hir_def/src/path.rs
+++ b/crates/hir_def/src/path.rs
@@ -304,6 +304,7 @@ pub use hir_expand::name as __name;
 #[macro_export]
 macro_rules! __known_path {
     (core::iter::IntoIterator) => {};
+    (core::iter::Iterator) => {};
     (core::result::Result) => {};
     (core::option::Option) => {};
     (core::ops::Range) => {};
diff --git a/crates/hir_expand/src/name.rs b/crates/hir_expand/src/name.rs
index d692cec145a..c7609e90d98 100644
--- a/crates/hir_expand/src/name.rs
+++ b/crates/hir_expand/src/name.rs
@@ -186,6 +186,9 @@ pub mod known {
         Neg,
         Not,
         Index,
+        // Components of known path (function name)
+        filter_map,
+        next,
         // Builtin macros
         file,
         column,
diff --git a/crates/hir_ty/src/diagnostics.rs b/crates/hir_ty/src/diagnostics.rs
index 247da43f22d..323c5f96308 100644
--- a/crates/hir_ty/src/diagnostics.rs
+++ b/crates/hir_ty/src/diagnostics.rs
@@ -247,7 +247,7 @@ impl Diagnostic for RemoveThisSemicolon {
 
 // Diagnostic: break-outside-of-loop
 //
-// This diagnostic is triggered if `break` keyword is used outside of a loop.
+// This diagnostic is triggered if the `break` keyword is used outside of a loop.
 #[derive(Debug)]
 pub struct BreakOutsideOfLoop {
     pub file: HirFileId,
@@ -271,7 +271,7 @@ impl Diagnostic for BreakOutsideOfLoop {
 
 // Diagnostic: missing-unsafe
 //
-// This diagnostic is triggered if operation marked as `unsafe` is used outside of `unsafe` function or block.
+// This diagnostic is triggered if an operation marked as `unsafe` is used outside of an `unsafe` function or block.
 #[derive(Debug)]
 pub struct MissingUnsafe {
     pub file: HirFileId,
@@ -295,7 +295,7 @@ impl Diagnostic for MissingUnsafe {
 
 // Diagnostic: mismatched-arg-count
 //
-// This diagnostic is triggered if function is invoked with an incorrect amount of arguments.
+// This diagnostic is triggered if a function is invoked with an incorrect amount of arguments.
 #[derive(Debug)]
 pub struct MismatchedArgCount {
     pub file: HirFileId,
@@ -347,7 +347,7 @@ impl fmt::Display for CaseType {
 
 // Diagnostic: incorrect-ident-case
 //
-// This diagnostic is triggered if item name doesn't follow https://doc.rust-lang.org/1.0.0/style/style/naming/README.html[Rust naming convention].
+// This diagnostic is triggered if an item name doesn't follow https://doc.rust-lang.org/1.0.0/style/style/naming/README.html[Rust naming convention].
 #[derive(Debug)]
 pub struct IncorrectCase {
     pub file: HirFileId,
@@ -386,6 +386,31 @@ impl Diagnostic for IncorrectCase {
     }
 }
 
+// Diagnostic: replace-filter-map-next-with-find-map
+//
+// This diagnostic is triggered when `.filter_map(..).next()` is used, rather than the more concise `.find_map(..)`.
+#[derive(Debug)]
+pub struct ReplaceFilterMapNextWithFindMap {
+    pub file: HirFileId,
+    /// This expression is the whole method chain up to and including `.filter_map(..).next()`.
+    pub next_expr: AstPtr<ast::Expr>,
+}
+
+impl Diagnostic for ReplaceFilterMapNextWithFindMap {
+    fn code(&self) -> DiagnosticCode {
+        DiagnosticCode("replace-filter-map-next-with-find-map")
+    }
+    fn message(&self) -> String {
+        "replace filter_map(..).next() with find_map(..)".to_string()
+    }
+    fn display_source(&self) -> InFile<SyntaxNodePtr> {
+        InFile { file_id: self.file, value: self.next_expr.clone().into() }
+    }
+    fn as_any(&self) -> &(dyn Any + Send + 'static) {
+        self
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use base_db::{fixture::WithFixture, FileId, SourceDatabase, SourceDatabaseExt};
@@ -644,4 +669,87 @@ fn foo() { break; }
             "#,
         );
     }
+
+    // Register the required standard library types to make the tests work
+    fn add_filter_map_with_find_next_boilerplate(body: &str) -> String {
+        let prefix = r#"
+        //- /main.rs crate:main deps:core
+        use core::iter::Iterator;
+        use core::option::Option::{self, Some, None};
+        "#;
+        let suffix = r#"
+        //- /core/lib.rs crate:core
+        pub mod option {
+            pub enum Option<T> { Some(T), None }
+        }
+        pub mod iter {
+            pub trait Iterator {
+                type Item;
+                fn filter_map<B, F>(self, f: F) -> FilterMap where F: FnMut(Self::Item) -> Option<B> { FilterMap }
+                fn next(&mut self) -> Option<Self::Item>;
+            }
+            pub struct FilterMap {}
+            impl Iterator for FilterMap {
+                type Item = i32;
+                fn next(&mut self) -> i32 { 7 }
+            }
+        }
+        "#;
+        format!("{}{}{}", prefix, body, suffix)
+    }
+
+    #[test]
+    fn replace_filter_map_next_with_find_map2() {
+        check_diagnostics(&add_filter_map_with_find_next_boilerplate(
+            r#"
+            fn foo() {
+                let m = [1, 2, 3].iter().filter_map(|x| if *x == 2 { Some (4) } else { None }).next();
+                      //^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ replace filter_map(..).next() with find_map(..)
+            }
+        "#,
+        ));
+    }
+
+    #[test]
+    fn replace_filter_map_next_with_find_map_no_diagnostic_without_next() {
+        check_diagnostics(&add_filter_map_with_find_next_boilerplate(
+            r#"
+            fn foo() {
+                let m = [1, 2, 3]
+                    .iter()
+                    .filter_map(|x| if *x == 2 { Some (4) } else { None })
+                    .len();
+            }
+            "#,
+        ));
+    }
+
+    #[test]
+    fn replace_filter_map_next_with_find_map_no_diagnostic_with_intervening_methods() {
+        check_diagnostics(&add_filter_map_with_find_next_boilerplate(
+            r#"
+            fn foo() {
+                let m = [1, 2, 3]
+                    .iter()
+                    .filter_map(|x| if *x == 2 { Some (4) } else { None })
+                    .map(|x| x + 2)
+                    .len();
+            }
+            "#,
+        ));
+    }
+
+    #[test]
+    fn replace_filter_map_next_with_find_map_no_diagnostic_if_not_in_chain() {
+        check_diagnostics(&add_filter_map_with_find_next_boilerplate(
+            r#"
+            fn foo() {
+                let m = [1, 2, 3]
+                    .iter()
+                    .filter_map(|x| if *x == 2 { Some (4) } else { None });
+                let n = m.next();
+            }
+            "#,
+        ));
+    }
 }
diff --git a/crates/hir_ty/src/diagnostics/expr.rs b/crates/hir_ty/src/diagnostics/expr.rs
index 107417c2780..d740b726555 100644
--- a/crates/hir_ty/src/diagnostics/expr.rs
+++ b/crates/hir_ty/src/diagnostics/expr.rs
@@ -2,8 +2,10 @@
 
 use std::sync::Arc;
 
-use hir_def::{expr::Statement, path::path, resolver::HasResolver, AdtId, DefWithBodyId};
-use hir_expand::diagnostics::DiagnosticSink;
+use hir_def::{
+    expr::Statement, path::path, resolver::HasResolver, AdtId, AssocItemId, DefWithBodyId,
+};
+use hir_expand::{diagnostics::DiagnosticSink, name};
 use rustc_hash::FxHashSet;
 use syntax::{ast, AstPtr};
 
@@ -24,6 +26,8 @@ pub(crate) use hir_def::{
     LocalFieldId, VariantId,
 };
 
+use super::ReplaceFilterMapNextWithFindMap;
+
 pub(super) struct ExprValidator<'a, 'b: 'a> {
     owner: DefWithBodyId,
     infer: Arc<InferenceResult>,
@@ -40,6 +44,8 @@ impl<'a, 'b> ExprValidator<'a, 'b> {
     }
 
     pub(super) fn validate_body(&mut self, db: &dyn HirDatabase) {
+        self.check_for_filter_map_next(db);
+
         let body = db.body(self.owner.into());
 
         for (id, expr) in body.exprs.iter() {
@@ -150,20 +156,76 @@ impl<'a, 'b> ExprValidator<'a, 'b> {
         }
     }
 
-    fn validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr) -> Option<()> {
+    fn check_for_filter_map_next(&mut self, db: &dyn HirDatabase) {
+        // Find the FunctionIds for Iterator::filter_map and Iterator::next
+        let iterator_path = path![core::iter::Iterator];
+        let resolver = self.owner.resolver(db.upcast());
+        let iterator_trait_id = match resolver.resolve_known_trait(db.upcast(), &iterator_path) {
+            Some(id) => id,
+            None => return,
+        };
+        let iterator_trait_items = &db.trait_data(iterator_trait_id).items;
+        let filter_map_function_id =
+            match iterator_trait_items.iter().find(|item| item.0 == name![filter_map]) {
+                Some((_, AssocItemId::FunctionId(id))) => id,
+                _ => return,
+            };
+        let next_function_id = match iterator_trait_items.iter().find(|item| item.0 == name![next])
+        {
+            Some((_, AssocItemId::FunctionId(id))) => id,
+            _ => return,
+        };
+
+        // Search function body for instances of .filter_map(..).next()
+        let body = db.body(self.owner.into());
+        let mut prev = None;
+        for (id, expr) in body.exprs.iter() {
+            if let Expr::MethodCall { receiver, .. } = expr {
+                let function_id = match self.infer.method_resolution(id) {
+                    Some(id) => id,
+                    None => continue,
+                };
+
+                if function_id == *filter_map_function_id {
+                    prev = Some(id);
+                    continue;
+                }
+
+                if function_id == *next_function_id {
+                    if let Some(filter_map_id) = prev {
+                        if *receiver == filter_map_id {
+                            let (_, source_map) = db.body_with_source_map(self.owner.into());
+                            if let Ok(next_source_ptr) = source_map.expr_syntax(id) {
+                                self.sink.push(ReplaceFilterMapNextWithFindMap {
+                                    file: next_source_ptr.file_id,
+                                    next_expr: next_source_ptr.value,
+                                });
+                            }
+                        }
+                    }
+                }
+            }
+            prev = None;
+        }
+    }
+
+    fn validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr) {
         // Check that the number of arguments matches the number of parameters.
 
         // FIXME: Due to shortcomings in the current type system implementation, only emit this
         // diagnostic if there are no type mismatches in the containing function.
         if self.infer.type_mismatches.iter().next().is_some() {
-            return None;
+            return;
         }
 
         let is_method_call = matches!(expr, Expr::MethodCall { .. });
         let (sig, args) = match expr {
             Expr::Call { callee, args } => {
                 let callee = &self.infer.type_of_expr[*callee];
-                let sig = callee.callable_sig(db)?;
+                let sig = match callee.callable_sig(db) {
+                    Some(sig) => sig,
+                    None => return,
+                };
                 (sig, args.clone())
             }
             Expr::MethodCall { receiver, args, .. } => {
@@ -175,22 +237,25 @@ impl<'a, 'b> ExprValidator<'a, 'b> {
                     // if the receiver is of unknown type, it's very likely we
                     // don't know enough to correctly resolve the method call.
                     // This is kind of a band-aid for #6975.
-                    return None;
+                    return;
                 }
 
                 // FIXME: note that we erase information about substs here. This
                 // is not right, but, luckily, doesn't matter as we care only
                 // about the number of params
-                let callee = self.infer.method_resolution(call_id)?;
+                let callee = match self.infer.method_resolution(call_id) {
+                    Some(callee) => callee,
+                    None => return,
+                };
                 let sig = db.callable_item_signature(callee.into()).value;
 
                 (sig, args)
             }
-            _ => return None,
+            _ => return,
         };
 
         if sig.is_varargs {
-            return None;
+            return;
         }
 
         let params = sig.params();
@@ -213,8 +278,6 @@ impl<'a, 'b> ExprValidator<'a, 'b> {
                 });
             }
         }
-
-        None
     }
 
     fn validate_match(
diff --git a/crates/ide/src/diagnostics.rs b/crates/ide/src/diagnostics.rs
index b35bc2bae22..8607139ba3c 100644
--- a/crates/ide/src/diagnostics.rs
+++ b/crates/ide/src/diagnostics.rs
@@ -136,6 +136,9 @@ pub(crate) fn diagnostics(
         .on::<hir::diagnostics::IncorrectCase, _>(|d| {
             res.borrow_mut().push(warning_with_fix(d, &sema));
         })
+        .on::<hir::diagnostics::ReplaceFilterMapNextWithFindMap, _>(|d| {
+            res.borrow_mut().push(warning_with_fix(d, &sema));
+        })
         .on::<hir::diagnostics::InactiveCode, _>(|d| {
             // If there's inactive code somewhere in a macro, don't propagate to the call-site.
             if d.display_source().file_id.expansion_info(db).is_some() {
diff --git a/crates/ide/src/diagnostics/fixes.rs b/crates/ide/src/diagnostics/fixes.rs
index 579d5a30889..cbfc66ab3c0 100644
--- a/crates/ide/src/diagnostics/fixes.rs
+++ b/crates/ide/src/diagnostics/fixes.rs
@@ -4,7 +4,7 @@ use hir::{
     db::AstDatabase,
     diagnostics::{
         Diagnostic, IncorrectCase, MissingFields, MissingOkOrSomeInTailExpr, NoSuchField,
-        RemoveThisSemicolon, UnresolvedModule,
+        RemoveThisSemicolon, ReplaceFilterMapNextWithFindMap, UnresolvedModule,
     },
     HasSource, HirDisplay, InFile, Semantics, VariantDef,
 };
@@ -15,8 +15,8 @@ use ide_db::{
 };
 use syntax::{
     algo,
-    ast::{self, edit::IndentLevel, make},
-    AstNode,
+    ast::{self, edit::IndentLevel, make, ArgListOwner},
+    AstNode, TextRange,
 };
 use text_edit::TextEdit;
 
@@ -144,6 +144,33 @@ impl DiagnosticWithFix for IncorrectCase {
     }
 }
 
+impl DiagnosticWithFix for ReplaceFilterMapNextWithFindMap {
+    fn fix(&self, sema: &Semantics<RootDatabase>) -> Option<Fix> {
+        let root = sema.db.parse_or_expand(self.file)?;
+        let next_expr = self.next_expr.to_node(&root);
+        let next_call = ast::MethodCallExpr::cast(next_expr.syntax().clone())?;
+
+        let filter_map_call = ast::MethodCallExpr::cast(next_call.receiver()?.syntax().clone())?;
+        let filter_map_name_range = filter_map_call.name_ref()?.ident_token()?.text_range();
+        let filter_map_args = filter_map_call.arg_list()?;
+
+        let range_to_replace =
+            TextRange::new(filter_map_name_range.start(), next_expr.syntax().text_range().end());
+        let replacement = format!("find_map{}", filter_map_args.syntax().text());
+        let trigger_range = next_expr.syntax().text_range();
+
+        let edit = TextEdit::replace(range_to_replace, replacement);
+
+        let source_change = SourceChange::from_text_edit(self.file.original_file(sema.db), edit);
+
+        Some(Fix::new(
+            "Replace filter_map(..).next() with find_map()",
+            source_change,
+            trigger_range,
+        ))
+    }
+}
+
 fn missing_record_expr_field_fix(
     sema: &Semantics<RootDatabase>,
     usage_file_id: FileId,