about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDorian Scheidt <dorian.scheidt@gmail.com>2022-07-20 12:03:18 -0500
committerDorian Scheidt <dorian.scheidt@gmail.com>2022-08-02 14:37:12 -0400
commit405dd77d30382f417f033e4feba2b2bd02ebe00e (patch)
tree15de5cc10a88013cc5cd551578463ce3a87b0a02
parent113f1dbc9102d8eb693fefc1f369868c2a497910 (diff)
downloadrust-405dd77d30382f417f033e4feba2b2bd02ebe00e.tar.gz
rust-405dd77d30382f417f033e4feba2b2bd02ebe00e.zip
Support adding variants via structural editing
-rw-r--r--crates/syntax/src/ast/edit_in_place.rs171
-rw-r--r--crates/syntax/src/ast/make.rs5
2 files changed, 148 insertions, 28 deletions
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs
index e3e928aecd4..8efd58e2c39 100644
--- a/crates/syntax/src/ast/edit_in_place.rs
+++ b/crates/syntax/src/ast/edit_in_place.rs
@@ -11,7 +11,7 @@ use crate::{
     ted::{self, Position},
     AstNode, AstToken, Direction,
     SyntaxKind::{ATTR, COMMENT, WHITESPACE},
-    SyntaxNode,
+    SyntaxNode, SyntaxToken,
 };
 
 use super::HasName;
@@ -506,19 +506,7 @@ impl ast::RecordExprFieldList {
 
         let position = match self.fields().last() {
             Some(last_field) => {
-                let comma = match last_field
-                    .syntax()
-                    .siblings_with_tokens(Direction::Next)
-                    .filter_map(|it| it.into_token())
-                    .find(|it| it.kind() == T![,])
-                {
-                    Some(it) => it,
-                    None => {
-                        let comma = ast::make::token(T![,]);
-                        ted::insert(Position::after(last_field.syntax()), &comma);
-                        comma
-                    }
-                };
+                let comma = get_or_insert_comma_after(last_field.syntax());
                 Position::after(comma)
             }
             None => match self.l_curly_token() {
@@ -579,19 +567,8 @@ impl ast::RecordPatFieldList {
 
         let position = match self.fields().last() {
             Some(last_field) => {
-                let comma = match last_field
-                    .syntax()
-                    .siblings_with_tokens(Direction::Next)
-                    .filter_map(|it| it.into_token())
-                    .find(|it| it.kind() == T![,])
-                {
-                    Some(it) => it,
-                    None => {
-                        let comma = ast::make::token(T![,]);
-                        ted::insert(Position::after(last_field.syntax()), &comma);
-                        comma
-                    }
-                };
+                let syntax = last_field.syntax();
+                let comma = get_or_insert_comma_after(syntax);
                 Position::after(comma)
             }
             None => match self.l_curly_token() {
@@ -606,12 +583,53 @@ impl ast::RecordPatFieldList {
         }
     }
 }
+
+fn get_or_insert_comma_after(syntax: &SyntaxNode) -> SyntaxToken {
+    let comma = match syntax
+        .siblings_with_tokens(Direction::Next)
+        .filter_map(|it| it.into_token())
+        .find(|it| it.kind() == T![,])
+    {
+        Some(it) => it,
+        None => {
+            let comma = ast::make::token(T![,]);
+            ted::insert(Position::after(syntax), &comma);
+            comma
+        }
+    };
+    comma
+}
+
 impl ast::StmtList {
     pub fn push_front(&self, statement: ast::Stmt) {
         ted::insert(Position::after(self.l_curly_token().unwrap()), statement.syntax());
     }
 }
 
+impl ast::VariantList {
+    pub fn add_variant(&self, variant: ast::Variant) {
+        let (indent, position) = match self.variants().last() {
+            Some(last_item) => (
+                IndentLevel::from_node(last_item.syntax()),
+                Position::after(get_or_insert_comma_after(last_item.syntax())),
+            ),
+            None => match self.l_curly_token() {
+                Some(l_curly) => {
+                    normalize_ws_between_braces(self.syntax());
+                    (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly))
+                }
+                None => (IndentLevel::single(), Position::last_child_of(self.syntax())),
+            },
+        };
+        let elements: Vec<SyntaxElement<_>> = vec![
+            make::tokens::whitespace(&format!("{}{}", "\n", indent)).into(),
+            variant.syntax().clone().into(),
+            ast::make::token(T![,]).into(),
+        ];
+        ted::insert_all(position, elements);
+    }
+}
+
 fn normalize_ws_between_braces(node: &SyntaxNode) -> Option<()> {
     let l = node
         .children_with_tokens()
@@ -661,6 +679,9 @@ impl<N: AstNode + Clone> Indent for N {}
 mod tests {
     use std::fmt;
 
+    use stdx::trim_indent;
+    use test_utils::assert_eq_text;
+
     use crate::SourceFile;
 
     use super::*;
@@ -714,4 +735,100 @@ mod tests {
         }",
         );
     }
+
+    #[test]
+    fn add_variant_to_empty_enum() {
+        let variant = make::variant(make::name("Bar"), None).clone_for_update();
+
+        check_add_variant(
+            r#"
+enum Foo {}
+"#,
+            r#"
+enum Foo {
+    Bar,
+}
+"#,
+            variant,
+        );
+    }
+
+    #[test]
+    fn add_variant_to_non_empty_enum() {
+        let variant = make::variant(make::name("Baz"), None).clone_for_update();
+
+        check_add_variant(
+            r#"
+enum Foo {
+    Bar,
+}
+"#,
+            r#"
+enum Foo {
+    Bar,
+    Baz,
+}
+"#,
+            variant,
+        );
+    }
+
+    #[test]
+    fn add_variant_with_tuple_field_list() {
+        let variant = make::variant(
+            make::name("Baz"),
+            Some(ast::FieldList::TupleFieldList(make::tuple_field_list(std::iter::once(
+                make::tuple_field(None, make::ty("bool")),
+            )))),
+        )
+        .clone_for_update();
+
+        check_add_variant(
+            r#"
+enum Foo {
+    Bar,
+}
+"#,
+            r#"
+enum Foo {
+    Bar,
+    Baz(bool),
+}
+"#,
+            variant,
+        );
+    }
+
+    #[test]
+    fn add_variant_with_record_field_list() {
+        let variant = make::variant(
+            make::name("Baz"),
+            Some(ast::FieldList::RecordFieldList(make::record_field_list(std::iter::once(
+                make::record_field(None, make::name("x"), make::ty("bool")),
+            )))),
+        )
+        .clone_for_update();
+
+        check_add_variant(
+            r#"
+enum Foo {
+    Bar,
+}
+"#,
+            r#"
+enum Foo {
+    Bar,
+    Baz { x: bool },
+}
+"#,
+            variant,
+        );
+    }
+
+    fn check_add_variant(before: &str, expected: &str, variant: ast::Variant) {
+        let enum_ = ast_mut_from_text::<ast::Enum>(before);
+        enum_.variant_list().map(|it| it.add_variant(variant));
+        let after = enum_.to_string();
+        assert_eq_text!(&trim_indent(expected.trim()), &trim_indent(&after.trim()));
+    }
 }
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index 5908dda8e63..037de876d45 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -745,7 +745,10 @@ pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::T
 pub fn variant(name: ast::Name, field_list: Option<ast::FieldList>) -> ast::Variant {
     let field_list = match field_list {
         None => String::new(),
-        Some(it) => format!("{}", it),
+        Some(it) => match it {
+            ast::FieldList::RecordFieldList(record) => format!(" {}", record),
+            ast::FieldList::TupleFieldList(tuple) => format!("{}", tuple),
+        },
     };
     ast_from_text(&format!("enum f {{ {}{} }}", name, field_list))
 }