about summary refs log tree commit diff
diff options
context:
space:
mode:
authorroife <roifewu@gmail.com>2024-01-04 13:31:00 +0800
committerroife <roifewu@gmail.com>2024-01-05 13:59:34 +0800
commitbf0c4acef4a352408e07b39eb0d5292c31640b6a (patch)
treec6f5c866641e0e045c185dcb52ed4ee0e4879081
parent426d2842c1f0e5cc5e34bb37c7ac3ee0945f9746 (diff)
downloadrust-bf0c4acef4a352408e07b39eb0d5292c31640b6a.tar.gz
rust-bf0c4acef4a352408e07b39eb0d5292c31640b6a.zip
internal: refactor generate_delegate_trait and add comments
-rw-r--r--crates/ide-assists/src/handlers/generate_delegate_trait.rs326
-rw-r--r--crates/syntax/src/ast/edit_in_place.rs3
-rw-r--r--crates/syntax/src/ast/make.rs3
3 files changed, 142 insertions, 190 deletions
diff --git a/crates/ide-assists/src/handlers/generate_delegate_trait.rs b/crates/ide-assists/src/handlers/generate_delegate_trait.rs
index 0d34502add9..af2b60d22a5 100644
--- a/crates/ide-assists/src/handlers/generate_delegate_trait.rs
+++ b/crates/ide-assists/src/handlers/generate_delegate_trait.rs
@@ -17,7 +17,7 @@ use syntax::{
         self,
         edit::{self, AstNodeEdit},
         make, AssocItem, GenericArgList, GenericParamList, HasGenericParams, HasName,
-        HasTypeBounds, HasVisibility as astHasVisibility, Path,
+        HasTypeBounds, HasVisibility as astHasVisibility, Path, WherePred,
     },
     ted::{self, Position},
     AstNode, NodeOrToken, SmolStr, SyntaxKind,
@@ -217,9 +217,9 @@ impl Struct {
             };
 
             acc.add_group(
-                &GroupLabel(format!("Generate delegate impls for field `{}`", field.name)),
+                &GroupLabel(format!("Generate delegate trait impls for field `{}`", field.name)),
                 AssistId("generate_delegate_trait", ide_db::assists::AssistKind::Generate),
-                format!("Generate delegate impl `{}` for `{}`", signature, field.name),
+                format!("Generate delegate trait impl `{}` for `{}`", signature, field.name),
                 field.range,
                 |builder| {
                     builder.insert(
@@ -243,12 +243,12 @@ fn generate_impl(
     let db = ctx.db();
     let ast_strukt = &strukt.strukt;
     let strukt_ty = make::ty_path(make::ext::ident_path(&strukt.name.to_string()));
+    let strukt_params = ast_strukt.generic_param_list();
 
     match delegee {
         Delegee::Bound(delegee) => {
             let bound_def = ctx.sema.source(delegee.to_owned())?.value;
             let bound_params = bound_def.generic_param_list();
-            let strukt_params = ast_strukt.generic_param_list();
 
             delegate = make::impl_trait(
                 delegee.is_unsafe(db),
@@ -295,63 +295,77 @@ fn generate_impl(
         }
         Delegee::Impls(trait_, old_impl) => {
             let old_impl = ctx.sema.source(old_impl.to_owned())?.value;
+            let old_impl_params = old_impl.generic_param_list();
+
+            // 1) Resolve conflicts between generic parameters in old_impl and
+            // those in strukt.
+            //
+            // These generics parameters will also be used in `field_ty` and
+            // `where_clauses`, so we should substitude arguments in them as well.
+            let strukt_params = resolve_name_conflicts(strukt_params, &old_impl_params);
+            let (field_ty, ty_where_clause) = match &strukt_params {
+                Some(strukt_params) => {
+                    let args = strukt_params.to_generic_args();
+                    let field_ty = rename_strukt_args(ctx, ast_strukt, field_ty, &args)?;
+                    let where_clause = ast_strukt
+                        .where_clause()
+                        .and_then(|wc| Some(rename_strukt_args(ctx, ast_strukt, &wc, &args)?));
+                    (field_ty, where_clause)
+                }
+                None => (field_ty.clone_for_update(), None),
+            };
+
+            // 2) Handle instantiated generics in `field_ty`.
+
+            // 2.1) Some generics used in `self_ty` may be instantiated, so they
+            // are no longer generics, we should remove and instantiate those
+            // generics in advance.
 
             // `old_trait_args` contains names of generic args for trait in `old_impl`
-            let old_trait_args = old_impl
+            let old_impl_trait_args = old_impl
                 .trait_()?
                 .generic_arg_list()
                 .map(|l| l.generic_args().map(|arg| arg.to_string()))
                 .map_or_else(|| FxHashSet::default(), |it| it.collect());
 
-            let old_impl_params = old_impl.generic_param_list();
-
-            // Resolve conflicts with generic parameters in strukt.
-            // These generics parameters will also be used in `field_ty` and `where_clauses`,
-            // so we should substitude arguments in them as well.
-            let (renamed_strukt_params, field_ty, ty_where_clause) = if let Some(strukt_params) =
-                resolve_conflicts_for_strukt(ast_strukt, old_impl_params.as_ref())
-            {
-                let strukt_args = strukt_params.to_generic_args();
-                let field_ty =
-                    subst_name_in_strukt(ctx, ast_strukt, field_ty, strukt_args.clone())?;
-                let wc = ast_strukt
-                    .where_clause()
-                    .and_then(|wc| Some(subst_name_in_strukt(ctx, ast_strukt, &wc, strukt_args)?));
-                (Some(strukt_params), field_ty, wc)
-            } else {
-                (None, field_ty.clone_for_update(), None)
-            };
-
-            // Some generics used in `field_ty` may be instantiated, so they are no longer
-            // `generics`. We should remove them from generics params, and use the rest params.
-            let trait_gen_params =
-                remove_instantiated_params(&old_impl.self_ty()?, old_impl_params, &old_trait_args);
+            let trait_gen_params = remove_instantiated_params(
+                &old_impl.self_ty()?,
+                old_impl_params.clone(),
+                &old_impl_trait_args,
+            );
 
-            // Generate generic args that applied to current impl, this step will also remove unused params
-            let args_for_impl =
-                get_args_for_impl(&old_impl, &field_ty, &trait_gen_params, &old_trait_args);
+            // 2.2) Generate generic args applied on impl.
+            let transform_args = generate_args_for_impl(
+                old_impl_params,
+                &old_impl.self_ty()?,
+                &field_ty,
+                &trait_gen_params,
+                &old_impl_trait_args,
+            );
 
+            // 2.3) Instantiate generics with `transform_impl`, this step also
+            // remove unused params.
             let mut trait_gen_args = old_impl.trait_()?.generic_arg_list();
-            if let Some(arg_list) = &mut trait_gen_args {
-                *arg_list = arg_list.clone_for_update();
-                transform_impl(ctx, ast_strukt, &old_impl, &args_for_impl, &arg_list.syntax())?;
+            if let Some(trait_args) = &mut trait_gen_args {
+                *trait_args = trait_args.clone_for_update();
+                transform_impl(ctx, ast_strukt, &old_impl, &transform_args, &trait_args.syntax())?;
             }
 
-            let mut type_gen_args =
-                renamed_strukt_params.clone().map(|params| params.to_generic_args());
+            let mut type_gen_args = strukt_params.clone().map(|params| params.to_generic_args());
             if let Some(type_args) = &mut type_gen_args {
                 *type_args = type_args.clone_for_update();
-                transform_impl(ctx, ast_strukt, &old_impl, &args_for_impl, &type_args.syntax())?;
+                transform_impl(ctx, ast_strukt, &old_impl, &transform_args, &type_args.syntax())?;
             }
 
             let path_type = make::ty(&trait_.name(db).to_smol_str()).clone_for_update();
-            transform_impl(ctx, ast_strukt, &old_impl, &args_for_impl, &path_type.syntax())?;
+            transform_impl(ctx, ast_strukt, &old_impl, &transform_args, &path_type.syntax())?;
 
+            // 3) Generate delegate trait impl
             delegate = make::impl_trait(
                 trait_.is_unsafe(db),
                 trait_gen_params,
                 trait_gen_args,
-                renamed_strukt_params,
+                strukt_params,
                 type_gen_args,
                 trait_.is_auto(db),
                 path_type,
@@ -369,24 +383,23 @@ fn generate_impl(
                 delegate.trait_()?.to_string()
             ));
 
+            // 4) Transform associated items in delegte trait impl
             let delegate_assoc_items = delegate.get_or_create_assoc_item_list();
             for item in old_impl
                 .get_or_create_assoc_item_list()
                 .assoc_items()
                 .filter(|item| matches!(item, AssocItem::MacroCall(_)).not())
             {
-                let assoc = process_assoc_item(
-                    transform_assoc_item(ctx, ast_strukt, &old_impl, &args_for_impl, item)?,
-                    qualified_path_type.clone(),
-                    &field_name,
-                )?;
+                let item = item.clone_for_update();
+                transform_impl(ctx, ast_strukt, &old_impl, &transform_args, item.syntax())?;
 
+                let assoc = process_assoc_item(item, qualified_path_type.clone(), &field_name)?;
                 delegate_assoc_items.add_item(assoc);
             }
 
-            // Remove unused where clauses
+            // 5) Remove useless where clauses
             if let Some(wc) = delegate.where_clause() {
-                remove_useless_where_clauses(&delegate, wc)?;
+                remove_useless_where_clauses(&delegate.trait_()?, &delegate.self_ty()?, wc);
             }
         }
     }
@@ -394,32 +407,6 @@ fn generate_impl(
     Some(delegate)
 }
 
-fn transform_assoc_item(
-    ctx: &AssistContext<'_>,
-    strukt: &ast::Struct,
-    old_impl: &ast::Impl,
-    args: &Option<GenericArgList>,
-    item: AssocItem,
-) -> Option<AssocItem> {
-    let source_scope = ctx.sema.scope(&item.syntax()).unwrap();
-    let target_scope = ctx.sema.scope(&strukt.syntax())?;
-    let hir_old_impl = ctx.sema.to_impl_def(old_impl)?;
-    let item = item.clone_for_update();
-    let transform = args.as_ref().map_or_else(
-        || PathTransform::generic_transformation(&target_scope, &source_scope),
-        |args| {
-            PathTransform::impl_transformation(
-                &target_scope,
-                &source_scope,
-                hir_old_impl,
-                args.clone(),
-            )
-        },
-    );
-    transform.apply(&item.syntax());
-    Some(item)
-}
-
 fn transform_impl(
     ctx: &AssistContext<'_>,
     strukt: &ast::Struct,
@@ -463,11 +450,11 @@ fn remove_instantiated_params(
                     .segments()
                     .filter_map(|seg| seg.generic_arg_list())
                     .flat_map(|it| it.generic_args())
-                    // However, if the param is also used in the trait arguments, it shouldn't be removed.
+                    // However, if the param is also used in the trait arguments,
+                    // it shouldn't be removed now, which will be instantiated in
+                    // later `path_transform`
                     .filter(|arg| !old_trait_args.contains(&arg.to_string()))
-                    .for_each(|arg| {
-                        new_gpl.remove_generic_arg(&arg);
-                    });
+                    .for_each(|arg| new_gpl.remove_generic_arg(&arg));
                 (new_gpl.generic_params().count() > 0).then_some(new_gpl)
             })
         }
@@ -475,49 +462,37 @@ fn remove_instantiated_params(
     }
 }
 
-fn remove_useless_where_clauses(delegate: &ast::Impl, wc: ast::WhereClause) -> Option<()> {
-    let trait_args =
-        delegate.trait_()?.generic_arg_list().map(|trait_args| trait_args.generic_args());
-    let strukt_args =
-        delegate.self_ty()?.generic_arg_list().map(|strukt_args| strukt_args.generic_args());
-    let used_generic_names = match (trait_args, strukt_args) {
-        (None, None) => None,
-        (None, Some(y)) => Some(y.map(|arg| arg.to_string()).collect::<FxHashSet<_>>()),
-        (Some(x), None) => Some(x.map(|arg| arg.to_string()).collect::<FxHashSet<_>>()),
-        (Some(x), Some(y)) => Some(x.chain(y).map(|arg| arg.to_string()).collect::<FxHashSet<_>>()),
+fn remove_useless_where_clauses(trait_ty: &ast::Type, self_ty: &ast::Type, wc: ast::WhereClause) {
+    let live_generics = [trait_ty, self_ty]
+        .into_iter()
+        .flat_map(|ty| ty.generic_arg_list())
+        .flat_map(|gal| gal.generic_args())
+        .map(|x| x.to_string())
+        .collect::<FxHashSet<_>>();
+
+    // Keep where-clauses that have generics after substitution, and remove the
+    // rest.
+    let has_live_generics = |pred: &WherePred| {
+        pred.syntax()
+            .descendants_with_tokens()
+            .filter_map(|e| e.into_token())
+            .any(|e| e.kind() == SyntaxKind::IDENT && live_generics.contains(&e.to_string()))
+            .not()
     };
-
-    // Keep clauses that have generic clauses after substitution, and remove the rest
-    if let Some(used_generic_names) = used_generic_names {
-        wc.predicates()
-            .filter(|pred| {
-                pred.syntax()
-                    .descendants_with_tokens()
-                    .filter_map(|e| e.into_token())
-                    .find(|e| {
-                        e.kind() == SyntaxKind::IDENT && used_generic_names.contains(&e.to_string())
-                    })
-                    .is_none()
-            })
-            .for_each(|pred| {
-                wc.remove_predicate(pred);
-            });
-    } else {
-        wc.predicates().for_each(|pred| wc.remove_predicate(pred));
-    }
+    wc.predicates().filter(has_live_generics).for_each(|pred| wc.remove_predicate(pred));
 
     if wc.predicates().count() == 0 {
         // Remove useless whitespaces
-        wc.syntax()
-            .siblings_with_tokens(syntax::Direction::Prev)
-            .skip(1)
-            .take_while(|node_or_tok| node_or_tok.kind() == SyntaxKind::WHITESPACE)
-            .for_each(|ws| ted::remove(ws));
-        wc.syntax()
-            .siblings_with_tokens(syntax::Direction::Next)
-            .skip(1)
-            .take_while(|node_or_tok| node_or_tok.kind() == SyntaxKind::WHITESPACE)
+        [syntax::Direction::Prev, syntax::Direction::Next]
+            .into_iter()
+            .flat_map(|dir| {
+                wc.syntax()
+                    .siblings_with_tokens(dir)
+                    .skip(1)
+                    .take_while(|node_or_tok| node_or_tok.kind() == SyntaxKind::WHITESPACE)
+            })
             .for_each(|ws| ted::remove(ws));
+
         ted::insert(
             ted::Position::after(wc.syntax()),
             NodeOrToken::Token(make::token(SyntaxKind::WHITESPACE)),
@@ -525,84 +500,63 @@ fn remove_useless_where_clauses(delegate: &ast::Impl, wc: ast::WhereClause) -> O
         // Remove where clause
         ted::remove(wc.syntax());
     }
-
-    Some(())
 }
 
-fn get_args_for_impl(
-    old_impl: &ast::Impl,
+// Generate generic args that should be apply to current impl.
+//
+// For exmaple, say we have implementation `impl<A, B, C> Trait for B<A>`,
+// and `b: B<T>` in struct `S<T>`. Then the `A` should be instantiated to `T`.
+// While the last two generic args `B` and `C` doesn't change, it remains
+// `<B, C>`. So we apply `<T, B, C>` as generic arguments to impl.
+fn generate_args_for_impl(
+    old_impl_gpl: Option<GenericParamList>,
+    self_ty: &ast::Type,
     field_ty: &ast::Type,
     trait_params: &Option<GenericParamList>,
     old_trait_args: &FxHashSet<String>,
 ) -> Option<ast::GenericArgList> {
-    // Generate generic args that should be apply to current impl
-    //
-    // For exmaple, if we have `impl<A, B, C> Trait for B<A>`, and `b: B<T>` in `S<T>`,
-    // then the generic `A` should be renamed to `T`. While the last two generic args
-    // doesn't change, it renames <B, C>. So we apply `<T, B C>` as generic arguments
-    // to impl.
-    let old_impl_params = old_impl.generic_param_list();
-    let self_ty = old_impl.self_ty();
-
-    if let (Some(old_impl_gpl), Some(self_ty)) = (old_impl_params, self_ty) {
-        // Make pair of the arguments of `field_ty` and `old_strukt_args` to
-        // get the list for substitution
-        let mut arg_substs = FxHashMap::default();
-
-        match field_ty {
-            field_ty @ ast::Type::PathType(_) => {
-                let field_args = field_ty.generic_arg_list();
-                if let (Some(field_args), Some(old_impl_args)) =
-                    (field_args, self_ty.generic_arg_list())
-                {
-                    field_args.generic_args().zip(old_impl_args.generic_args()).for_each(
-                        |(field_arg, impl_arg)| {
-                            arg_substs.entry(impl_arg.to_string()).or_insert(field_arg);
-                        },
-                    )
-                }
+    let Some(old_impl_args) = old_impl_gpl.map(|gpl| gpl.to_generic_args().generic_args()) else {
+        return None;
+    };
+    // Create pairs of the args of `self_ty` and corresponding `field_ty` to
+    // form the substitution list
+    let mut arg_substs = FxHashMap::default();
+
+    match field_ty {
+        field_ty @ ast::Type::PathType(_) => {
+            let field_args = field_ty.generic_arg_list().map(|gal| gal.generic_args());
+            let self_ty_args = self_ty.generic_arg_list().map(|gal| gal.generic_args());
+            if let (Some(field_args), Some(self_ty_args)) = (field_args, self_ty_args) {
+                self_ty_args.zip(field_args).for_each(|(self_ty_arg, field_arg)| {
+                    arg_substs.entry(self_ty_arg.to_string()).or_insert(field_arg);
+                })
             }
-            _ => {}
         }
-
-        let args = old_impl_gpl
-            .to_generic_args()
-            .generic_args()
-            .map(|old_arg| {
-                arg_substs.get(&old_arg.to_string()).map_or_else(
-                    || old_arg.clone(),
-                    |replace_with| {
-                        // The old_arg will be replaced, so it becomes redundant
-                        let old_arg_name = old_arg.to_string();
-                        if old_trait_args.contains(&old_arg_name) {
-                            // However, we should check type bounds and where clauses on old_arg,
-                            // if it has type bound, we should keep the type bound.
-                            // match trait_params.and_then(|params| params.remove_generic_arg(&old_arg)) {
-                            //     Some(ast::GenericParam::TypeParam(ty)) => {
-                            //         ty.type_bound_list().and_then(|bounds| )
-                            //     }
-                            //     _ => {}
-                            // }
-                            if let Some(params) = trait_params {
-                                params.remove_generic_arg(&old_arg);
-                            }
-                        }
-                        replace_with.clone()
-                    },
-                )
-            })
-            .collect_vec();
-        args.is_empty().not().then(|| make::generic_arg_list(args.into_iter()))
-    } else {
-        None
+        _ => {}
     }
+
+    let args = old_impl_args
+        .map(|old_arg| {
+            arg_substs.get(&old_arg.to_string()).map_or_else(
+                || old_arg.clone(),
+                |replace_with| {
+                    // The old_arg will be replaced, so it becomes redundant
+                    if trait_params.is_some() && old_trait_args.contains(&old_arg.to_string()) {
+                        trait_params.as_ref().unwrap().remove_generic_arg(&old_arg)
+                    }
+                    replace_with.clone()
+                },
+            )
+        })
+        .collect_vec();
+    args.is_empty().not().then(|| make::generic_arg_list(args.into_iter()))
 }
 
-fn subst_name_in_strukt<N>(
+fn rename_strukt_args<N>(
     ctx: &AssistContext<'_>,
     strukt: &ast::Struct,
     item: &N,
-    args: GenericArgList,
+    args: &GenericArgList,
 ) -> Option<N>
 where
     N: ast::AstNode,
@@ -611,9 +565,11 @@ where
     let hir_adt = hir::Adt::from(hir_strukt);
 
     let item = item.clone_for_update();
-    let item_scope = ctx.sema.scope(item.syntax())?;
-    let transform = PathTransform::adt_transformation(&item_scope, &item_scope, hir_adt, args);
+    let scope = ctx.sema.scope(item.syntax())?;
+
+    let transform = PathTransform::adt_transformation(&scope, &scope, hir_adt, args.clone());
     transform.apply(&item.syntax());
+
     Some(item)
 }
 
@@ -627,16 +583,16 @@ fn has_self_type(trait_: hir::Trait, ctx: &AssistContext<'_>) -> Option<()> {
         .map(|_| ())
 }
 
-fn resolve_conflicts_for_strukt(
-    strukt: &ast::Struct,
-    old_impl_params: Option<&ast::GenericParamList>,
+fn resolve_name_conflicts(
+    strukt_params: Option<ast::GenericParamList>,
+    old_impl_params: &Option<ast::GenericParamList>,
 ) -> Option<ast::GenericParamList> {
-    match (strukt.generic_param_list(), old_impl_params) {
+    match (strukt_params, old_impl_params) {
         (Some(old_strukt_params), Some(old_impl_params)) => {
             let params = make::generic_param_list(std::iter::empty()).clone_for_update();
 
             for old_strukt_param in old_strukt_params.generic_params() {
-                // Get old name from `strukt``
+                // Get old name from `strukt`
                 let mut name = SmolStr::from(match &old_strukt_param {
                     ast::GenericParam::ConstParam(c) => c.name()?.to_string(),
                     ast::GenericParam::LifetimeParam(l) => {
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs
index 4c2878f49f0..aa56f10d609 100644
--- a/crates/syntax/src/ast/edit_in_place.rs
+++ b/crates/syntax/src/ast/edit_in_place.rs
@@ -293,13 +293,12 @@ impl ast::GenericParamList {
     }
 
     /// Removes the corresponding generic arg
-    pub fn remove_generic_arg(&self, generic_arg: &ast::GenericArg) -> Option<GenericParam> {
+    pub fn remove_generic_arg(&self, generic_arg: &ast::GenericArg) {
         let param_to_remove = self.find_generic_arg(generic_arg);
 
         if let Some(param) = &param_to_remove {
             self.remove_generic_param(param.clone());
         }
-        param_to_remove
     }
 
     /// Constructs a matching [`ast::GenericArgList`]
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index 2abbfc81f67..b1dd1fe8c82 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -263,9 +263,6 @@ pub fn impl_(
     ast_from_text(&format!("impl{gen_params} {path_type}{tr_gen_args}{where_clause}{{{}}}", body))
 }
 
-// FIXME : We must make *_gen_args' type ast::GenericArgList but in order to do so we must implement in `edit_in_place.rs`
-// `add_generic_arg()` just like `add_generic_param()`
-// is implemented for `ast::GenericParamList`
 pub fn impl_trait(
     is_unsafe: bool,
     trait_gen_params: Option<ast::GenericParamList>,