about summary refs log tree commit diff
path: root/compiler/rustc_span/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_span/src')
-rw-r--r--compiler/rustc_span/src/lev_distance.rs104
-rw-r--r--compiler/rustc_span/src/lev_distance/tests.rs56
-rw-r--r--compiler/rustc_span/src/lib.rs1
3 files changed, 161 insertions, 0 deletions
diff --git a/compiler/rustc_span/src/lev_distance.rs b/compiler/rustc_span/src/lev_distance.rs
new file mode 100644
index 00000000000..edc6625a6ea
--- /dev/null
+++ b/compiler/rustc_span/src/lev_distance.rs
@@ -0,0 +1,104 @@
+use crate::symbol::Symbol;
+use std::cmp;
+
+#[cfg(test)]
+mod tests;
+
+/// Finds the Levenshtein distance between two strings
+pub fn lev_distance(a: &str, b: &str) -> usize {
+    // cases which don't require further computation
+    if a.is_empty() {
+        return b.chars().count();
+    } else if b.is_empty() {
+        return a.chars().count();
+    }
+
+    let mut dcol: Vec<_> = (0..=b.len()).collect();
+    let mut t_last = 0;
+
+    for (i, sc) in a.chars().enumerate() {
+        let mut current = i;
+        dcol[0] = current + 1;
+
+        for (j, tc) in b.chars().enumerate() {
+            let next = dcol[j + 1];
+            if sc == tc {
+                dcol[j + 1] = current;
+            } else {
+                dcol[j + 1] = cmp::min(current, next);
+                dcol[j + 1] = cmp::min(dcol[j + 1], dcol[j]) + 1;
+            }
+            current = next;
+            t_last = j;
+        }
+    }
+    dcol[t_last + 1]
+}
+
+/// Finds the best match for a given word in the given iterator
+///
+/// As a loose rule to avoid the obviously incorrect suggestions, it takes
+/// an optional limit for the maximum allowable edit distance, which defaults
+/// to one-third of the given word.
+///
+/// Besides Levenshtein, we use case insensitive comparison to improve accuracy on an edge case with
+/// a lower(upper)case letters mismatch.
+#[cold]
+pub fn find_best_match_for_name(
+    name_vec: &[Symbol],
+    lookup: Symbol,
+    dist: Option<usize>,
+) -> Option<Symbol> {
+    let lookup = &lookup.as_str();
+    let max_dist = dist.unwrap_or_else(|| cmp::max(lookup.len(), 3) / 3);
+
+    let (case_insensitive_match, levenshtein_match) = name_vec
+        .iter()
+        .filter_map(|&name| {
+            let dist = lev_distance(lookup, &name.as_str());
+            if dist <= max_dist { Some((name, dist)) } else { None }
+        })
+        // Here we are collecting the next structure:
+        // (case_insensitive_match, (levenshtein_match, levenshtein_distance))
+        .fold((None, None), |result, (candidate, dist)| {
+            (
+                if candidate.as_str().to_uppercase() == lookup.to_uppercase() {
+                    Some(candidate)
+                } else {
+                    result.0
+                },
+                match result.1 {
+                    None => Some((candidate, dist)),
+                    Some((c, d)) => Some(if dist < d { (candidate, dist) } else { (c, d) }),
+                },
+            )
+        });
+    // Priority of matches:
+    // 1. Exact case insensitive match
+    // 2. Levenshtein distance match
+    // 3. Sorted word match
+    if let Some(candidate) = case_insensitive_match {
+        Some(candidate)
+    } else if levenshtein_match.is_some() {
+        levenshtein_match.map(|(candidate, _)| candidate)
+    } else {
+        find_match_by_sorted_words(name_vec, lookup)
+    }
+}
+
+fn find_match_by_sorted_words(iter_names: &[Symbol], lookup: &str) -> Option<Symbol> {
+    iter_names.iter().fold(None, |result, candidate| {
+        if sort_by_words(&candidate.as_str()) == sort_by_words(lookup) {
+            Some(*candidate)
+        } else {
+            result
+        }
+    })
+}
+
+fn sort_by_words(name: &str) -> String {
+    let mut split_words: Vec<&str> = name.split('_').collect();
+    // We are sorting primitive &strs and can use unstable sort here
+    split_words.sort_unstable();
+    split_words.join("_")
+}
diff --git a/compiler/rustc_span/src/lev_distance/tests.rs b/compiler/rustc_span/src/lev_distance/tests.rs
new file mode 100644
index 00000000000..7aa01cb8efe
--- /dev/null
+++ b/compiler/rustc_span/src/lev_distance/tests.rs
@@ -0,0 +1,56 @@
+use super::*;
+
+#[test]
+fn test_lev_distance() {
+    use std::char::{from_u32, MAX};
+    // Test bytelength agnosticity
+    for c in (0..MAX as u32).filter_map(|i| from_u32(i)).map(|i| i.to_string()) {
+        assert_eq!(lev_distance(&c[..], &c[..]), 0);
+    }
+
+    let a = "\nMäry häd ä little lämb\n\nLittle lämb\n";
+    let b = "\nMary häd ä little lämb\n\nLittle lämb\n";
+    let c = "Mary häd ä little lämb\n\nLittle lämb\n";
+    assert_eq!(lev_distance(a, b), 1);
+    assert_eq!(lev_distance(b, a), 1);
+    assert_eq!(lev_distance(a, c), 2);
+    assert_eq!(lev_distance(c, a), 2);
+    assert_eq!(lev_distance(b, c), 1);
+    assert_eq!(lev_distance(c, b), 1);
+}
+
+#[test]
+fn test_find_best_match_for_name() {
+    use crate::with_default_session_globals;
+    with_default_session_globals(|| {
+        let input = vec![Symbol::intern("aaab"), Symbol::intern("aaabc")];
+        assert_eq!(
+            find_best_match_for_name(&input, Symbol::intern("aaaa"), None),
+            Some(Symbol::intern("aaab"))
+        );
+
+        assert_eq!(find_best_match_for_name(&input, Symbol::intern("1111111111"), None), None);
+
+        let input = vec![Symbol::intern("aAAA")];
+        assert_eq!(
+            find_best_match_for_name(&input, Symbol::intern("AAAA"), None),
+            Some(Symbol::intern("aAAA"))
+        );
+
+        let input = vec![Symbol::intern("AAAA")];
+        // Returns None because `lev_distance > max_dist / 3`
+        assert_eq!(find_best_match_for_name(&input, Symbol::intern("aaaa"), None), None);
+
+        let input = vec![Symbol::intern("AAAA")];
+        assert_eq!(
+            find_best_match_for_name(&input, Symbol::intern("aaaa"), Some(4)),
+            Some(Symbol::intern("AAAA"))
+        );
+
+        let input = vec![Symbol::intern("a_longer_variable_name")];
+        assert_eq!(
+            find_best_match_for_name(&input, Symbol::intern("a_variable_longer_name"), None),
+            Some(Symbol::intern("a_longer_variable_name"))
+        );
+    })
+}
diff --git a/compiler/rustc_span/src/lib.rs b/compiler/rustc_span/src/lib.rs
index 0926561f4c5..11a49d1ab88 100644
--- a/compiler/rustc_span/src/lib.rs
+++ b/compiler/rustc_span/src/lib.rs
@@ -34,6 +34,7 @@ use hygiene::Transparency;
 pub use hygiene::{DesugaringKind, ExpnData, ExpnId, ExpnKind, ForLoopLoc, MacroKind};
 pub mod def_id;
 use def_id::{CrateNum, DefId, LOCAL_CRATE};
+pub mod lev_distance;
 mod span_encoding;
 pub use span_encoding::{Span, DUMMY_SP};