about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2022-08-02 19:50:01 +0000
committerbors <bors@rust-lang.org>2022-08-02 19:50:01 +0000
commit1c03f45c08145be5b1ecd40aa1f28a12c615d63a (patch)
treeb3f0239668b4e9b1aa520e15f04cdc25485f5240
parent113f1dbc9102d8eb693fefc1f369868c2a497910 (diff)
parent1980c1192c7bcfb10e1ed380ee90bcaa7153d58d (diff)
downloadrust-1c03f45c08145be5b1ecd40aa1f28a12c615d63a.tar.gz
rust-1c03f45c08145be5b1ecd40aa1f28a12c615d63a.zip
Auto merge of #12837 - DorianListens:dscheidt/generate-enum-data, r=Veykril
feat: support associated values in "Generate Enum Variant" assist

This change adds support for associated values to the "Generate Enum Variant" assist.

I've split the implementation out into 4 steps to make code review easier:
- Add "add_variant" support to the structural ast editing system in `edit_in_place`
- Migrate `generate_enum_variant` to use structural ast editing instead of string manipulation
- Support tuple fields
- Support record fields

Please let me know if I should leave the commits as-is, or squash before merging.

Fixes #12797
-rw-r--r--crates/ide-assists/src/handlers/generate_enum_variant.rs375
-rw-r--r--crates/syntax/src/ast/edit_in_place.rs171
-rw-r--r--crates/syntax/src/ast/make.rs5
3 files changed, 504 insertions, 47 deletions
diff --git a/crates/ide-assists/src/handlers/generate_enum_variant.rs b/crates/ide-assists/src/handlers/generate_enum_variant.rs
index 4461fbd5ac8..35cd42908af 100644
--- a/crates/ide-assists/src/handlers/generate_enum_variant.rs
+++ b/crates/ide-assists/src/handlers/generate_enum_variant.rs
@@ -1,8 +1,8 @@
-use hir::{HasSource, InFile};
+use hir::{HasSource, HirDisplay, InFile};
 use ide_db::assists::{AssistId, AssistKind};
 use syntax::{
-    ast::{self, edit::IndentLevel},
-    AstNode, TextSize,
+    ast::{self, make, HasArgList},
+    match_ast, AstNode, SyntaxNode,
 };
 
 use crate::assist_context::{AssistContext, Assists};
@@ -32,8 +32,8 @@ use crate::assist_context::{AssistContext, Assists};
 // }
 // ```
 pub(crate) fn generate_enum_variant(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
-    let path_expr: ast::PathExpr = ctx.find_node_at_offset()?;
-    let path = path_expr.path()?;
+    let path: ast::Path = ctx.find_node_at_offset()?;
+    let parent = path_parent(&path)?;
 
     if ctx.sema.resolve_path(&path).is_some() {
         // No need to generate anything if the path resolves
@@ -50,26 +50,71 @@ pub(crate) fn generate_enum_variant(acc: &mut Assists, ctx: &AssistContext<'_>)
         ctx.sema.resolve_path(&path.qualifier()?)
     {
         let target = path.syntax().text_range();
-        return add_variant_to_accumulator(acc, ctx, target, e, &name_ref);
+        return add_variant_to_accumulator(acc, ctx, target, e, &name_ref, parent);
     }
 
     None
 }
 
+#[derive(Debug)]
+enum PathParent {
+    PathExpr(ast::PathExpr),
+    RecordExpr(ast::RecordExpr),
+    PathPat(ast::PathPat),
+    UseTree(ast::UseTree),
+}
+
+impl PathParent {
+    fn syntax(&self) -> &SyntaxNode {
+        match self {
+            PathParent::PathExpr(it) => it.syntax(),
+            PathParent::RecordExpr(it) => it.syntax(),
+            PathParent::PathPat(it) => it.syntax(),
+            PathParent::UseTree(it) => it.syntax(),
+        }
+    }
+
+    fn make_field_list(&self, ctx: &AssistContext<'_>) -> Option<ast::FieldList> {
+        let scope = ctx.sema.scope(self.syntax())?;
+
+        match self {
+            PathParent::PathExpr(it) => {
+                if let Some(call_expr) = it.syntax().parent().and_then(ast::CallExpr::cast) {
+                    make_tuple_field_list(call_expr, ctx, &scope)
+                } else {
+                    None
+                }
+            }
+            PathParent::RecordExpr(it) => make_record_field_list(it, ctx, &scope),
+            PathParent::UseTree(_) | PathParent::PathPat(_) => None,
+        }
+    }
+}
+
+fn path_parent(path: &ast::Path) -> Option<PathParent> {
+    let parent = path.syntax().parent()?;
+
+    match_ast! {
+        match parent {
+            ast::PathExpr(it) => Some(PathParent::PathExpr(it)),
+            ast::RecordExpr(it) => Some(PathParent::RecordExpr(it)),
+            ast::PathPat(it) => Some(PathParent::PathPat(it)),
+            ast::UseTree(it) => Some(PathParent::UseTree(it)),
+            _ => None
+        }
+    }
+}
+
 fn add_variant_to_accumulator(
     acc: &mut Assists,
     ctx: &AssistContext<'_>,
     target: syntax::TextRange,
     adt: hir::Enum,
     name_ref: &ast::NameRef,
+    parent: PathParent,
 ) -> Option<()> {
     let db = ctx.db();
     let InFile { file_id, value: enum_node } = adt.source(db)?.original_ast_node(db)?;
-    let enum_indent = IndentLevel::from_node(&enum_node.syntax());
-
-    let variant_list = enum_node.variant_list()?;
-    let offset = variant_list.syntax().text_range().end() - TextSize::of('}');
-    let empty_enum = variant_list.variants().next().is_none();
 
     acc.add(
         AssistId("generate_enum_variant", AssistKind::Generate),
@@ -77,18 +122,80 @@ fn add_variant_to_accumulator(
         target,
         |builder| {
             builder.edit_file(file_id.original_file(db));
-            let text = format!(
-                "{maybe_newline}{indent_1}{name},\n{enum_indent}",
-                maybe_newline = if empty_enum { "\n" } else { "" },
-                indent_1 = IndentLevel(1),
-                name = name_ref,
-                enum_indent = enum_indent
-            );
-            builder.insert(offset, text)
+            let node = builder.make_mut(enum_node);
+            let variant = make_variant(ctx, name_ref, parent);
+            node.variant_list().map(|it| it.add_variant(variant.clone_for_update()));
         },
     )
 }
 
+fn make_variant(
+    ctx: &AssistContext<'_>,
+    name_ref: &ast::NameRef,
+    parent: PathParent,
+) -> ast::Variant {
+    let field_list = parent.make_field_list(ctx);
+    make::variant(make::name(&name_ref.text()), field_list)
+}
+
+fn make_record_field_list(
+    record: &ast::RecordExpr,
+    ctx: &AssistContext<'_>,
+    scope: &hir::SemanticsScope<'_>,
+) -> Option<ast::FieldList> {
+    let fields = record.record_expr_field_list()?.fields();
+    let record_fields = fields.map(|field| {
+        let name = name_from_field(&field);
+
+        let ty = field
+            .expr()
+            .and_then(|it| expr_ty(ctx, it, scope))
+            .unwrap_or_else(make::ty_placeholder);
+
+        make::record_field(None, name, ty)
+    });
+    Some(make::record_field_list(record_fields).into())
+}
+
+fn name_from_field(field: &ast::RecordExprField) -> ast::Name {
+    let text = match field.name_ref() {
+        Some(it) => it.to_string(),
+        None => name_from_field_shorthand(field).unwrap_or("unknown".to_string()),
+    };
+    make::name(&text)
+}
+
+fn name_from_field_shorthand(field: &ast::RecordExprField) -> Option<String> {
+    let path = match field.expr()? {
+        ast::Expr::PathExpr(path_expr) => path_expr.path(),
+        _ => None,
+    }?;
+    Some(path.as_single_name_ref()?.to_string())
+}
+
+fn make_tuple_field_list(
+    call_expr: ast::CallExpr,
+    ctx: &AssistContext<'_>,
+    scope: &hir::SemanticsScope<'_>,
+) -> Option<ast::FieldList> {
+    let args = call_expr.arg_list()?.args();
+    let tuple_fields = args.map(|arg| {
+        let ty = expr_ty(ctx, arg, &scope).unwrap_or_else(make::ty_placeholder);
+        make::tuple_field(None, ty)
+    });
+    Some(make::tuple_field_list(tuple_fields).into())
+}
+
+fn expr_ty(
+    ctx: &AssistContext<'_>,
+    arg: ast::Expr,
+    scope: &hir::SemanticsScope<'_>,
+) -> Option<ast::Type> {
+    let ty = ctx.sema.type_of_expr(&arg).map(|it| it.adjusted())?;
+    let text = ty.display_source_code(ctx.db(), scope.module().into()).ok()?;
+    Some(make::ty(&text))
+}
+
 #[cfg(test)]
 mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable};
@@ -224,4 +331,234 @@ fn main() {
 ",
         )
     }
+
+    #[test]
+    fn associated_single_element_tuple() {
+        check_assist(
+            generate_enum_variant,
+            r"
+enum Foo {}
+fn main() {
+    Foo::Bar$0(true)
+}
+",
+            r"
+enum Foo {
+    Bar(bool),
+}
+fn main() {
+    Foo::Bar(true)
+}
+",
+        )
+    }
+
+    #[test]
+    fn associated_single_element_tuple_unknown_type() {
+        check_assist(
+            generate_enum_variant,
+            r"
+enum Foo {}
+fn main() {
+    Foo::Bar$0(x)
+}
+",
+            r"
+enum Foo {
+    Bar(_),
+}
+fn main() {
+    Foo::Bar(x)
+}
+",
+        )
+    }
+
+    #[test]
+    fn associated_multi_element_tuple() {
+        check_assist(
+            generate_enum_variant,
+            r"
+struct Struct {}
+enum Foo {}
+fn main() {
+    Foo::Bar$0(true, x, Struct {})
+}
+",
+            r"
+struct Struct {}
+enum Foo {
+    Bar(bool, _, Struct),
+}
+fn main() {
+    Foo::Bar(true, x, Struct {})
+}
+",
+        )
+    }
+
+    #[test]
+    fn associated_record() {
+        check_assist(
+            generate_enum_variant,
+            r"
+enum Foo {}
+fn main() {
+    Foo::$0Bar { x: true }
+}
+",
+            r"
+enum Foo {
+    Bar { x: bool },
+}
+fn main() {
+    Foo::Bar { x: true }
+}
+",
+        )
+    }
+
+    #[test]
+    fn associated_record_unknown_type() {
+        check_assist(
+            generate_enum_variant,
+            r"
+enum Foo {}
+fn main() {
+    Foo::$0Bar { x: y }
+}
+",
+            r"
+enum Foo {
+    Bar { x: _ },
+}
+fn main() {
+    Foo::Bar { x: y }
+}
+",
+        )
+    }
+
+    #[test]
+    fn associated_record_field_shorthand() {
+        check_assist(
+            generate_enum_variant,
+            r"
+enum Foo {}
+fn main() {
+    let x = true;
+    Foo::$0Bar { x }
+}
+",
+            r"
+enum Foo {
+    Bar { x: bool },
+}
+fn main() {
+    let x = true;
+    Foo::Bar { x }
+}
+",
+        )
+    }
+
+    #[test]
+    fn associated_record_field_shorthand_unknown_type() {
+        check_assist(
+            generate_enum_variant,
+            r"
+enum Foo {}
+fn main() {
+    Foo::$0Bar { x }
+}
+",
+            r"
+enum Foo {
+    Bar { x: _ },
+}
+fn main() {
+    Foo::Bar { x }
+}
+",
+        )
+    }
+
+    #[test]
+    fn associated_record_field_multiple_fields() {
+        check_assist(
+            generate_enum_variant,
+            r"
+struct Struct {}
+enum Foo {}
+fn main() {
+    Foo::$0Bar { x, y: x, s: Struct {} }
+}
+",
+            r"
+struct Struct {}
+enum Foo {
+    Bar { x: _, y: _, s: Struct },
+}
+fn main() {
+    Foo::Bar { x, y: x, s: Struct {} }
+}
+",
+        )
+    }
+
+    #[test]
+    fn use_tree() {
+        check_assist(
+            generate_enum_variant,
+            r"
+//- /main.rs
+mod foo;
+use foo::Foo::Bar$0;
+
+//- /foo.rs
+enum Foo {}
+",
+            r"
+enum Foo {
+    Bar,
+}
+",
+        )
+    }
+
+    #[test]
+    fn not_applicable_for_path_type() {
+        check_assist_not_applicable(
+            generate_enum_variant,
+            r"
+enum Foo {}
+impl Foo::Bar$0 {}
+",
+        )
+    }
+
+    #[test]
+    fn path_pat() {
+        check_assist(
+            generate_enum_variant,
+            r"
+enum Foo {}
+fn foo(x: Foo) {
+    match x {
+        Foo::Bar$0 =>
+    }
+}
+",
+            r"
+enum Foo {
+    Bar,
+}
+fn foo(x: Foo) {
+    match x {
+        Foo::Bar =>
+    }
+}
+",
+        )
+    }
 }
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs
index e3e928aecd4..8efd58e2c39 100644
--- a/crates/syntax/src/ast/edit_in_place.rs
+++ b/crates/syntax/src/ast/edit_in_place.rs
@@ -11,7 +11,7 @@ use crate::{
     ted::{self, Position},
     AstNode, AstToken, Direction,
     SyntaxKind::{ATTR, COMMENT, WHITESPACE},
-    SyntaxNode,
+    SyntaxNode, SyntaxToken,
 };
 
 use super::HasName;
@@ -506,19 +506,7 @@ impl ast::RecordExprFieldList {
 
         let position = match self.fields().last() {
             Some(last_field) => {
-                let comma = match last_field
-                    .syntax()
-                    .siblings_with_tokens(Direction::Next)
-                    .filter_map(|it| it.into_token())
-                    .find(|it| it.kind() == T![,])
-                {
-                    Some(it) => it,
-                    None => {
-                        let comma = ast::make::token(T![,]);
-                        ted::insert(Position::after(last_field.syntax()), &comma);
-                        comma
-                    }
-                };
+                let comma = get_or_insert_comma_after(last_field.syntax());
                 Position::after(comma)
             }
             None => match self.l_curly_token() {
@@ -579,19 +567,8 @@ impl ast::RecordPatFieldList {
 
         let position = match self.fields().last() {
             Some(last_field) => {
-                let comma = match last_field
-                    .syntax()
-                    .siblings_with_tokens(Direction::Next)
-                    .filter_map(|it| it.into_token())
-                    .find(|it| it.kind() == T![,])
-                {
-                    Some(it) => it,
-                    None => {
-                        let comma = ast::make::token(T![,]);
-                        ted::insert(Position::after(last_field.syntax()), &comma);
-                        comma
-                    }
-                };
+                let syntax = last_field.syntax();
+                let comma = get_or_insert_comma_after(syntax);
                 Position::after(comma)
             }
             None => match self.l_curly_token() {
@@ -606,12 +583,53 @@ impl ast::RecordPatFieldList {
         }
     }
 }
+
+fn get_or_insert_comma_after(syntax: &SyntaxNode) -> SyntaxToken {
+    let comma = match syntax
+        .siblings_with_tokens(Direction::Next)
+        .filter_map(|it| it.into_token())
+        .find(|it| it.kind() == T![,])
+    {
+        Some(it) => it,
+        None => {
+            let comma = ast::make::token(T![,]);
+            ted::insert(Position::after(syntax), &comma);
+            comma
+        }
+    };
+    comma
+}
+
 impl ast::StmtList {
     pub fn push_front(&self, statement: ast::Stmt) {
         ted::insert(Position::after(self.l_curly_token().unwrap()), statement.syntax());
     }
 }
 
+impl ast::VariantList {
+    pub fn add_variant(&self, variant: ast::Variant) {
+        let (indent, position) = match self.variants().last() {
+            Some(last_item) => (
+                IndentLevel::from_node(last_item.syntax()),
+                Position::after(get_or_insert_comma_after(last_item.syntax())),
+            ),
+            None => match self.l_curly_token() {
+                Some(l_curly) => {
+                    normalize_ws_between_braces(self.syntax());
+                    (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly))
+                }
+                None => (IndentLevel::single(), Position::last_child_of(self.syntax())),
+            },
+        };
+        let elements: Vec<SyntaxElement<_>> = vec![
+            make::tokens::whitespace(&format!("{}{}", "\n", indent)).into(),
+            variant.syntax().clone().into(),
+            ast::make::token(T![,]).into(),
+        ];
+        ted::insert_all(position, elements);
+    }
+}
+
 fn normalize_ws_between_braces(node: &SyntaxNode) -> Option<()> {
     let l = node
         .children_with_tokens()
@@ -661,6 +679,9 @@ impl<N: AstNode + Clone> Indent for N {}
 mod tests {
     use std::fmt;
 
+    use stdx::trim_indent;
+    use test_utils::assert_eq_text;
+
     use crate::SourceFile;
 
     use super::*;
@@ -714,4 +735,100 @@ mod tests {
         }",
         );
     }
+
+    #[test]
+    fn add_variant_to_empty_enum() {
+        let variant = make::variant(make::name("Bar"), None).clone_for_update();
+
+        check_add_variant(
+            r#"
+enum Foo {}
+"#,
+            r#"
+enum Foo {
+    Bar,
+}
+"#,
+            variant,
+        );
+    }
+
+    #[test]
+    fn add_variant_to_non_empty_enum() {
+        let variant = make::variant(make::name("Baz"), None).clone_for_update();
+
+        check_add_variant(
+            r#"
+enum Foo {
+    Bar,
+}
+"#,
+            r#"
+enum Foo {
+    Bar,
+    Baz,
+}
+"#,
+            variant,
+        );
+    }
+
+    #[test]
+    fn add_variant_with_tuple_field_list() {
+        let variant = make::variant(
+            make::name("Baz"),
+            Some(ast::FieldList::TupleFieldList(make::tuple_field_list(std::iter::once(
+                make::tuple_field(None, make::ty("bool")),
+            )))),
+        )
+        .clone_for_update();
+
+        check_add_variant(
+            r#"
+enum Foo {
+    Bar,
+}
+"#,
+            r#"
+enum Foo {
+    Bar,
+    Baz(bool),
+}
+"#,
+            variant,
+        );
+    }
+
+    #[test]
+    fn add_variant_with_record_field_list() {
+        let variant = make::variant(
+            make::name("Baz"),
+            Some(ast::FieldList::RecordFieldList(make::record_field_list(std::iter::once(
+                make::record_field(None, make::name("x"), make::ty("bool")),
+            )))),
+        )
+        .clone_for_update();
+
+        check_add_variant(
+            r#"
+enum Foo {
+    Bar,
+}
+"#,
+            r#"
+enum Foo {
+    Bar,
+    Baz { x: bool },
+}
+"#,
+            variant,
+        );
+    }
+
+    fn check_add_variant(before: &str, expected: &str, variant: ast::Variant) {
+        let enum_ = ast_mut_from_text::<ast::Enum>(before);
+        enum_.variant_list().map(|it| it.add_variant(variant));
+        let after = enum_.to_string();
+        assert_eq_text!(&trim_indent(expected.trim()), &trim_indent(&after.trim()));
+    }
 }
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index 5908dda8e63..037de876d45 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -745,7 +745,10 @@ pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::T
 pub fn variant(name: ast::Name, field_list: Option<ast::FieldList>) -> ast::Variant {
     let field_list = match field_list {
         None => String::new(),
-        Some(it) => format!("{}", it),
+        Some(it) => match it {
+            ast::FieldList::RecordFieldList(record) => format!(" {}", record),
+            ast::FieldList::TupleFieldList(tuple) => format!("{}", tuple),
+        },
     };
     ast_from_text(&format!("enum f {{ {}{} }}", name, field_list))
 }