about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors[bot] <26634292+bors[bot]@users.noreply.github.com>2022-01-03 16:00:05 +0000
committerGitHub <noreply@github.com>2022-01-03 16:00:05 +0000
commitb14af5cc6fa9bd704b49091fc290892908e773c6 (patch)
treea74c11a4aa10b5c0307a91e0fcf9a8178f063e97
parent1ba9a924d7b161c52e605e157ee16d582e4a8684 (diff)
parentd77d3234ce861106cbf66738aafa4fafc6bf7db6 (diff)
downloadrust-b14af5cc6fa9bd704b49091fc290892908e773c6.tar.gz
rust-b14af5cc6fa9bd704b49091fc290892908e773c6.zip
Merge #11115
11115: internal: refactor: avoid separate traversal in replace filter map next with find map r=Veykril a=rainy-me

fix: #7428

Co-authored-by: rainy-me <github@yue.coffee>
-rw-r--r--crates/hir_ty/src/diagnostics/expr.rs143
1 files changed, 84 insertions, 59 deletions
diff --git a/crates/hir_ty/src/diagnostics/expr.rs b/crates/hir_ty/src/diagnostics/expr.rs
index a8c4026e31f..b7d765c59b4 100644
--- a/crates/hir_ty/src/diagnostics/expr.rs
+++ b/crates/hir_ty/src/diagnostics/expr.rs
@@ -81,9 +81,8 @@ impl ExprValidator {
     }
 
     fn validate_body(&mut self, db: &dyn HirDatabase) {
-        self.check_for_filter_map_next(db);
-
         let body = db.body(self.owner);
+        let mut filter_map_next_checker = None;
 
         for (id, expr) in body.exprs.iter() {
             if let Some((variant, missed_fields, true)) =
@@ -101,7 +100,7 @@ impl ExprValidator {
                     self.validate_match(id, *expr, arms, db, self.infer.clone());
                 }
                 Expr::Call { .. } | Expr::MethodCall { .. } => {
-                    self.validate_call(db, id, expr);
+                    self.validate_call(db, id, expr, &mut filter_map_next_checker);
                 }
                 _ => {}
             }
@@ -143,58 +142,13 @@ impl ExprValidator {
             });
     }
 
-    fn check_for_filter_map_next(&mut self, db: &dyn HirDatabase) {
-        // Find the FunctionIds for Iterator::filter_map and Iterator::next
-        let iterator_path = path![core::iter::Iterator];
-        let resolver = self.owner.resolver(db.upcast());
-        let iterator_trait_id = match resolver.resolve_known_trait(db.upcast(), &iterator_path) {
-            Some(id) => id,
-            None => return,
-        };
-        let iterator_trait_items = &db.trait_data(iterator_trait_id).items;
-        let filter_map_function_id =
-            match iterator_trait_items.iter().find(|item| item.0 == name![filter_map]) {
-                Some((_, AssocItemId::FunctionId(id))) => id,
-                _ => return,
-            };
-        let next_function_id = match iterator_trait_items.iter().find(|item| item.0 == name![next])
-        {
-            Some((_, AssocItemId::FunctionId(id))) => id,
-            _ => return,
-        };
-
-        // Search function body for instances of .filter_map(..).next()
-        let body = db.body(self.owner);
-        let mut prev = None;
-        for (id, expr) in body.exprs.iter() {
-            if let Expr::MethodCall { receiver, .. } = expr {
-                let function_id = match self.infer.method_resolution(id) {
-                    Some((id, _)) => id,
-                    None => continue,
-                };
-
-                if function_id == *filter_map_function_id {
-                    prev = Some(id);
-                    continue;
-                }
-
-                if function_id == *next_function_id {
-                    if let Some(filter_map_id) = prev {
-                        if *receiver == filter_map_id {
-                            self.diagnostics.push(
-                                BodyValidationDiagnostic::ReplaceFilterMapNextWithFindMap {
-                                    method_call_expr: id,
-                                },
-                            );
-                        }
-                    }
-                }
-            }
-            prev = None;
-        }
-    }
-
-    fn validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr) {
+    fn validate_call(
+        &mut self,
+        db: &dyn HirDatabase,
+        call_id: ExprId,
+        expr: &Expr,
+        filter_map_next_checker: &mut Option<FilterMapNextChecker>,
+    ) {
         // Check that the number of arguments matches the number of parameters.
 
         // FIXME: Due to shortcomings in the current type system implementation, only emit this
@@ -214,6 +168,24 @@ impl ExprValidator {
                 (sig, args.len())
             }
             Expr::MethodCall { receiver, args, .. } => {
+                let (callee, subst) = match self.infer.method_resolution(call_id) {
+                    Some(it) => it,
+                    None => return,
+                };
+
+                if filter_map_next_checker
+                    .get_or_insert_with(|| {
+                        FilterMapNextChecker::new(&self.owner.resolver(db.upcast()), db)
+                    })
+                    .check(call_id, receiver, &callee)
+                    .is_some()
+                {
+                    self.diagnostics.push(
+                        BodyValidationDiagnostic::ReplaceFilterMapNextWithFindMap {
+                            method_call_expr: call_id,
+                        },
+                    );
+                }
                 let receiver = &self.infer.type_of_expr[*receiver];
                 if receiver.strip_references().is_unknown() {
                     // if the receiver is of unknown type, it's very likely we
@@ -222,10 +194,6 @@ impl ExprValidator {
                     return;
                 }
 
-                let (callee, subst) = match self.infer.method_resolution(call_id) {
-                    Some(it) => it,
-                    None => return,
-                };
                 let sig = db.callable_item_signature(callee.into()).substitute(Interner, &subst);
 
                 (sig, args.len() + 1)
@@ -424,6 +392,63 @@ impl ExprValidator {
     }
 }
 
+struct FilterMapNextChecker {
+    filter_map_function_id: Option<hir_def::FunctionId>,
+    next_function_id: Option<hir_def::FunctionId>,
+    prev_filter_map_expr_id: Option<ExprId>,
+}
+
+impl FilterMapNextChecker {
+    fn new(resolver: &hir_def::resolver::Resolver, db: &dyn HirDatabase) -> Self {
+        // Find and store the FunctionIds for Iterator::filter_map and Iterator::next
+        let iterator_path = path![core::iter::Iterator];
+        let mut filter_map_function_id = None;
+        let mut next_function_id = None;
+
+        if let Some(iterator_trait_id) = resolver.resolve_known_trait(db.upcast(), &iterator_path) {
+            let iterator_trait_items = &db.trait_data(iterator_trait_id).items;
+            for item in iterator_trait_items.iter() {
+                if let (name, AssocItemId::FunctionId(id)) = item {
+                    if *name == name![filter_map] {
+                        filter_map_function_id = Some(*id);
+                    }
+                    if *name == name![next] {
+                        next_function_id = Some(*id);
+                    }
+                }
+                if filter_map_function_id.is_some() && next_function_id.is_some() {
+                    break;
+                }
+            }
+        }
+        Self { filter_map_function_id, next_function_id, prev_filter_map_expr_id: None }
+    }
+
+    // check for instances of .filter_map(..).next()
+    fn check(
+        &mut self,
+        current_expr_id: ExprId,
+        receiver_expr_id: &ExprId,
+        function_id: &hir_def::FunctionId,
+    ) -> Option<()> {
+        if *function_id == self.filter_map_function_id? {
+            self.prev_filter_map_expr_id = Some(current_expr_id);
+            return None;
+        }
+
+        if *function_id == self.next_function_id? {
+            if let Some(prev_filter_map_expr_id) = self.prev_filter_map_expr_id {
+                if *receiver_expr_id == prev_filter_map_expr_id {
+                    return Some(());
+                }
+            }
+        }
+
+        self.prev_filter_map_expr_id = None;
+        None
+    }
+}
+
 pub fn record_literal_missing_fields(
     db: &dyn HirDatabase,
     infer: &InferenceResult,