about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMikhail Modin <mikhailm1@gmail.com>2020-02-10 22:45:38 +0000
committerMikhail Modin <mikhailm1@gmail.com>2020-02-14 21:45:42 +0000
commitf8f454ab5c19c6e7d91b3a4e6bb63fb9bf5f2673 (patch)
treee80fbf31a2f69916c86b5569da4f673e7818d8ec
parent6fb36dfdcb91f67c28f51e51514ebe420ec3aa22 (diff)
downloadrust-f8f454ab5c19c6e7d91b3a4e6bb63fb9bf5f2673.tar.gz
rust-f8f454ab5c19c6e7d91b3a4e6bb63fb9bf5f2673.zip
Init implementation of structural search replace
-rw-r--r--crates/ra_ide/src/lib.rs12
-rw-r--r--crates/ra_ide/src/ssr.rs324
-rw-r--r--crates/ra_lsp_server/src/main_loop.rs1
-rw-r--r--crates/ra_lsp_server/src/main_loop/handlers.rs5
-rw-r--r--crates/ra_lsp_server/src/req.rs13
-rw-r--r--crates/ra_syntax/src/ast/make.rs2
-rw-r--r--editors/code/package.json5
-rw-r--r--editors/code/src/commands/index.ts1
-rw-r--r--editors/code/src/commands/ssr.ts36
-rw-r--r--editors/code/src/main.ts1
10 files changed, 399 insertions, 1 deletions
diff --git a/crates/ra_ide/src/lib.rs b/crates/ra_ide/src/lib.rs
index 689921f3f23..dfd191e4267 100644
--- a/crates/ra_ide/src/lib.rs
+++ b/crates/ra_ide/src/lib.rs
@@ -37,6 +37,7 @@ mod display;
 mod inlay_hints;
 mod expand;
 mod expand_macro;
+mod ssr;
 
 #[cfg(test)]
 mod marks;
@@ -73,6 +74,7 @@ pub use crate::{
     },
     runnables::{Runnable, RunnableKind},
     source_change::{FileSystemEdit, SourceChange, SourceFileEdit},
+    ssr::SsrError,
     syntax_highlighting::HighlightedRange,
 };
 
@@ -464,6 +466,16 @@ impl Analysis {
         self.with_db(|db| references::rename(db, position, new_name))
     }
 
+    pub fn structural_search_replace(
+        &self,
+        query: &str,
+    ) -> Cancelable<Result<SourceChange, SsrError>> {
+        self.with_db(|db| {
+            let edits = ssr::parse_search_replace(query, db)?;
+            Ok(SourceChange::source_file_edits("ssr", edits))
+        })
+    }
+
     /// Performs an operation on that may be Canceled.
     fn with_db<F: FnOnce(&RootDatabase) -> T + std::panic::UnwindSafe, T>(
         &self,
diff --git a/crates/ra_ide/src/ssr.rs b/crates/ra_ide/src/ssr.rs
new file mode 100644
index 00000000000..14eb0b8b259
--- /dev/null
+++ b/crates/ra_ide/src/ssr.rs
@@ -0,0 +1,324 @@
+//!  structural search replace
+
+use crate::source_change::SourceFileEdit;
+use ra_ide_db::RootDatabase;
+use ra_syntax::ast::make::expr_from_text;
+use ra_syntax::AstNode;
+use ra_syntax::SyntaxElement;
+use ra_syntax::SyntaxNode;
+use ra_text_edit::{TextEdit, TextEditBuilder};
+use rustc_hash::FxHashMap;
+use std::collections::HashMap;
+use std::str::FromStr;
+
+pub use ra_db::{SourceDatabase, SourceDatabaseExt};
+use ra_ide_db::symbol_index::SymbolsDatabase;
+
+#[derive(Debug, PartialEq)]
+pub struct SsrError(String);
+
+impl std::fmt::Display for SsrError {
+    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+        write!(f, "Parse error: {}", self.0)
+    }
+}
+
+impl std::error::Error for SsrError {}
+
+pub fn parse_search_replace(
+    query: &str,
+    db: &RootDatabase,
+) -> Result<Vec<SourceFileEdit>, SsrError> {
+    let mut edits = vec![];
+    let query: SsrQuery = query.parse()?;
+    for &root in db.local_roots().iter() {
+        let sr = db.source_root(root);
+        for file_id in sr.walk() {
+            dbg!(db.file_relative_path(file_id));
+            let matches = find(&query.pattern, db.parse(file_id).tree().syntax());
+            if !matches.matches.is_empty() {
+                edits.push(SourceFileEdit { file_id, edit: replace(&matches, &query.template) });
+            }
+        }
+    }
+    Ok(edits)
+}
+
+#[derive(Debug)]
+struct SsrQuery {
+    pattern: SsrPattern,
+    template: SsrTemplate,
+}
+
+#[derive(Debug)]
+struct SsrPattern {
+    pattern: SyntaxNode,
+    vars: Vec<Var>,
+}
+
+/// represents an `$var` in an SSR query
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+struct Var(String);
+
+#[derive(Debug)]
+struct SsrTemplate {
+    template: SyntaxNode,
+    placeholders: FxHashMap<SyntaxNode, Var>,
+}
+
+type Binding = HashMap<Var, SyntaxNode>;
+
+#[derive(Debug)]
+struct Match {
+    place: SyntaxNode,
+    binding: Binding,
+}
+
+#[derive(Debug)]
+struct SsrMatches {
+    matches: Vec<Match>,
+}
+
+impl FromStr for SsrQuery {
+    type Err = SsrError;
+
+    fn from_str(query: &str) -> Result<SsrQuery, SsrError> {
+        let mut it = query.split("==>>");
+        let pattern = it.next().expect("at least empty string").trim();
+        let mut template =
+            it.next().ok_or(SsrError("Cannot find delemiter `==>>`".into()))?.trim().to_string();
+        if it.next().is_some() {
+            return Err(SsrError("More than one delimiter found".into()));
+        }
+        let mut vars = vec![];
+        let mut it = pattern.split('$');
+        let mut pattern = it.next().expect("something").to_string();
+
+        for part in it.map(split_by_var) {
+            let (var, var_type, remainder) = part?;
+            is_expr(var_type)?;
+            let new_var = create_name(var, &mut vars)?;
+            pattern.push_str(new_var);
+            pattern.push_str(remainder);
+            template = replace_in_template(template, var, new_var);
+        }
+
+        let template = expr_from_text(&template).syntax().clone();
+        let mut placeholders = FxHashMap::default();
+
+        traverse(&template, &mut |n| {
+            if let Some(v) = vars.iter().find(|v| v.0.as_str() == n.text()) {
+                placeholders.insert(n.clone(), v.clone());
+                false
+            } else {
+                true
+            }
+        });
+
+        let pattern = SsrPattern { pattern: expr_from_text(&pattern).syntax().clone(), vars };
+        let template = SsrTemplate { template, placeholders };
+        Ok(SsrQuery { pattern, template })
+    }
+}
+
+fn traverse(node: &SyntaxNode, go: &mut impl FnMut(&SyntaxNode) -> bool) {
+    if !go(node) {
+        return;
+    }
+    for ref child in node.children() {
+        traverse(child, go);
+    }
+}
+
+fn split_by_var(s: &str) -> Result<(&str, &str, &str), SsrError> {
+    let end_of_name = s.find(":").ok_or(SsrError("Use $<name>:expr".into()))?;
+    let name = &s[0..end_of_name];
+    is_name(name)?;
+    let type_begin = end_of_name + 1;
+    let type_length = s[type_begin..].find(|c| !char::is_ascii_alphanumeric(&c)).unwrap_or(s.len());
+    let type_name = &s[type_begin..type_begin + type_length];
+    Ok((name, type_name, &s[type_begin + type_length..]))
+}
+
+fn is_name(s: &str) -> Result<(), SsrError> {
+    if s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
+        Ok(())
+    } else {
+        Err(SsrError("Name can contain only alphanumerics and _".into()))
+    }
+}
+
+fn is_expr(s: &str) -> Result<(), SsrError> {
+    if s == "expr" {
+        Ok(())
+    } else {
+        Err(SsrError("Only $<name>:expr is supported".into()))
+    }
+}
+
+fn replace_in_template(template: String, var: &str, new_var: &str) -> String {
+    let name = format!("${}", var);
+    template.replace(&name, new_var)
+}
+
+fn create_name<'a>(name: &str, vars: &'a mut Vec<Var>) -> Result<&'a str, SsrError> {
+    let sanitized_name = format!("__search_pattern_{}", name);
+    if vars.iter().any(|a| a.0 == sanitized_name) {
+        return Err(SsrError(format!("Name `{}` repeats more than once", name)));
+    }
+    vars.push(Var(sanitized_name));
+    Ok(&vars.last().unwrap().0)
+}
+
+fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
+    fn check(
+        pattern: &SyntaxElement,
+        code: &SyntaxElement,
+        placeholders: &[Var],
+        match_: &mut Match,
+    ) -> bool {
+        match (pattern, code) {
+            (SyntaxElement::Token(ref pattern), SyntaxElement::Token(ref code)) => {
+                pattern.text() == code.text()
+            }
+            (SyntaxElement::Node(ref pattern), SyntaxElement::Node(ref code)) => {
+                if placeholders.iter().find(|&n| n.0.as_str() == pattern.text()).is_some() {
+                    match_.binding.insert(Var(pattern.text().to_string()), code.clone());
+                    true
+                } else {
+                    pattern.green().children().count() == code.green().children().count()
+                        && pattern
+                            .children_with_tokens()
+                            .zip(code.children_with_tokens())
+                            .all(|(a, b)| check(&a, &b, placeholders, match_))
+                }
+            }
+            _ => false,
+        }
+    }
+    let kind = pattern.pattern.kind();
+    let matches = code
+        .descendants_with_tokens()
+        .filter(|n| n.kind() == kind)
+        .filter_map(|code| {
+            let mut match_ =
+                Match { place: code.as_node().unwrap().clone(), binding: HashMap::new() };
+            if check(
+                &SyntaxElement::from(pattern.pattern.clone()),
+                &code,
+                &pattern.vars,
+                &mut match_,
+            ) {
+                Some(match_)
+            } else {
+                None
+            }
+        })
+        .collect();
+    SsrMatches { matches }
+}
+
+fn replace(matches: &SsrMatches, template: &SsrTemplate) -> TextEdit {
+    let mut builder = TextEditBuilder::default();
+    for match_ in &matches.matches {
+        builder.replace(match_.place.text_range(), render_replace(&match_.binding, template));
+    }
+    builder.finish()
+}
+
+fn render_replace(binding: &Binding, template: &SsrTemplate) -> String {
+    let mut builder = TextEditBuilder::default();
+    for element in template.template.descendants() {
+        if let Some(var) = template.placeholders.get(&element) {
+            builder.replace(element.text_range(), binding[var].to_string())
+        }
+    }
+    builder.finish().apply(&template.template.text().to_string())
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use ra_syntax::SourceFile;
+
+    fn parse_error_text(query: &str) -> String {
+        format!("{}", query.parse::<SsrQuery>().unwrap_err())
+    }
+
+    #[test]
+    fn parser_happy_case() {
+        let result: SsrQuery = "foo($a:expr, $b:expr) ==>> bar($b, $a)".parse().unwrap();
+        assert_eq!(&result.pattern.pattern.text(), "foo(__search_pattern_a, __search_pattern_b)");
+        assert_eq!(result.pattern.vars.len(), 2);
+        assert_eq!(result.pattern.vars[0].0, "__search_pattern_a");
+        assert_eq!(result.pattern.vars[1].0, "__search_pattern_b");
+        assert_eq!(&result.template.template.text(), "bar(__search_pattern_b, __search_pattern_a)");
+        dbg!(result.template.placeholders);
+    }
+
+    #[test]
+    fn parser_empty_query() {
+        assert_eq!(parse_error_text(""), "Parse error: Cannot find delemiter `==>>`");
+    }
+
+    #[test]
+    fn parser_no_delimiter() {
+        assert_eq!(parse_error_text("foo()"), "Parse error: Cannot find delemiter `==>>`");
+    }
+
+    #[test]
+    fn parser_two_delimiters() {
+        assert_eq!(
+            parse_error_text("foo() ==>> a ==>> b "),
+            "Parse error: More than one delimiter found"
+        );
+    }
+
+    #[test]
+    fn parser_no_pattern_type() {
+        assert_eq!(parse_error_text("foo($a) ==>>"), "Parse error: Use $<name>:expr");
+    }
+
+    #[test]
+    fn parser_invalid_name() {
+        assert_eq!(
+            parse_error_text("foo($a+:expr) ==>>"),
+            "Parse error: Name can contain only alphanumerics and _"
+        );
+    }
+
+    #[test]
+    fn parser_invalid_type() {
+        assert_eq!(
+            parse_error_text("foo($a:ident) ==>>"),
+            "Parse error: Only $<name>:expr is supported"
+        );
+    }
+
+    #[test]
+    fn parser_repeated_name() {
+        assert_eq!(
+            parse_error_text("foo($a:expr, $a:expr) ==>>"),
+            "Parse error: Name `a` repeats more than once"
+        );
+    }
+
+    #[test]
+    fn parse_match_replace() {
+        let query: SsrQuery = "foo($x:expr) ==>> bar($x)".parse().unwrap();
+        let input = "fn main() { foo(1+2); }";
+
+        let code = SourceFile::parse(input).tree();
+        let matches = find(&query.pattern, code.syntax());
+        assert_eq!(matches.matches.len(), 1);
+        assert_eq!(matches.matches[0].place.text(), "foo(1+2)");
+        assert_eq!(matches.matches[0].binding.len(), 1);
+        assert_eq!(
+            matches.matches[0].binding[&Var("__search_pattern_x".to_string())].text(),
+            "1+2"
+        );
+
+        let edit = replace(&matches, &query.template);
+        assert_eq!(edit.apply(input), "fn main() { bar(1+2); }");
+    }
+}
diff --git a/crates/ra_lsp_server/src/main_loop.rs b/crates/ra_lsp_server/src/main_loop.rs
index ceff82fda9e..061383e28b8 100644
--- a/crates/ra_lsp_server/src/main_loop.rs
+++ b/crates/ra_lsp_server/src/main_loop.rs
@@ -526,6 +526,7 @@ fn on_request(
         .on::<req::CallHierarchyPrepare>(handlers::handle_call_hierarchy_prepare)?
         .on::<req::CallHierarchyIncomingCalls>(handlers::handle_call_hierarchy_incoming)?
         .on::<req::CallHierarchyOutgoingCalls>(handlers::handle_call_hierarchy_outgoing)?
+        .on::<req::Ssr>(handlers::handle_ssr)?
         .finish();
     Ok(())
 }
diff --git a/crates/ra_lsp_server/src/main_loop/handlers.rs b/crates/ra_lsp_server/src/main_loop/handlers.rs
index 2e598fdcdf1..72bb4861922 100644
--- a/crates/ra_lsp_server/src/main_loop/handlers.rs
+++ b/crates/ra_lsp_server/src/main_loop/handlers.rs
@@ -881,6 +881,11 @@ pub fn handle_document_highlight(
     ))
 }
 
+pub fn handle_ssr(world: WorldSnapshot, params: req::SsrParams) -> Result<req::SourceChange> {
+    let _p = profile("handle_ssr");
+    world.analysis().structural_search_replace(&params.arg)??.try_conv_with(&world)
+}
+
 pub fn publish_diagnostics(world: &WorldSnapshot, file_id: FileId) -> Result<DiagnosticTask> {
     let _p = profile("publish_diagnostics");
     let line_index = world.analysis().file_line_index(file_id)?;
diff --git a/crates/ra_lsp_server/src/req.rs b/crates/ra_lsp_server/src/req.rs
index dc327f53d2e..7ff7f60b31f 100644
--- a/crates/ra_lsp_server/src/req.rs
+++ b/crates/ra_lsp_server/src/req.rs
@@ -206,3 +206,16 @@ pub struct InlayHint {
     pub kind: InlayKind,
     pub label: String,
 }
+
+pub enum Ssr {}
+
+impl Request for Ssr {
+    type Params = SsrParams;
+    type Result = SourceChange;
+    const METHOD: &'static str = "rust-analyzer/ssr";
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+pub struct SsrParams {
+    pub arg: String,
+}
diff --git a/crates/ra_syntax/src/ast/make.rs b/crates/ra_syntax/src/ast/make.rs
index 862eb11728c..89d1403e784 100644
--- a/crates/ra_syntax/src/ast/make.rs
+++ b/crates/ra_syntax/src/ast/make.rs
@@ -84,7 +84,7 @@ pub fn expr_prefix(op: SyntaxKind, expr: ast::Expr) -> ast::Expr {
     let token = token(op);
     expr_from_text(&format!("{}{}", token, expr.syntax()))
 }
-fn expr_from_text(text: &str) -> ast::Expr {
+pub fn expr_from_text(text: &str) -> ast::Expr {
     ast_from_text(&format!("const C: () = {};", text))
 }
 
diff --git a/editors/code/package.json b/editors/code/package.json
index db1fe51893c..e1a70f05cc6 100644
--- a/editors/code/package.json
+++ b/editors/code/package.json
@@ -124,6 +124,11 @@
                 "command": "rust-analyzer.onEnter",
                 "title": "Enhanced enter key",
                 "category": "Rust Analyzer"
+            },
+            {
+                "command": "rust-analyzer.ssr",
+                "title": "Structural Search Replace",
+                "category": "Rust Analyzer"
             }
         ],
         "keybindings": [
diff --git a/editors/code/src/commands/index.ts b/editors/code/src/commands/index.ts
index aee96943201..b5ebec117f3 100644
--- a/editors/code/src/commands/index.ts
+++ b/editors/code/src/commands/index.ts
@@ -12,6 +12,7 @@ export * from './parent_module';
 export * from './syntax_tree';
 export * from './expand_macro';
 export * from './runnables';
+export * from './ssr';
 
 export function collectGarbage(ctx: Ctx): Cmd {
     return async () => {
diff --git a/editors/code/src/commands/ssr.ts b/editors/code/src/commands/ssr.ts
new file mode 100644
index 00000000000..6287bf47b42
--- /dev/null
+++ b/editors/code/src/commands/ssr.ts
@@ -0,0 +1,36 @@
+import { Ctx, Cmd } from '../ctx';
+import { applySourceChange, SourceChange } from '../source_change';
+import * as vscode from 'vscode';
+
+export function ssr(ctx: Ctx): Cmd {
+    return async () => {
+        const client = ctx.client;
+        if (!client) return;
+
+        const options: vscode.InputBoxOptions = {
+            placeHolder: "foo($a:expr, $b:expr) ==>> bar($a, foo($b))",
+            prompt: "Enter request",
+            validateInput: (x: string) => {
+                if (x.includes('==>>')) {
+                    return null;
+                }
+                return "Enter request: pattern ==>> template"
+            }
+        }
+        const request = await vscode.window.showInputBox(options);
+
+        if (!request) return;
+
+        const ssrRequest: SsrRequest = { arg: request };
+        const change = await client.sendRequest<SourceChange>(
+            'rust-analyzer/ssr',
+            ssrRequest,
+        );
+
+        await applySourceChange(ctx, change);
+    };
+}
+
+interface SsrRequest {
+    arg: string;
+}
diff --git a/editors/code/src/main.ts b/editors/code/src/main.ts
index 5efce41f404..5a99e96f0e5 100644
--- a/editors/code/src/main.ts
+++ b/editors/code/src/main.ts
@@ -22,6 +22,7 @@ export async function activate(context: vscode.ExtensionContext) {
     ctx.registerCommand('run', commands.run);
     ctx.registerCommand('reload', commands.reload);
     ctx.registerCommand('onEnter', commands.onEnter);
+    ctx.registerCommand('ssr', commands.ssr)
 
     // Internal commands which are invoked by the server.
     ctx.registerCommand('runSingle', commands.runSingle);