about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2021-09-25 18:39:43 +0200
committerLukas Wirth <lukastw97@gmail.com>2021-09-25 18:39:43 +0200
commit1ccb21a0ca935bec27e6bda95f3c48c4eabcd514 (patch)
tree4d4dbc854ab82e03bd4074e88042285173dd28d3
parent13da3d93f9be7200dda0635b0822e56b965194c5 (diff)
downloadrust-1ccb21a0ca935bec27e6bda95f3c48c4eabcd514.tar.gz
rust-1ccb21a0ca935bec27e6bda95f3c48c4eabcd514.zip
feat: Implement inline callers assist
-rw-r--r--crates/base_db/src/fixture.rs3
-rw-r--r--crates/ide_assists/src/handlers/inline_call.rs554
-rw-r--r--crates/ide_assists/src/lib.rs1
-rw-r--r--crates/ide_assists/src/tests/generated.rs37
4 files changed, 455 insertions, 140 deletions
diff --git a/crates/base_db/src/fixture.rs b/crates/base_db/src/fixture.rs
index 27be3746b71..cc1de54dfb3 100644
--- a/crates/base_db/src/fixture.rs
+++ b/crates/base_db/src/fixture.rs
@@ -161,7 +161,8 @@ impl ChangeFixture {
         }
 
         if crates.is_empty() {
-            let crate_root = default_crate_root.unwrap();
+            let crate_root = default_crate_root
+                .expect("missing default crate root, specify a main.rs or lib.rs");
             crate_graph.add_crate_root(
                 crate_root,
                 Edition::CURRENT,
diff --git a/crates/ide_assists/src/handlers/inline_call.rs b/crates/ide_assists/src/handlers/inline_call.rs
index 33e029c236f..40231bd9480 100644
--- a/crates/ide_assists/src/handlers/inline_call.rs
+++ b/crates/ide_assists/src/handlers/inline_call.rs
@@ -1,10 +1,13 @@
 use ast::make;
-use hir::{HasSource, PathResolution, TypeInfo};
-use ide_db::{defs::Definition, path_transform::PathTransform, search::FileReference};
+use hir::{db::HirDatabase, HasSource, PathResolution, Semantics, TypeInfo};
+use ide_db::{
+    base_db::FileId, defs::Definition, path_transform::PathTransform, search::FileReference,
+    RootDatabase,
+};
 use itertools::izip;
 use syntax::{
     ast::{self, edit_in_place::Indent, ArgListOwner},
-    ted, AstNode,
+    ted, AstNode, SyntaxNode,
 };
 
 use crate::{
@@ -12,6 +15,132 @@ use crate::{
     AssistId, AssistKind,
 };
 
+// Assist: inline_into_callers
+//
+// Inline a function or method body into all of its callers where possible, creating a `let` statement per parameter
+// unless the parameter can be inlined. The parameter will be inlined either if it the supplied argument is a simple local
+// or if the parameter is only accessed inside the function body once.
+// If all calls can be inlined the function will be removed.
+//
+// ```
+// fn print(_: &str) {}
+// fn foo$0(word: &str) {
+//     if !word.is_empty() {
+//         print(word);
+//     }
+// }
+// fn bar() {
+//     foo("안녕하세요");
+//     foo("여러분");
+// }
+// ```
+// ->
+// ```
+// fn print(_: &str) {}
+//
+// fn bar() {
+//     {
+//         let word = "안녕하세요";
+//         if !word.is_empty() {
+//             print(word);
+//         }
+//     };
+//     {
+//         let word = "여러분";
+//         if !word.is_empty() {
+//             print(word);
+//         }
+//     };
+// }
+// ```
+pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
+    let name = ctx.find_node_at_offset::<ast::Name>()?;
+    let func_syn = name.syntax().parent().and_then(ast::Fn::cast)?;
+    let func_body = func_syn.body()?;
+    let param_list = func_syn.param_list()?;
+    let function = ctx.sema.to_def(&func_syn)?;
+    let params = get_fn_params(ctx.sema.db, function, &param_list)?;
+
+    let usages = Definition::ModuleDef(hir::ModuleDef::Function(function)).usages(&ctx.sema);
+    if !usages.at_least_one() {
+        return None;
+    }
+
+    acc.add(
+        AssistId("inline_into_callers", AssistKind::RefactorInline),
+        "Inline into all callers",
+        name.syntax().text_range(),
+        |builder| {
+            let def_file = ctx.frange.file_id;
+            let usages =
+                Definition::ModuleDef(hir::ModuleDef::Function(function)).usages(&ctx.sema);
+            let mut usages = usages.all();
+            let current_file_usage = usages.references.remove(&def_file);
+
+            let mut can_remove = true;
+            let mut inline_refs = |file_id, refs: Vec<FileReference>| {
+                builder.edit_file(file_id);
+                let count = refs.len();
+                let name_refs = refs.into_iter().filter_map(|file_ref| match file_ref.name {
+                    ast::NameLike::NameRef(name_ref) => Some(name_ref),
+                    _ => None,
+                });
+                let call_infos = name_refs.filter_map(|name_ref| {
+                    let parent = name_ref.syntax().parent()?;
+                    if let Some(call) = ast::MethodCallExpr::cast(parent.clone()) {
+                        let receiver = call.receiver()?;
+                        let mut arguments = vec![receiver];
+                        arguments.extend(call.arg_list()?.args());
+                        Some(CallInfo {
+                            generic_arg_list: call.generic_arg_list(),
+                            node: CallExprNode::MethodCallExpr(call),
+                            arguments,
+                        })
+                    } else if let Some(segment) = ast::PathSegment::cast(parent) {
+                        let path = segment.syntax().parent().and_then(ast::Path::cast)?;
+                        let path = path.syntax().parent().and_then(ast::PathExpr::cast)?;
+                        let call = path.syntax().parent().and_then(ast::CallExpr::cast)?;
+
+                        Some(CallInfo {
+                            arguments: call.arg_list()?.args().collect(),
+                            node: CallExprNode::Call(call),
+                            generic_arg_list: segment.generic_arg_list(),
+                        })
+                    } else {
+                        None
+                    }
+                });
+                let replaced = call_infos
+                    .map(|call_info| {
+                        let replacement =
+                            inline(&ctx.sema, def_file, function, &func_body, &params, &call_info);
+
+                        builder.replace_ast(
+                            match call_info.node {
+                                CallExprNode::Call(it) => ast::Expr::CallExpr(it),
+                                CallExprNode::MethodCallExpr(it) => ast::Expr::MethodCallExpr(it),
+                            },
+                            replacement,
+                        );
+                    })
+                    .count();
+                can_remove &= replaced == count;
+            };
+            for (file_id, refs) in usages.into_iter() {
+                inline_refs(file_id, refs);
+            }
+            if let Some(refs) = current_file_usage {
+                inline_refs(def_file, refs);
+            } else {
+                builder.edit_file(def_file);
+            }
+            if can_remove {
+                builder.delete(func_syn.syntax().text_range());
+            }
+        },
+    )
+}
+
 // Assist: inline_call
 //
 // Inlines a function or method body creating a `let` statement per parameter unless the parameter
@@ -34,8 +163,9 @@ use crate::{
 // }
 // ```
 pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
-    let (label, function, arguments, generic_arg_list, expr) =
+    let (label, function, call_info) =
         if let Some(path_expr) = ctx.find_node_at_offset::<ast::PathExpr>() {
+            // FIXME make applicable only on nameref
             let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?;
             let path = path_expr.path()?;
 
@@ -47,9 +177,11 @@ pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
             (
                 format!("Inline `{}`", path),
                 function,
-                call.arg_list()?.args().collect(),
-                path.segment().and_then(|it| it.generic_arg_list()),
-                ast::Expr::CallExpr(call),
+                CallInfo {
+                    arguments: call.arg_list()?.args().collect(),
+                    node: CallExprNode::Call(call),
+                    generic_arg_list: path.segment().and_then(|it| it.generic_arg_list()),
+                },
             )
         } else {
             let name_ref: ast::NameRef = ctx.find_node_at_offset()?;
@@ -61,27 +193,73 @@ pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
             (
                 format!("Inline `{}`", name_ref),
                 function,
-                arguments,
-                call.generic_arg_list(),
-                ast::Expr::MethodCallExpr(call),
+                CallInfo {
+                    generic_arg_list: call.generic_arg_list(),
+                    node: CallExprNode::MethodCallExpr(call),
+                    arguments,
+                },
             )
         };
 
-    inline_(acc, ctx, label, function, arguments, expr, generic_arg_list)
+    let hir::InFile { value: function_source, file_id } = function.source(ctx.db())?;
+    let fn_body = function_source.body()?;
+    let param_list = function_source.param_list()?;
+
+    let params = get_fn_params(ctx.sema.db, function, &param_list)?;
+
+    if call_info.arguments.len() != params.len() {
+        // Can't inline the function because they've passed the wrong number of
+        // arguments to this function
+        cov_mark::hit!(inline_call_incorrect_number_of_arguments);
+        return None;
+    }
+
+    let syntax = call_info.node.syntax().clone();
+    acc.add(
+        AssistId("inline_call", AssistKind::RefactorInline),
+        label,
+        syntax.text_range(),
+        |builder| {
+            let file_id = file_id.original_file(ctx.sema.db);
+            let replacement = inline(&ctx.sema, file_id, function, &fn_body, &params, &call_info);
+
+            builder.replace_ast(
+                match call_info.node {
+                    CallExprNode::Call(it) => ast::Expr::CallExpr(it),
+                    CallExprNode::MethodCallExpr(it) => ast::Expr::MethodCallExpr(it),
+                },
+                replacement,
+            );
+        },
+    )
 }
 
-pub(crate) fn inline_(
-    acc: &mut Assists,
-    ctx: &AssistContext,
-    label: String,
-    function: hir::Function,
-    arg_list: Vec<ast::Expr>,
-    expr: ast::Expr,
+enum CallExprNode {
+    Call(ast::CallExpr),
+    MethodCallExpr(ast::MethodCallExpr),
+}
+
+impl CallExprNode {
+    fn syntax(&self) -> &SyntaxNode {
+        match self {
+            CallExprNode::Call(it) => it.syntax(),
+            CallExprNode::MethodCallExpr(it) => it.syntax(),
+        }
+    }
+}
+
+struct CallInfo {
+    node: CallExprNode,
+    arguments: Vec<ast::Expr>,
     generic_arg_list: Option<ast::GenericArgList>,
-) -> Option<()> {
-    let hir::InFile { value: function_source, file_id } = function.source(ctx.db())?;
-    let param_list = function_source.param_list()?;
-    let mut assoc_fn_params = function.assoc_fn_params(ctx.sema.db).into_iter();
+}
+
+fn get_fn_params(
+    db: &dyn HirDatabase,
+    function: hir::Function,
+    param_list: &ast::ParamList,
+) -> Option<Vec<(ast::Pat, Option<ast::Type>, hir::Param)>> {
+    let mut assoc_fn_params = function.assoc_fn_params(db).into_iter();
 
     let mut params = Vec::new();
     if let Some(self_param) = param_list.self_param() {
@@ -101,131 +279,116 @@ pub(crate) fn inline_(
         params.push((param.pat()?, param.ty(), assoc_fn_params.next()?));
     }
 
-    if arg_list.len() != params.len() {
-        // Can't inline the function because they've passed the wrong number of
-        // arguments to this function
-        cov_mark::hit!(inline_call_incorrect_number_of_arguments);
-        return None;
-    }
-
-    let fn_body = function_source.body()?;
-
-    acc.add(
-        AssistId("inline_call", AssistKind::RefactorInline),
-        label,
-        expr.syntax().text_range(),
-        |builder| {
-            let body = fn_body.clone_for_update();
+    Some(params)
+}
 
-            let file_id = file_id.original_file(ctx.sema.db);
-            let usages_for_locals = |local| {
-                Definition::Local(local)
-                    .usages(&ctx.sema)
-                    .all()
-                    .references
-                    .remove(&file_id)
-                    .unwrap_or_default()
-                    .into_iter()
-            };
-            // Contains the nodes of usages of parameters.
-            // If the inner Vec for a parameter is empty it either means there are no usages or that the parameter
-            // has a pattern that does not allow inlining
-            let param_use_nodes: Vec<Vec<_>> = params
-                .iter()
-                .map(|(pat, _, param)| {
-                    if !matches!(pat, ast::Pat::IdentPat(pat) if pat.is_simple_ident()) {
-                        return Vec::new();
-                    }
-                    usages_for_locals(param.as_local(ctx.sema.db))
-                        .map(|FileReference { name, range, .. }| match name {
-                            ast::NameLike::NameRef(_) => body
-                                .syntax()
-                                .covering_element(range)
-                                .ancestors()
-                                .nth(3)
-                                .and_then(ast::PathExpr::cast),
-                            _ => None,
-                        })
-                        .collect::<Option<Vec<_>>>()
-                        .unwrap_or_default()
-                })
-                .collect();
-
-            // Rewrite `self` to `this`
-            if param_list.self_param().is_some() {
-                let this = || make::name_ref("this").syntax().clone_for_update();
-                usages_for_locals(params[0].2.as_local(ctx.sema.db))
-                    .flat_map(|FileReference { name, range, .. }| match name {
-                        ast::NameLike::NameRef(_) => Some(body.syntax().covering_element(range)),
-                        _ => None,
-                    })
-                    .for_each(|it| {
-                        ted::replace(it, &this());
-                    })
+fn inline(
+    sema: &Semantics<RootDatabase>,
+    function_def_file_id: FileId,
+    function: hir::Function,
+    fn_body: &ast::BlockExpr,
+    params: &[(ast::Pat, Option<ast::Type>, hir::Param)],
+    CallInfo { node, arguments, generic_arg_list }: &CallInfo,
+) -> ast::Expr {
+    let body = fn_body.clone_for_update();
+    let usages_for_locals = |local| {
+        Definition::Local(local)
+            .usages(&sema)
+            .all()
+            .references
+            .remove(&function_def_file_id)
+            .unwrap_or_default()
+            .into_iter()
+    };
+    let param_use_nodes: Vec<Vec<_>> = params
+        .iter()
+        .map(|(pat, _, param)| {
+            if !matches!(pat, ast::Pat::IdentPat(pat) if pat.is_simple_ident()) {
+                return Vec::new();
             }
-
-            // Inline parameter expressions or generate `let` statements depending on whether inlining works or not.
-            for ((pat, param_ty, _), usages, expr) in izip!(params, param_use_nodes, arg_list).rev()
+            usages_for_locals(param.as_local(sema.db))
+                .map(|FileReference { name, range, .. }| match name {
+                    ast::NameLike::NameRef(_) => body
+                        .syntax()
+                        .covering_element(range)
+                        .ancestors()
+                        .nth(3)
+                        .and_then(ast::PathExpr::cast),
+                    _ => None,
+                })
+                .collect::<Option<Vec<_>>>()
+                .unwrap_or_default()
+        })
+        .collect();
+    if function.self_param(sema.db).is_some() {
+        let this = || make::name_ref("this").syntax().clone_for_update();
+        usages_for_locals(params[0].2.as_local(sema.db))
+            .flat_map(|FileReference { name, range, .. }| match name {
+                ast::NameLike::NameRef(_) => Some(body.syntax().covering_element(range)),
+                _ => None,
+            })
+            .for_each(|it| {
+                ted::replace(it, &this());
+            })
+    }
+    // Inline parameter expressions or generate `let` statements depending on whether inlining works or not.
+    for ((pat, param_ty, _), usages, expr) in izip!(params, param_use_nodes, arguments).rev() {
+        let expr_is_name_ref = matches!(&expr,
+            ast::Expr::PathExpr(expr)
+                if expr.path().and_then(|path| path.as_single_name_ref()).is_some()
+        );
+        match &*usages {
+            // inline single use closure arguments
+            [usage]
+                if matches!(expr, ast::Expr::ClosureExpr(_))
+                    && usage.syntax().parent().and_then(ast::Expr::cast).is_some() =>
             {
-                let expr_is_name_ref = matches!(&expr,
-                    ast::Expr::PathExpr(expr)
-                        if expr.path().and_then(|path| path.as_single_name_ref()).is_some()
-                );
-                match &*usages {
-                    // inline single use closure arguments
-                    [usage]
-                        if matches!(expr, ast::Expr::ClosureExpr(_))
-                            && usage.syntax().parent().and_then(ast::Expr::cast).is_some() =>
-                    {
-                        cov_mark::hit!(inline_call_inline_closure);
-                        let expr = make::expr_paren(expr);
-                        ted::replace(usage.syntax(), expr.syntax().clone_for_update());
-                    }
-                    // inline single use literals
-                    [usage] if matches!(expr, ast::Expr::Literal(_)) => {
-                        cov_mark::hit!(inline_call_inline_literal);
-                        ted::replace(usage.syntax(), expr.syntax().clone_for_update());
-                    }
-                    // inline direct local arguments
-                    [_, ..] if expr_is_name_ref => {
-                        cov_mark::hit!(inline_call_inline_locals);
-                        usages.into_iter().for_each(|usage| {
-                            ted::replace(usage.syntax(), &expr.syntax().clone_for_update());
-                        });
-                    }
-                    // cant inline, emit a let statement
-                    _ => {
-                        let ty = ctx
-                            .sema
-                            .type_of_expr(&expr)
-                            .filter(TypeInfo::has_adjustment)
-                            .and(param_ty);
-                        body.push_front(
-                            make::let_stmt(pat, ty, Some(expr)).clone_for_update().into(),
-                        )
-                    }
-                }
+                cov_mark::hit!(inline_call_inline_closure);
+                let expr = make::expr_paren(expr.clone());
+                ted::replace(usage.syntax(), expr.syntax().clone_for_update());
+            }
+            // inline single use literals
+            [usage] if matches!(expr, ast::Expr::Literal(_)) => {
+                cov_mark::hit!(inline_call_inline_literal);
+                ted::replace(usage.syntax(), expr.syntax().clone_for_update());
             }
-            if let Some(generic_arg_list) = generic_arg_list {
-                PathTransform::function_call(
-                    &ctx.sema.scope(expr.syntax()),
-                    &ctx.sema.scope(fn_body.syntax()),
-                    function,
-                    generic_arg_list,
+            // inline direct local arguments
+            [_, ..] if expr_is_name_ref => {
+                cov_mark::hit!(inline_call_inline_locals);
+                usages.into_iter().for_each(|usage| {
+                    ted::replace(usage.syntax(), &expr.syntax().clone_for_update());
+                });
+            }
+            // cant inline, emit a let statement
+            _ => {
+                let ty =
+                    sema.type_of_expr(expr).filter(TypeInfo::has_adjustment).and(param_ty.clone());
+                body.push_front(
+                    make::let_stmt(pat.clone(), ty, Some(expr.clone())).clone_for_update().into(),
                 )
-                .apply(body.syntax());
             }
+        }
+    }
+    if let Some(generic_arg_list) = generic_arg_list.clone() {
+        PathTransform::function_call(
+            &sema.scope(node.syntax()),
+            &sema.scope(fn_body.syntax()),
+            function,
+            generic_arg_list,
+        )
+        .apply(body.syntax());
+    }
 
-            let original_indentation = expr.indent_level();
-            body.reindent_to(original_indentation);
+    let original_indentation = match node {
+        CallExprNode::Call(it) => it.indent_level(),
+        CallExprNode::MethodCallExpr(it) => it.indent_level(),
+    };
+    body.reindent_to(original_indentation);
 
-            let replacement = match body.tail_expr() {
-                Some(expr) if body.statements().next().is_none() => expr,
-                _ => ast::Expr::BlockExpr(body),
-            };
-            builder.replace_ast(expr, replacement);
-        },
-    )
+    match body.tail_expr() {
+        Some(expr) if body.statements().next().is_none() => expr,
+        _ => ast::Expr::BlockExpr(body),
+    }
 }
 
 #[cfg(test)]
@@ -694,4 +857,117 @@ fn main() {
 "#,
         );
     }
+
+    #[test]
+    fn inline_callers() {
+        check_assist(
+            inline_into_callers,
+            r#"
+fn do_the_math$0(b: u32) -> u32 {
+    let foo = 10;
+    foo * b + foo
+}
+fn foo() {
+    do_the_math(0);
+    let bar = 10;
+    do_the_math(bar);
+}
+"#,
+            r#"
+
+fn foo() {
+    {
+        let foo = 10;
+        foo * 0 + foo
+    };
+    let bar = 10;
+    {
+        let foo = 10;
+        foo * bar + foo
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn inline_callers_across_files() {
+        check_assist(
+            inline_into_callers,
+            r#"
+//- /lib.rs
+mod foo;
+fn do_the_math$0(b: u32) -> u32 {
+    let foo = 10;
+    foo * b + foo
+}
+//- /foo.rs
+use super::do_the_math;
+fn foo() {
+    do_the_math(0);
+    let bar = 10;
+    do_the_math(bar);
+}
+"#,
+            r#"
+use super::do_the_math;
+fn foo() {
+    {
+        let foo = 10;
+        foo * 0 + foo
+    };
+    let bar = 10;
+    {
+        let foo = 10;
+        foo * bar + foo
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn inline_callers_across_files_with_def_file() {
+        check_assist(
+            inline_into_callers,
+            r#"
+//- /lib.rs
+mod foo;
+fn do_the_math$0(b: u32) -> u32 {
+    let foo = 10;
+    foo * b + foo
+}
+fn bar(a: u32, b: u32) -> u32 {
+    do_the_math(0);
+}
+//- /foo.rs
+use super::do_the_math;
+fn foo() {
+    do_the_math(0);
+}
+"#,
+            r#"
+//- /lib.rs
+mod foo;
+fn do_the_math(b: u32) -> u32 {
+    let foo = 10;
+    foo * b + foo
+}
+fn bar(a: u32, b: u32) -> u32 {
+    {
+        let foo = 10;
+        foo * 0 + foo
+    };
+}
+//- /foo.rs
+use super::do_the_math;
+fn foo() {
+    {
+        let foo = 10;
+        foo * 0 + foo
+    };
+}
+"#,
+        );
+    }
 }
diff --git a/crates/ide_assists/src/lib.rs b/crates/ide_assists/src/lib.rs
index b2c8cd76941..cad2a6d8530 100644
--- a/crates/ide_assists/src/lib.rs
+++ b/crates/ide_assists/src/lib.rs
@@ -216,6 +216,7 @@ mod handlers {
             generate_is_empty_from_len::generate_is_empty_from_len,
             generate_new::generate_new,
             inline_call::inline_call,
+            inline_call::inline_into_callers,
             inline_local_variable::inline_local_variable,
             introduce_named_generic::introduce_named_generic,
             introduce_named_lifetime::introduce_named_lifetime,
diff --git a/crates/ide_assists/src/tests/generated.rs b/crates/ide_assists/src/tests/generated.rs
index 95a68ca9893..8974b8099d0 100644
--- a/crates/ide_assists/src/tests/generated.rs
+++ b/crates/ide_assists/src/tests/generated.rs
@@ -1052,6 +1052,43 @@ fn foo(name: Option<&str>) {
 }
 
 #[test]
+fn doctest_inline_into_callers() {
+    check_doc_test(
+        "inline_into_callers",
+        r#####"
+fn print(_: &str) {}
+fn foo$0(word: &str) {
+    if !word.is_empty() {
+        print(word);
+    }
+}
+fn bar() {
+    foo("안녕하세요");
+    foo("여러분");
+}
+"#####,
+        r#####"
+fn print(_: &str) {}
+
+fn bar() {
+    {
+        let word = "안녕하세요";
+        if !word.is_empty() {
+            print(word);
+        }
+    };
+    {
+        let word = "여러분";
+        if !word.is_empty() {
+            print(word);
+        }
+    };
+}
+"#####,
+    )
+}
+
+#[test]
 fn doctest_inline_local_variable() {
     check_doc_test(
         "inline_local_variable",