about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_parse/src/parser/item.rs2
-rw-r--r--compiler/rustc_span/src/lev_distance.rs67
-rw-r--r--compiler/rustc_span/src/lev_distance/tests.rs22
-rw-r--r--compiler/rustc_span/src/lib.rs1
-rw-r--r--compiler/rustc_typeck/src/check/method/probe.rs6
5 files changed, 56 insertions, 42 deletions
diff --git a/compiler/rustc_parse/src/parser/item.rs b/compiler/rustc_parse/src/parser/item.rs
index ade441b0e7d..06849b31256 100644
--- a/compiler/rustc_parse/src/parser/item.rs
+++ b/compiler/rustc_parse/src/parser/item.rs
@@ -423,7 +423,7 @@ impl<'a> Parser<'a> {
                 // Maybe the user misspelled `macro_rules` (issue #91227)
                 if self.token.is_ident()
                     && path.segments.len() == 1
-                    && lev_distance("macro_rules", &path.segments[0].ident.to_string()) <= 3
+                    && lev_distance("macro_rules", &path.segments[0].ident.to_string(), 3).is_some()
                 {
                     err.span_suggestion(
                         path.span,
diff --git a/compiler/rustc_span/src/lev_distance.rs b/compiler/rustc_span/src/lev_distance.rs
index 6554312b8b9..93cf965f105 100644
--- a/compiler/rustc_span/src/lev_distance.rs
+++ b/compiler/rustc_span/src/lev_distance.rs
@@ -11,16 +11,21 @@ use std::cmp;
 mod tests;
 
 /// Finds the Levenshtein distance between two strings.
-pub fn lev_distance(a: &str, b: &str) -> usize {
-    // cases which don't require further computation
-    if a.is_empty() {
-        return b.chars().count();
-    } else if b.is_empty() {
-        return a.chars().count();
+///
+/// Returns None if the distance exceeds the limit.
+pub fn lev_distance(a: &str, b: &str, limit: usize) -> Option<usize> {
+    let n = a.chars().count();
+    let m = b.chars().count();
+    let min_dist = if n < m { m - n } else { n - m };
+
+    if min_dist > limit {
+        return None;
+    }
+    if n == 0 || m == 0 {
+        return (min_dist <= limit).then_some(min_dist);
     }
 
-    let mut dcol: Vec<_> = (0..=b.len()).collect();
-    let mut t_last = 0;
+    let mut dcol: Vec<_> = (0..=m).collect();
 
     for (i, sc) in a.chars().enumerate() {
         let mut current = i;
@@ -35,10 +40,10 @@ pub fn lev_distance(a: &str, b: &str) -> usize {
                 dcol[j + 1] = cmp::min(dcol[j + 1], dcol[j]) + 1;
             }
             current = next;
-            t_last = j;
         }
     }
-    dcol[t_last + 1]
+
+    (dcol[m] <= limit).then_some(dcol[m])
 }
 
 /// Finds the best match for a given word in the given iterator.
@@ -51,40 +56,38 @@ pub fn lev_distance(a: &str, b: &str) -> usize {
 /// on an edge case with a lower(upper)case letters mismatch.
 #[cold]
 pub fn find_best_match_for_name(
-    name_vec: &[Symbol],
+    candidates: &[Symbol],
     lookup: Symbol,
     dist: Option<usize>,
 ) -> Option<Symbol> {
     let lookup = lookup.as_str();
     let lookup_uppercase = lookup.to_uppercase();
-    let max_dist = dist.unwrap_or_else(|| cmp::max(lookup.len(), 3) / 3);
 
     // Priority of matches:
     // 1. Exact case insensitive match
     // 2. Levenshtein distance match
     // 3. Sorted word match
-    if let Some(case_insensitive_match) =
-        name_vec.iter().find(|candidate| candidate.as_str().to_uppercase() == lookup_uppercase)
-    {
-        return Some(*case_insensitive_match);
+    if let Some(c) = candidates.iter().find(|c| c.as_str().to_uppercase() == lookup_uppercase) {
+        return Some(*c);
     }
-    let levenshtein_match = name_vec
-        .iter()
-        .filter_map(|&name| {
-            let dist = lev_distance(lookup, name.as_str());
-            if dist <= max_dist { Some((name, dist)) } else { None }
-        })
-        // Here we are collecting the next structure:
-        // (levenshtein_match, levenshtein_distance)
-        .fold(None, |result, (candidate, dist)| match result {
-            None => Some((candidate, dist)),
-            Some((c, d)) => Some(if dist < d { (candidate, dist) } else { (c, d) }),
-        });
-    if levenshtein_match.is_some() {
-        levenshtein_match.map(|(candidate, _)| candidate)
-    } else {
-        find_match_by_sorted_words(name_vec, lookup)
+
+    let mut dist = dist.unwrap_or_else(|| cmp::max(lookup.len(), 3) / 3);
+    let mut best = None;
+    for c in candidates {
+        match lev_distance(lookup, c.as_str(), dist) {
+            Some(0) => return Some(*c),
+            Some(d) => {
+                dist = d - 1;
+                best = Some(*c);
+            }
+            None => {}
+        }
     }
+    if best.is_some() {
+        return best;
+    }
+
+    find_match_by_sorted_words(candidates, lookup)
 }
 
 fn find_match_by_sorted_words(iter_names: &[Symbol], lookup: &str) -> Option<Symbol> {
diff --git a/compiler/rustc_span/src/lev_distance/tests.rs b/compiler/rustc_span/src/lev_distance/tests.rs
index b32f8d32c13..4e34219248d 100644
--- a/compiler/rustc_span/src/lev_distance/tests.rs
+++ b/compiler/rustc_span/src/lev_distance/tests.rs
@@ -5,18 +5,26 @@ fn test_lev_distance() {
     use std::char::{from_u32, MAX};
     // Test bytelength agnosticity
     for c in (0..MAX as u32).filter_map(from_u32).map(|i| i.to_string()) {
-        assert_eq!(lev_distance(&c[..], &c[..]), 0);
+        assert_eq!(lev_distance(&c[..], &c[..], usize::MAX), Some(0));
     }
 
     let a = "\nMäry häd ä little lämb\n\nLittle lämb\n";
     let b = "\nMary häd ä little lämb\n\nLittle lämb\n";
     let c = "Mary häd ä little lämb\n\nLittle lämb\n";
-    assert_eq!(lev_distance(a, b), 1);
-    assert_eq!(lev_distance(b, a), 1);
-    assert_eq!(lev_distance(a, c), 2);
-    assert_eq!(lev_distance(c, a), 2);
-    assert_eq!(lev_distance(b, c), 1);
-    assert_eq!(lev_distance(c, b), 1);
+    assert_eq!(lev_distance(a, b, usize::MAX), Some(1));
+    assert_eq!(lev_distance(b, a, usize::MAX), Some(1));
+    assert_eq!(lev_distance(a, c, usize::MAX), Some(2));
+    assert_eq!(lev_distance(c, a, usize::MAX), Some(2));
+    assert_eq!(lev_distance(b, c, usize::MAX), Some(1));
+    assert_eq!(lev_distance(c, b, usize::MAX), Some(1));
+}
+
+#[test]
+fn test_lev_distance_limit() {
+    assert_eq!(lev_distance("abc", "abcd", 1), Some(1));
+    assert_eq!(lev_distance("abc", "abcd", 0), None);
+    assert_eq!(lev_distance("abc", "xyz", 3), Some(3));
+    assert_eq!(lev_distance("abc", "xyz", 2), None);
 }
 
 #[test]
diff --git a/compiler/rustc_span/src/lib.rs b/compiler/rustc_span/src/lib.rs
index 92360164a01..29c76027c15 100644
--- a/compiler/rustc_span/src/lib.rs
+++ b/compiler/rustc_span/src/lib.rs
@@ -15,6 +15,7 @@
 
 #![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
 #![feature(array_windows)]
+#![feature(bool_to_option)]
 #![feature(crate_visibility_modifier)]
 #![feature(if_let_guard)]
 #![feature(negative_impls)]
diff --git a/compiler/rustc_typeck/src/check/method/probe.rs b/compiler/rustc_typeck/src/check/method/probe.rs
index d082b4aac48..3815fd1992b 100644
--- a/compiler/rustc_typeck/src/check/method/probe.rs
+++ b/compiler/rustc_typeck/src/check/method/probe.rs
@@ -1907,8 +1907,10 @@ impl<'a, 'tcx> ProbeContext<'a, 'tcx> {
                         if x.kind.namespace() != Namespace::ValueNS {
                             return false;
                         }
-                        let dist = lev_distance(name.as_str(), x.name.as_str());
-                        dist > 0 && dist <= max_dist
+                        match lev_distance(name.as_str(), x.name.as_str(), max_dist) {
+                            Some(d) => d > 0,
+                            None => false,
+                        }
                     })
                     .copied()
                     .collect()