about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide-assists/src/handlers/unqualify_method_call.rs52
-rw-r--r--crates/ide-assists/src/tests/generated.rs2
2 files changed, 53 insertions, 1 deletions
diff --git a/crates/ide-assists/src/handlers/unqualify_method_call.rs b/crates/ide-assists/src/handlers/unqualify_method_call.rs
index e9d4e270cdc..0bf1782a489 100644
--- a/crates/ide-assists/src/handlers/unqualify_method_call.rs
+++ b/crates/ide-assists/src/handlers/unqualify_method_call.rs
@@ -1,3 +1,4 @@
+use ide_db::imports::insert_use::ImportScope;
 use syntax::{
     ast::{self, make, AstNode, HasArgList},
     TextRange,
@@ -17,6 +18,8 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
 // ```
 // ->
 // ```
+// use std::ops::Add;
+//
 // fn main() {
 //     1.add(2);
 // }
@@ -38,7 +41,7 @@ pub(crate) fn unqualify_method_call(acc: &mut Assists, ctx: &AssistContext<'_>)
     let first_arg = args_iter.next()?;
     let second_arg = args_iter.next();
 
-    _ = path.qualifier()?;
+    let qualifier = path.qualifier()?;
     let method_name = path.segment()?.name_ref()?;
 
     let res = ctx.sema.resolve_path(&path)?;
@@ -76,10 +79,51 @@ pub(crate) fn unqualify_method_call(acc: &mut Assists, ctx: &AssistContext<'_>)
                 edit.insert(close, ")");
             }
             edit.replace(replace_comma, format!(".{method_name}("));
+            add_import(qualifier, ctx, edit);
         },
     )
 }
 
+fn add_import(
+    qualifier: ast::Path,
+    ctx: &AssistContext<'_>,
+    edit: &mut ide_db::source_change::SourceChangeBuilder,
+) {
+    if let Some(path_segment) = qualifier.segment() {
+        // for `<i32 as std::ops::Add>`
+        let path_type = path_segment.syntax().children().filter_map(ast::PathType::cast).last();
+        let import = match path_type {
+            Some(it) => {
+                if let Some(path) = it.path() {
+                    path
+                } else {
+                    return;
+                }
+            }
+            None => qualifier,
+        };
+
+        // in case for `<_>`
+        if import.coloncolon_token().is_none() {
+            return;
+        }
+
+        let scope = ide_db::imports::insert_use::ImportScope::find_insert_use_container(
+            import.syntax(),
+            &ctx.sema,
+        );
+
+        if let Some(scope) = scope {
+            let scope = match scope {
+                ImportScope::File(it) => ImportScope::File(edit.make_mut(it)),
+                ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)),
+                ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)),
+            };
+            ide_db::imports::insert_use::insert_use(&scope, import, &ctx.config.insert_use);
+        }
+    }
+}
+
 fn needs_parens_as_receiver(expr: &ast::Expr) -> bool {
     // Make `(expr).dummy()`
     let dummy_call = make::expr_method_call(
@@ -127,6 +171,8 @@ fn f() { S.f(S); }"#,
 //- minicore: add
 fn f() { <u32 as core::ops::Add>::$0add(2, 2); }"#,
             r#"
+use core::ops::Add;
+
 fn f() { 2.add(2); }"#,
         );
 
@@ -136,6 +182,8 @@ fn f() { 2.add(2); }"#,
 //- minicore: add
 fn f() { core::ops::Add::$0add(2, 2); }"#,
             r#"
+use core::ops::Add;
+
 fn f() { 2.add(2); }"#,
         );
 
@@ -179,6 +227,8 @@ impl core::ops::Deref for S {
 }
 fn f() { core::ops::Deref::$0deref(&S); }"#,
             r#"
+use core::ops::Deref;
+
 struct S;
 impl core::ops::Deref for S {
     type Target = S;
diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs
index e9d0d373ee7..8523632acfb 100644
--- a/crates/ide-assists/src/tests/generated.rs
+++ b/crates/ide-assists/src/tests/generated.rs
@@ -2948,6 +2948,8 @@ fn main() {
 mod std { pub mod ops { pub trait Add { fn add(self, _: Self) {} } impl Add for i32 {} } }
 "#####,
         r#####"
+use std::ops::Add;
+
 fn main() {
     1.add(2);
 }