about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/bool_to_enum.rs6
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs2
-rw-r--r--src/tools/rust-analyzer/crates/syntax/src/ast/edit_in_place.rs8
-rw-r--r--src/tools/rust-analyzer/crates/syntax/src/ast/make.rs28
-rw-r--r--src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs207
5 files changed, 242 insertions, 9 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bool_to_enum.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bool_to_enum.rs
index 605fd140523..f699899066b 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -512,9 +512,11 @@ fn make_bool_enum(make_pub: bool) -> ast::Enum {
     let enum_def = make::enum_(
         if make_pub { Some(make::visibility_pub()) } else { None },
         make::name("Bool"),
+        None,
+        None,
         make::variant_list(vec![
-            make::variant(make::name("True"), None),
-            make::variant(make::name("False"), None),
+            make::variant(None, make::name("True"), None, None),
+            make::variant(None, make::name("False"), None, None),
         ]),
     )
     .clone_for_update();
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs
index 5d584591210..985d14d22af 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs
@@ -137,7 +137,7 @@ fn make_variant(
     parent: PathParent,
 ) -> ast::Variant {
     let field_list = parent.make_field_list(ctx);
-    make::variant(make::name(&name_ref.text()), field_list)
+    make::variant(None, make::name(&name_ref.text()), field_list, None)
 }
 
 fn make_record_field_list(
diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/edit_in_place.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/edit_in_place.rs
index f1286e7aa21..8ec794bfa45 100644
--- a/src/tools/rust-analyzer/crates/syntax/src/ast/edit_in_place.rs
+++ b/src/tools/rust-analyzer/crates/syntax/src/ast/edit_in_place.rs
@@ -1173,7 +1173,7 @@ mod tests {
 
     #[test]
     fn add_variant_to_empty_enum() {
-        let variant = make::variant(make::name("Bar"), None).clone_for_update();
+        let variant = make::variant(None, make::name("Bar"), None, None).clone_for_update();
 
         check_add_variant(
             r#"
@@ -1190,7 +1190,7 @@ enum Foo {
 
     #[test]
     fn add_variant_to_non_empty_enum() {
-        let variant = make::variant(make::name("Baz"), None).clone_for_update();
+        let variant = make::variant(None, make::name("Baz"), None, None).clone_for_update();
 
         check_add_variant(
             r#"
@@ -1211,10 +1211,12 @@ enum Foo {
     #[test]
     fn add_variant_with_tuple_field_list() {
         let variant = make::variant(
+            None,
             make::name("Baz"),
             Some(ast::FieldList::TupleFieldList(make::tuple_field_list(std::iter::once(
                 make::tuple_field(None, make::ty("bool")),
             )))),
+            None,
         )
         .clone_for_update();
 
@@ -1237,10 +1239,12 @@ enum Foo {
     #[test]
     fn add_variant_with_record_field_list() {
         let variant = make::variant(
+            None,
             make::name("Baz"),
             Some(ast::FieldList::RecordFieldList(make::record_field_list(std::iter::once(
                 make::record_field(None, make::name("x"), make::ty("bool")),
             )))),
+            None,
         )
         .clone_for_update();
 
diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs
index 7eb3e08f541..05c2a8354da 100644
--- a/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs
+++ b/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs
@@ -1053,7 +1053,17 @@ pub fn variant_list(variants: impl IntoIterator<Item = ast::Variant>) -> ast::Va
     ast_from_text(&format!("enum f {{ {variants} }}"))
 }
 
-pub fn variant(name: ast::Name, field_list: Option<ast::FieldList>) -> ast::Variant {
+pub fn variant(
+    visibility: Option<ast::Visibility>,
+    name: ast::Name,
+    field_list: Option<ast::FieldList>,
+    discriminant: Option<ast::Expr>,
+) -> ast::Variant {
+    let visibility = match visibility {
+        None => String::new(),
+        Some(it) => format!("{it} "),
+    };
+
     let field_list = match field_list {
         None => String::new(),
         Some(it) => match it {
@@ -1061,7 +1071,12 @@ pub fn variant(name: ast::Name, field_list: Option<ast::FieldList>) -> ast::Vari
             ast::FieldList::TupleFieldList(tuple) => format!("{tuple}"),
         },
     };
-    ast_from_text(&format!("enum f {{ {name}{field_list} }}"))
+
+    let discriminant = match discriminant {
+        Some(it) => format!(" = {it}"),
+        None => String::new(),
+    };
+    ast_from_text(&format!("enum f {{ {visibility}{name}{field_list}{discriminant} }}"))
 }
 
 pub fn fn_(
@@ -1122,6 +1137,8 @@ pub fn struct_(
 pub fn enum_(
     visibility: Option<ast::Visibility>,
     enum_name: ast::Name,
+    generic_param_list: Option<ast::GenericParamList>,
+    where_clause: Option<ast::WhereClause>,
     variant_list: ast::VariantList,
 ) -> ast::Enum {
     let visibility = match visibility {
@@ -1129,7 +1146,12 @@ pub fn enum_(
         Some(it) => format!("{it} "),
     };
 
-    ast_from_text(&format!("{visibility}enum {enum_name} {variant_list}"))
+    let generic_params = generic_param_list.map(|it| it.to_string()).unwrap_or_default();
+    let where_clause = where_clause.map(|it| format!(" {it}")).unwrap_or_default();
+
+    ast_from_text(&format!(
+        "{visibility}enum {enum_name}{generic_params}{where_clause} {variant_list}"
+    ))
 }
 
 pub fn attr_outer(meta: ast::Meta) -> ast::Attr {
diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
index 54f17bd721d..88e9a93cd28 100644
--- a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
+++ b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
@@ -2,7 +2,7 @@
 use itertools::Itertools;
 
 use crate::{
-    ast::{self, make, HasName, HasTypeBounds},
+    ast::{self, make, HasGenericParams, HasName, HasTypeBounds, HasVisibility},
     syntax_editor::SyntaxMappingBuilder,
     AstNode, NodeOrToken, SyntaxKind, SyntaxNode, SyntaxToken,
 };
@@ -169,6 +169,211 @@ impl SyntaxFactory {
         ast
     }
 
+    pub fn record_field_list(
+        &self,
+        fields: impl IntoIterator<Item = ast::RecordField>,
+    ) -> ast::RecordFieldList {
+        let fields: Vec<ast::RecordField> = fields.into_iter().collect();
+        let input: Vec<_> = fields.iter().map(|it| it.syntax().clone()).collect();
+        let ast = make::record_field_list(fields).clone_for_update();
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+
+            builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone()));
+
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
+    pub fn record_field(
+        &self,
+        visibility: Option<ast::Visibility>,
+        name: ast::Name,
+        ty: ast::Type,
+    ) -> ast::RecordField {
+        let ast =
+            make::record_field(visibility.clone(), name.clone(), ty.clone()).clone_for_update();
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+            if let Some(visibility) = visibility {
+                builder.map_node(
+                    visibility.syntax().clone(),
+                    ast.visibility().unwrap().syntax().clone(),
+                );
+            }
+
+            builder.map_node(name.syntax().clone(), ast.name().unwrap().syntax().clone());
+            builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone());
+
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
+    pub fn tuple_field_list(
+        &self,
+        fields: impl IntoIterator<Item = ast::TupleField>,
+    ) -> ast::TupleFieldList {
+        let fields: Vec<ast::TupleField> = fields.into_iter().collect();
+        let input: Vec<_> = fields.iter().map(|it| it.syntax().clone()).collect();
+        let ast = make::tuple_field_list(fields).clone_for_update();
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+
+            builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone()));
+
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
+    pub fn tuple_field(
+        &self,
+        visibility: Option<ast::Visibility>,
+        ty: ast::Type,
+    ) -> ast::TupleField {
+        let ast = make::tuple_field(visibility.clone(), ty.clone()).clone_for_update();
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+            if let Some(visibility) = visibility {
+                builder.map_node(
+                    visibility.syntax().clone(),
+                    ast.visibility().unwrap().syntax().clone(),
+                );
+            }
+
+            builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone());
+
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
+    pub fn item_enum(
+        &self,
+        visibility: Option<ast::Visibility>,
+        name: ast::Name,
+        generic_param_list: Option<ast::GenericParamList>,
+        where_clause: Option<ast::WhereClause>,
+        variant_list: ast::VariantList,
+    ) -> ast::Enum {
+        let ast = make::enum_(
+            visibility.clone(),
+            name.clone(),
+            generic_param_list.clone(),
+            where_clause.clone(),
+            variant_list.clone(),
+        )
+        .clone_for_update();
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+            if let Some(visibility) = visibility {
+                builder.map_node(
+                    visibility.syntax().clone(),
+                    ast.visibility().unwrap().syntax().clone(),
+                );
+            }
+
+            builder.map_node(name.syntax().clone(), ast.name().unwrap().syntax().clone());
+
+            if let Some(generic_param_list) = generic_param_list {
+                builder.map_node(
+                    generic_param_list.syntax().clone(),
+                    ast.generic_param_list().unwrap().syntax().clone(),
+                );
+            }
+
+            if let Some(where_clause) = where_clause {
+                builder.map_node(
+                    where_clause.syntax().clone(),
+                    ast.where_clause().unwrap().syntax().clone(),
+                );
+            }
+
+            builder.map_node(
+                variant_list.syntax().clone(),
+                ast.variant_list().unwrap().syntax().clone(),
+            );
+
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
+    pub fn variant_list(
+        &self,
+        variants: impl IntoIterator<Item = ast::Variant>,
+    ) -> ast::VariantList {
+        let variants: Vec<ast::Variant> = variants.into_iter().collect();
+        let input: Vec<_> = variants.iter().map(|it| it.syntax().clone()).collect();
+        let ast = make::variant_list(variants).clone_for_update();
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+
+            builder.map_children(input.into_iter(), ast.variants().map(|it| it.syntax().clone()));
+
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
+    pub fn variant(
+        &self,
+        visibility: Option<ast::Visibility>,
+        name: ast::Name,
+        field_list: Option<ast::FieldList>,
+        discriminant: Option<ast::Expr>,
+    ) -> ast::Variant {
+        let ast = make::variant(
+            visibility.clone(),
+            name.clone(),
+            field_list.clone(),
+            discriminant.clone(),
+        )
+        .clone_for_update();
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+            if let Some(visibility) = visibility {
+                builder.map_node(
+                    visibility.syntax().clone(),
+                    ast.visibility().unwrap().syntax().clone(),
+                );
+            }
+
+            builder.map_node(name.syntax().clone(), ast.name().unwrap().syntax().clone());
+
+            if let Some(field_list) = field_list {
+                builder.map_node(
+                    field_list.syntax().clone(),
+                    ast.field_list().unwrap().syntax().clone(),
+                );
+            }
+
+            if let Some(discriminant) = discriminant {
+                builder
+                    .map_node(discriminant.syntax().clone(), ast.expr().unwrap().syntax().clone());
+            }
+
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
     pub fn token_tree(
         &self,
         delimiter: SyntaxKind,