about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors[bot] <26634292+bors[bot]@users.noreply.github.com>2022-02-22 18:46:12 +0000
committerGitHub <noreply@github.com>2022-02-22 18:46:12 +0000
commit0b53744f2d7e0694cd7207cca632fd6de1dc5bff (patch)
tree8a9d1f752fdebad3b3d97836a1f455d604372f15
parent033f91e75d515936cfc0cc529b121f1ca3792179 (diff)
parent0db0dec9993b510efeb61cb1d8ff113270d4ca51 (diff)
downloadrust-0b53744f2d7e0694cd7207cca632fd6de1dc5bff.tar.gz
rust-0b53744f2d7e0694cd7207cca632fd6de1dc5bff.zip
Merge #11461
11461: Extract struct from enum variant filters generics r=jo-goro a=jo-goro

Fixes #11452.

This PR updates extract_struct_from_enum_variant. Extracting a struct `A` form an enum like
```rust
enum X<'a, 'b> {
    A { a: &'a () },
    B { b: &'b () },
}
```
will now be correctly generated as
```rust
struct A<'a> { a: &'a () }

enum X<'a, 'b> {
    A(A<'a>),
    B { b: &'b () },
}
```
instead of the previous
```rust
struct A<'a, 'b>{ a: &'a () } // <- should not have 'b

enum X<'a, 'b> {
    A(A<'a, 'b>),
    B { b: &'b () },
}
```

This also works for generic type parameters and const generics.

Bounds are also copied, however I have not yet implemented a filter for unneeded bounds. Extracting `B` from the following enum
```rust
enum X<'a, 'b: 'a> {
    A { a: &'a () },
    B { b: &'b () },
}
```
will be generated as 
```rust
struct B<'b: 'a> { b: &'b () } // <- should be `struct B<'b> { b: &'b () }`

enum X<'a, 'b: 'a> {
    A { a: &'a () },
    B(B<'b>),
}
```

Extracting bounds with where clauses is also still not implemented.

Co-authored-by: Jonas Goronczy <goronczy.jonas@gmail.com>
-rw-r--r--crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs232
1 files changed, 203 insertions, 29 deletions
diff --git a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
index 82e0970cc4b..1cdd4187af4 100644
--- a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
+++ b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
@@ -11,15 +11,14 @@ use ide_db::{
     search::FileReference,
     RootDatabase,
 };
-use itertools::Itertools;
+use itertools::{Itertools, Position};
 use rustc_hash::FxHashSet;
 use syntax::{
     ast::{
         self, edit::IndentLevel, edit_in_place::Indent, make, AstNode, HasAttrs, HasGenericParams,
-        HasName, HasTypeBounds, HasVisibility,
+        HasName, HasVisibility,
     },
-    match_ast,
-    ted::{self, Position},
+    match_ast, ted, SyntaxElement,
     SyntaxKind::*,
     SyntaxNode, T,
 };
@@ -106,7 +105,12 @@ pub(crate) fn extract_struct_from_enum_variant(
             }
 
             let indent = enum_ast.indent_level();
-            let def = create_struct_def(variant_name.clone(), &variant, &field_list, &enum_ast);
+            let generic_params = enum_ast
+                .generic_param_list()
+                .and_then(|known_generics| extract_generic_params(&known_generics, &field_list));
+            let generics = generic_params.as_ref().map(|generics| generics.clone_for_update());
+            let def =
+                create_struct_def(variant_name.clone(), &variant, &field_list, generics, &enum_ast);
             def.reindent_to(indent);
 
             let start_offset = &variant.parent_enum().syntax().clone();
@@ -118,7 +122,7 @@ pub(crate) fn extract_struct_from_enum_variant(
                 ],
             );
 
-            update_variant(&variant, enum_ast.generic_param_list());
+            update_variant(&variant, generic_params.map(|g| g.clone_for_update()));
         },
     )
 }
@@ -159,10 +163,77 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Va
         .any(|(name, _)| name.to_string() == variant_name.to_string())
 }
 
+fn extract_generic_params(
+    known_generics: &ast::GenericParamList,
+    field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
+) -> Option<ast::GenericParamList> {
+    let mut generics = known_generics.generic_params().map(|param| (param, false)).collect_vec();
+
+    let tagged_one = match field_list {
+        Either::Left(field_list) => field_list
+            .fields()
+            .filter_map(|f| f.ty())
+            .fold(false, |tagged, ty| tag_generics_in_variant(&ty, &mut generics) || tagged),
+        Either::Right(field_list) => field_list
+            .fields()
+            .filter_map(|f| f.ty())
+            .fold(false, |tagged, ty| tag_generics_in_variant(&ty, &mut generics) || tagged),
+    };
+
+    let generics = generics.into_iter().filter_map(|(param, tag)| tag.then(|| param));
+    tagged_one.then(|| make::generic_param_list(generics))
+}
+
+fn tag_generics_in_variant(ty: &ast::Type, generics: &mut [(ast::GenericParam, bool)]) -> bool {
+    let mut tagged_one = false;
+
+    for token in ty.syntax().descendants_with_tokens().filter_map(SyntaxElement::into_token) {
+        for (param, tag) in generics.iter_mut().filter(|(_, tag)| !tag) {
+            match param {
+                ast::GenericParam::LifetimeParam(lt)
+                    if matches!(token.kind(), T![lifetime_ident]) =>
+                {
+                    if let Some(lt) = lt.lifetime() {
+                        if lt.text().as_str() == token.text() {
+                            *tag = true;
+                            tagged_one = true;
+                            break;
+                        }
+                    }
+                }
+                param if matches!(token.kind(), T![ident]) => {
+                    if match param {
+                        ast::GenericParam::ConstParam(konst) => konst
+                            .name()
+                            .map(|name| name.text().as_str() == token.text())
+                            .unwrap_or_default(),
+                        ast::GenericParam::TypeParam(ty) => ty
+                            .name()
+                            .map(|name| name.text().as_str() == token.text())
+                            .unwrap_or_default(),
+                        ast::GenericParam::LifetimeParam(lt) => lt
+                            .lifetime()
+                            .map(|lt| lt.text().as_str() == token.text())
+                            .unwrap_or_default(),
+                    } {
+                        *tag = true;
+                        tagged_one = true;
+                        break;
+                    }
+                }
+                _ => (),
+            }
+        }
+    }
+
+    tagged_one
+}
+
 fn create_struct_def(
     variant_name: ast::Name,
     variant: &ast::Variant,
     field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
+    generics: Option<ast::GenericParamList>,
     enum_: &ast::Enum,
 ) -> ast::Struct {
     let enum_vis = enum_.visibility();
@@ -204,9 +275,7 @@ fn create_struct_def(
 
     field_list.reindent_to(IndentLevel::single());
 
-    // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
-    let strukt = make::struct_(enum_vis, variant_name, enum_.generic_param_list(), field_list)
-        .clone_for_update();
+    let strukt = make::struct_(enum_vis, variant_name, generics, field_list).clone_for_update();
 
     // FIXME: Consider making this an actual function somewhere (like in `AttrsOwnerEdit`) after some deliberation
     let attrs_and_docs = |node: &SyntaxNode| {
@@ -233,36 +302,53 @@ fn create_struct_def(
             _ => tok,
         })
         .collect();
-    ted::insert_all(Position::first_child_of(strukt.syntax()), variant_attrs);
+    ted::insert_all(ted::Position::first_child_of(strukt.syntax()), variant_attrs);
 
     // copy attributes from enum
     ted::insert_all(
-        Position::first_child_of(strukt.syntax()),
+        ted::Position::first_child_of(strukt.syntax()),
         enum_.attrs().map(|it| it.syntax().clone_for_update().into()).collect(),
     );
     strukt
 }
 
-fn update_variant(variant: &ast::Variant, generic: Option<ast::GenericParamList>) -> Option<()> {
+fn update_variant(variant: &ast::Variant, generics: Option<ast::GenericParamList>) -> Option<()> {
     let name = variant.name()?;
-    let ty = match generic {
-        // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
-        Some(gpl) => {
-            let gpl = gpl.clone_for_update();
-            gpl.generic_params().for_each(|gp| {
-                let tbl = match gp {
-                    ast::GenericParam::LifetimeParam(it) => it.type_bound_list(),
-                    ast::GenericParam::TypeParam(it) => it.type_bound_list(),
-                    ast::GenericParam::ConstParam(_) => return,
-                };
-                if let Some(tbl) = tbl {
-                    tbl.remove();
+    let ty = generics
+        .filter(|generics| generics.generic_params().count() > 0)
+        .map(|generics| {
+            let mut generic_str = String::with_capacity(8);
+
+            for (p, more) in generics.generic_params().with_position().map(|p| match p {
+                Position::First(p) | Position::Middle(p) => (p, true),
+                Position::Last(p) | Position::Only(p) => (p, false),
+            }) {
+                match p {
+                    ast::GenericParam::ConstParam(konst) => {
+                        if let Some(name) = konst.name() {
+                            generic_str.push_str(name.text().as_str());
+                        }
+                    }
+                    ast::GenericParam::LifetimeParam(lt) => {
+                        if let Some(lt) = lt.lifetime() {
+                            generic_str.push_str(lt.text().as_str());
+                        }
+                    }
+                    ast::GenericParam::TypeParam(ty) => {
+                        if let Some(name) = ty.name() {
+                            generic_str.push_str(name.text().as_str());
+                        }
+                    }
                 }
-            });
-            make::ty(&format!("{}<{}>", name.text(), gpl.generic_params().join(", ")))
-        }
-        None => make::ty(&name.text()),
-    };
+                if more {
+                    generic_str.push_str(", ");
+                }
+            }
+
+            make::ty(&format!("{}<{}>", &name.text(), &generic_str))
+        })
+        .unwrap_or_else(|| make::ty(&name.text()));
+
     let tuple_field = make::tuple_field(None, ty);
     let replacement = make::variant(
         name,
@@ -902,4 +988,92 @@ enum A { $0One(u8, u32) }
     fn test_extract_not_applicable_no_field_named() {
         check_assist_not_applicable(extract_struct_from_enum_variant, r"enum A { $0None {} }");
     }
+
+    #[test]
+    fn test_extract_struct_only_copies_needed_generics() {
+        check_assist(
+            extract_struct_from_enum_variant,
+            r#"
+enum X<'a, 'b, 'x> {
+    $0A { a: &'a &'x mut () },
+    B { b: &'b () },
+    C { c: () },
+}
+"#,
+            r#"
+struct A<'a, 'x>{ a: &'a &'x mut () }
+
+enum X<'a, 'b, 'x> {
+    A(A<'a, 'x>),
+    B { b: &'b () },
+    C { c: () },
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn test_extract_struct_with_liftime_type_const() {
+        check_assist(
+            extract_struct_from_enum_variant,
+            r#"
+enum X<'b, T, V, const C: usize> {
+    $0A { a: T, b: X<'b>, c: [u8; C] },
+    D { d: V },
+}
+"#,
+            r#"
+struct A<'b, T, const C: usize>{ a: T, b: X<'b>, c: [u8; C] }
+
+enum X<'b, T, V, const C: usize> {
+    A(A<'b, T, C>),
+    D { d: V },
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn test_extract_struct_without_generics() {
+        check_assist(
+            extract_struct_from_enum_variant,
+            r#"
+enum X<'a, 'b> {
+    A { a: &'a () },
+    B { b: &'b () },
+    $0C { c: () },
+}
+"#,
+            r#"
+struct C{ c: () }
+
+enum X<'a, 'b> {
+    A { a: &'a () },
+    B { b: &'b () },
+    C(C),
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn test_extract_struct_keeps_trait_bounds() {
+        check_assist(
+            extract_struct_from_enum_variant,
+            r#"
+enum En<T: TraitT, V: TraitV> {
+    $0A { a: T },
+    B { b: V },
+}
+"#,
+            r#"
+struct A<T: TraitT>{ a: T }
+
+enum En<T: TraitT, V: TraitV> {
+    A(A<T>),
+    B { b: V },
+}
+"#,
+        );
+    }
 }