about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2022-01-31 11:56:42 +0100
committerLukas Wirth <lukastw97@gmail.com>2022-01-31 11:56:42 +0100
commit6194092086045cb430e3391dc0bca0ac2368bdd2 (patch)
tree35183f8618ef20e198cec24a7bbc0139c9a8104a
parentddf7b70a0f8f7fc1e49d2bf0365752be3b4aab8b (diff)
downloadrust-6194092086045cb430e3391dc0bca0ac2368bdd2.tar.gz
rust-6194092086045cb430e3391dc0bca0ac2368bdd2.zip
Complete local fn and closure params from surrounding locals scope
-rw-r--r--crates/hir/src/semantics.rs4
-rw-r--r--crates/ide_completion/src/completions/fn_param.rs141
-rw-r--r--crates/ide_completion/src/context.rs90
-rw-r--r--crates/ide_completion/src/render/pattern.rs2
-rw-r--r--crates/ide_completion/src/tests/fn_param.rs50
5 files changed, 199 insertions, 88 deletions
diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs
index eefc12570d7..a210574a06e 100644
--- a/crates/hir/src/semantics.rs
+++ b/crates/hir/src/semantics.rs
@@ -389,8 +389,8 @@ impl<'db, DB: HirDatabase> Semantics<'db, DB> {
         self.imp.scope(node)
     }
 
-    pub fn scope_at_offset(&self, token: &SyntaxToken, offset: TextSize) -> SemanticsScope<'db> {
-        self.imp.scope_at_offset(&token.parent().unwrap(), offset)
+    pub fn scope_at_offset(&self, node: &SyntaxNode, offset: TextSize) -> SemanticsScope<'db> {
+        self.imp.scope_at_offset(&node, offset)
     }
 
     pub fn scope_for_def(&self, def: Trait) -> SemanticsScope<'db> {
diff --git a/crates/ide_completion/src/completions/fn_param.rs b/crates/ide_completion/src/completions/fn_param.rs
index a55bab67e11..961681c20cb 100644
--- a/crates/ide_completion/src/completions/fn_param.rs
+++ b/crates/ide_completion/src/completions/fn_param.rs
@@ -1,9 +1,11 @@
 //! See [`complete_fn_param`].
 
+use hir::HirDisplay;
 use rustc_hash::FxHashMap;
 use syntax::{
+    algo,
     ast::{self, HasModuleItem},
-    match_ast, AstNode, SyntaxKind,
+    match_ast, AstNode, Direction, SyntaxKind,
 };
 
 use crate::{
@@ -15,14 +17,48 @@ use crate::{
 /// functions in a file have a `spam: &mut Spam` parameter, a completion with
 /// `spam: &mut Spam` insert text/label and `spam` lookup string will be
 /// suggested.
+///
+/// Also complete parameters for closure or local functions from the surrounding defined locals.
 pub(crate) fn complete_fn_param(acc: &mut Completions, ctx: &CompletionContext) -> Option<()> {
-    let param_of_fn =
-        matches!(ctx.pattern_ctx, Some(PatternContext { is_param: Some(ParamKind::Function), .. }));
+    let (param_list, _, param_kind) = match &ctx.pattern_ctx {
+        Some(PatternContext { param_ctx: Some(kind), .. }) => kind,
+        _ => return None,
+    };
+
+    let comma_wrapper = comma_wrapper(ctx);
+    let mut add_new_item_to_acc = |label: &str, lookup: String| {
+        let mk_item = |label: &str| {
+            CompletionItem::new(CompletionItemKind::Binding, ctx.source_range(), label)
+        };
+        let mut item = match &comma_wrapper {
+            Some(fmt) => mk_item(&fmt(&label)),
+            None => mk_item(label),
+        };
+        item.lookup_by(lookup);
+        item.add_to(acc)
+    };
 
-    if !param_of_fn {
-        return None;
+    match param_kind {
+        ParamKind::Function(function) => {
+            fill_fn_params(ctx, function, &param_list, add_new_item_to_acc);
+        }
+        ParamKind::Closure(closure) => {
+            let stmt_list = closure.syntax().ancestors().find_map(ast::StmtList::cast)?;
+            params_from_stmt_list_scope(ctx, stmt_list, |name, ty| {
+                add_new_item_to_acc(&format!("{name}: {ty}"), name.to_string());
+            });
+        }
     }
 
+    Some(())
+}
+
+fn fill_fn_params(
+    ctx: &CompletionContext,
+    function: &ast::Fn,
+    param_list: &ast::ParamList,
+    mut add_new_item_to_acc: impl FnMut(&str, String),
+) {
     let mut file_params = FxHashMap::default();
 
     let mut extract_params = |f: ast::Fn| {
@@ -56,23 +92,46 @@ pub(crate) fn complete_fn_param(acc: &mut Completions, ctx: &CompletionContext)
         };
     }
 
-    let function = ctx.token.ancestors().find_map(ast::Fn::cast)?;
-    let param_list = function.param_list()?;
+    if let Some(stmt_list) = function.syntax().parent().and_then(ast::StmtList::cast) {
+        params_from_stmt_list_scope(ctx, stmt_list, |name, ty| {
+            file_params.entry(format!("{name}: {ty}")).or_insert(name.to_string());
+        });
+    }
 
     remove_duplicated(&mut file_params, param_list.params());
-
     let self_completion_items = ["self", "&self", "mut self", "&mut self"];
     if should_add_self_completions(ctx, param_list) {
-        self_completion_items.into_iter().for_each(|self_item| {
-            add_new_item_to_acc(ctx, acc, self_item.to_string(), self_item.to_string())
-        });
+        self_completion_items
+            .into_iter()
+            .for_each(|self_item| add_new_item_to_acc(self_item, self_item.to_string()));
     }
 
-    file_params.into_iter().try_for_each(|(whole_param, binding)| {
-        Some(add_new_item_to_acc(ctx, acc, surround_with_commas(ctx, whole_param), binding))
-    })?;
+    file_params
+        .into_iter()
+        .for_each(|(whole_param, binding)| add_new_item_to_acc(&whole_param, binding));
+}
 
-    Some(())
+fn params_from_stmt_list_scope(
+    ctx: &CompletionContext,
+    stmt_list: ast::StmtList,
+    mut cb: impl FnMut(hir::Name, String),
+) {
+    let syntax_node = match stmt_list.syntax().last_child() {
+        Some(it) => it,
+        None => return,
+    };
+    let scope = ctx.sema.scope_at_offset(stmt_list.syntax(), syntax_node.text_range().end());
+    let module = match scope.module() {
+        Some(it) => it,
+        None => return,
+    };
+    scope.process_all_names(&mut |name, def| {
+        if let hir::ScopeDef::Local(local) = def {
+            if let Ok(ty) = local.ty(ctx.db).display_source_code(ctx.db, module.into()) {
+                cb(name, ty);
+            }
+        }
+    });
 }
 
 fn remove_duplicated(
@@ -96,52 +155,32 @@ fn remove_duplicated(
     })
 }
 
-fn should_add_self_completions(ctx: &CompletionContext, param_list: ast::ParamList) -> bool {
+fn should_add_self_completions(ctx: &CompletionContext, param_list: &ast::ParamList) -> bool {
     let inside_impl = ctx.impl_def.is_some();
     let no_params = param_list.params().next().is_none() && param_list.self_param().is_none();
 
     inside_impl && no_params
 }
 
-fn surround_with_commas(ctx: &CompletionContext, param: String) -> String {
-    match fallible_surround_with_commas(ctx, &param) {
-        Some(surrounded) => surrounded,
-        // fallback to the original parameter
-        None => param,
-    }
-}
-
-fn fallible_surround_with_commas(ctx: &CompletionContext, param: &str) -> Option<String> {
-    let next_token = {
+fn comma_wrapper(ctx: &CompletionContext) -> Option<impl Fn(&str) -> String> {
+    let next_token_kind = {
         let t = ctx.token.next_token()?;
-        match t.kind() {
-            SyntaxKind::WHITESPACE => t.next_token()?,
-            _ => t,
-        }
+        let t = algo::skip_whitespace_token(t, Direction::Next)?;
+        t.kind()
     };
-
-    let trailing_comma_missing = matches!(next_token.kind(), SyntaxKind::IDENT);
-    let trailing = if trailing_comma_missing { "," } else { "" };
-
-    let previous_token = if matches!(ctx.token.kind(), SyntaxKind::IDENT | SyntaxKind::WHITESPACE) {
-        ctx.previous_token.as_ref()?
-    } else {
-        &ctx.token
+    let prev_token_kind = {
+        let t = ctx.previous_token.clone()?;
+        let t = algo::skip_whitespace_token(t, Direction::Prev)?;
+        t.kind()
     };
 
-    let needs_leading = !matches!(previous_token.kind(), SyntaxKind::L_PAREN | SyntaxKind::COMMA);
-    let leading = if needs_leading { ", " } else { "" };
+    let has_trailing_comma =
+        matches!(next_token_kind, SyntaxKind::COMMA | SyntaxKind::R_PAREN | SyntaxKind::PIPE);
+    let trailing = if has_trailing_comma { "" } else { "," };
 
-    Some(format!("{}{}{}", leading, param, trailing))
-}
+    let has_leading_comma =
+        matches!(prev_token_kind, SyntaxKind::COMMA | SyntaxKind::L_PAREN | SyntaxKind::PIPE);
+    let leading = if has_leading_comma { "" } else { ", " };
 
-fn add_new_item_to_acc(
-    ctx: &CompletionContext,
-    acc: &mut Completions,
-    label: String,
-    lookup: String,
-) {
-    let mut item = CompletionItem::new(CompletionItemKind::Binding, ctx.source_range(), label);
-    item.lookup_by(lookup);
-    item.add_to(acc)
+    Some(move |param: &_| format!("{}{}{}", leading, param, trailing))
 }
diff --git a/crates/ide_completion/src/context.rs b/crates/ide_completion/src/context.rs
index 9eec4fd0c96..59c16c08d6c 100644
--- a/crates/ide_completion/src/context.rs
+++ b/crates/ide_completion/src/context.rs
@@ -27,6 +27,8 @@ use crate::{
     CompletionConfig,
 };
 
+const COMPLETION_MARKER: &str = "intellijRulezz";
+
 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
 pub(crate) enum PatternRefutability {
     Refutable,
@@ -68,7 +70,7 @@ pub(crate) struct PathCompletionContext {
 #[derive(Debug)]
 pub(super) struct PatternContext {
     pub(super) refutability: PatternRefutability,
-    pub(super) is_param: Option<ParamKind>,
+    pub(super) param_ctx: Option<(ast::ParamList, ast::Param, ParamKind)>,
     pub(super) has_type_ascription: bool,
 }
 
@@ -80,10 +82,10 @@ pub(super) enum LifetimeContext {
     LabelDef,
 }
 
-#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, PartialEq, Eq)]
 pub(crate) enum ParamKind {
-    Function,
-    Closure,
+    Function(ast::Fn),
+    Closure(ast::ClosureExpr),
 }
 
 /// `CompletionContext` is created early during completion to figure out, where
@@ -382,7 +384,7 @@ impl<'a> CompletionContext<'a> {
         // actual completion.
         let file_with_fake_ident = {
             let parse = db.parse(file_id);
-            let edit = Indel::insert(offset, "intellijRulezz".to_string());
+            let edit = Indel::insert(offset, COMPLETION_MARKER.to_string());
             parse.reparse(&edit).tree()
         };
         let fake_ident_token =
@@ -390,7 +392,7 @@ impl<'a> CompletionContext<'a> {
 
         let original_token = original_file.syntax().token_at_offset(offset).left_biased()?;
         let token = sema.descend_into_macros_single(original_token.clone());
-        let scope = sema.scope_at_offset(&token, offset);
+        let scope = sema.scope_at_offset(&token.parent()?, offset);
         let krate = scope.krate();
         let mut locals = vec![];
         scope.process_all_names(&mut |name, scope| {
@@ -723,7 +725,7 @@ impl<'a> CompletionContext<'a> {
                 }
             }
             ast::NameLike::Name(name) => {
-                self.pattern_ctx = Self::classify_name(&self.sema, name);
+                self.pattern_ctx = Self::classify_name(&self.sema, original_file, name);
             }
         }
     }
@@ -750,7 +752,11 @@ impl<'a> CompletionContext<'a> {
         })
     }
 
-    fn classify_name(_sema: &Semantics<RootDatabase>, name: ast::Name) -> Option<PatternContext> {
+    fn classify_name(
+        _sema: &Semantics<RootDatabase>,
+        original_file: &SyntaxNode,
+        name: ast::Name,
+    ) -> Option<PatternContext> {
         let bind_pat = name.syntax().parent().and_then(ast::IdentPat::cast)?;
         let is_name_in_field_pat = bind_pat
             .syntax()
@@ -763,7 +769,7 @@ impl<'a> CompletionContext<'a> {
         if !bind_pat.is_simple_ident() {
             return None;
         }
-        Some(pattern_context_for(bind_pat.into()))
+        Some(pattern_context_for(original_file, bind_pat.into()))
     }
 
     fn classify_name_ref(
@@ -799,15 +805,15 @@ impl<'a> CompletionContext<'a> {
                     },
                     ast::TupleStructPat(it) => {
                         path_ctx.has_call_parens = true;
-                        pat_ctx = Some(pattern_context_for(it.into()));
+                        pat_ctx = Some(pattern_context_for(original_file, it.into()));
                         Some(PathKind::Pat)
                     },
                     ast::RecordPat(it) => {
-                        pat_ctx = Some(pattern_context_for(it.into()));
+                        pat_ctx = Some(pattern_context_for(original_file, it.into()));
                         Some(PathKind::Pat)
                     },
                     ast::PathPat(it) => {
-                        pat_ctx = Some(pattern_context_for(it.into()));
+                        pat_ctx = Some(pattern_context_for(original_file, it.into()));
                         Some(PathKind::Pat)
                     },
                     ast::MacroCall(it) => it.excl_token().and(Some(PathKind::Mac)),
@@ -824,12 +830,7 @@ impl<'a> CompletionContext<'a> {
             path_ctx.use_tree_parent = use_tree_parent;
             path_ctx.qualifier = path
                 .segment()
-                .and_then(|it| {
-                    find_node_with_range::<ast::PathSegment>(
-                        original_file,
-                        it.syntax().text_range(),
-                    )
-                })
+                .and_then(|it| find_node_in_file(original_file, &it))
                 .map(|it| it.parent_path());
             return Some((path_ctx, pat_ctx));
         }
@@ -864,7 +865,7 @@ impl<'a> CompletionContext<'a> {
     }
 }
 
-fn pattern_context_for(pat: ast::Pat) -> PatternContext {
+fn pattern_context_for(original_file: &SyntaxNode, pat: ast::Pat) -> PatternContext {
     let mut is_param = None;
     let (refutability, has_type_ascription) =
     pat
@@ -877,18 +878,21 @@ fn pattern_context_for(pat: ast::Pat) -> PatternContext {
                 match node {
                     ast::LetStmt(let_) => return (PatternRefutability::Irrefutable, let_.ty().is_some()),
                     ast::Param(param) => {
-                        let is_closure_param = param
-                            .syntax()
-                            .ancestors()
-                            .nth(2)
-                            .and_then(ast::ClosureExpr::cast)
-                            .is_some();
-                        is_param = Some(if is_closure_param {
-                            ParamKind::Closure
-                        } else {
-                            ParamKind::Function
-                        });
-                        return (PatternRefutability::Irrefutable, param.ty().is_some())
+                        let has_type_ascription = param.ty().is_some();
+                        is_param = (|| {
+                            let fake_param_list = param.syntax().parent().and_then(ast::ParamList::cast)?;
+                            let param_list = find_node_in_file_compensated(original_file, &fake_param_list)?;
+                            let param_list_owner = param_list.syntax().parent()?;
+                            let kind = match_ast! {
+                                match param_list_owner {
+                                    ast::ClosureExpr(closure) => ParamKind::Closure(closure),
+                                    ast::Fn(fn_) => ParamKind::Function(fn_),
+                                    _ => return None,
+                                }
+                            };
+                            Some((param_list, param, kind))
+                        })();
+                        return (PatternRefutability::Irrefutable, has_type_ascription)
                     },
                     ast::MatchArm(_) => PatternRefutability::Refutable,
                     ast::Condition(_) => PatternRefutability::Refutable,
@@ -898,11 +902,29 @@ fn pattern_context_for(pat: ast::Pat) -> PatternContext {
             };
             (refutability, false)
         });
-    PatternContext { refutability, is_param, has_type_ascription }
+    PatternContext { refutability, param_ctx: is_param, has_type_ascription }
+}
+
+fn find_node_in_file<N: AstNode>(syntax: &SyntaxNode, node: &N) -> Option<N> {
+    let syntax_range = syntax.text_range();
+    let range = node.syntax().text_range();
+    let intersection = range.intersect(syntax_range)?;
+    syntax.covering_element(intersection).ancestors().find_map(N::cast)
 }
 
-fn find_node_with_range<N: AstNode>(syntax: &SyntaxNode, range: TextRange) -> Option<N> {
-    syntax.covering_element(range).ancestors().find_map(N::cast)
+/// Compensates for the offset introduced by the fake ident
+/// This is wrong if `node` comes before the insertion point! Use `find_node_in_file` instead.
+fn find_node_in_file_compensated<N: AstNode>(syntax: &SyntaxNode, node: &N) -> Option<N> {
+    let syntax_range = syntax.text_range();
+    let range = node.syntax().text_range();
+    let end = range.end().checked_sub(TextSize::try_from(COMPLETION_MARKER.len()).ok()?)?;
+    if end < range.start() {
+        return None;
+    }
+    let range = TextRange::new(range.start(), end);
+    // our inserted ident could cause `range` to be go outside of the original syntax, so cap it
+    let intersection = range.intersect(syntax_range)?;
+    syntax.covering_element(intersection).ancestors().find_map(N::cast)
 }
 
 fn path_or_use_tree_qualifier(path: &ast::Path) -> Option<(ast::Path, bool)> {
diff --git a/crates/ide_completion/src/render/pattern.rs b/crates/ide_completion/src/render/pattern.rs
index 888a5b4b0a7..e486d9f2b91 100644
--- a/crates/ide_completion/src/render/pattern.rs
+++ b/crates/ide_completion/src/render/pattern.rs
@@ -87,7 +87,7 @@ fn render_pat(
     if matches!(
         ctx.completion.pattern_ctx,
         Some(PatternContext {
-            is_param: Some(ParamKind::Function),
+            param_ctx: Some((.., ParamKind::Function(_))),
             has_type_ascription: false,
             ..
         })
diff --git a/crates/ide_completion/src/tests/fn_param.rs b/crates/ide_completion/src/tests/fn_param.rs
index 940cecf395d..662fbe309bc 100644
--- a/crates/ide_completion/src/tests/fn_param.rs
+++ b/crates/ide_completion/src/tests/fn_param.rs
@@ -156,3 +156,53 @@ impl A {
         "#]],
     )
 }
+
+// doesn't complete qux due to there being no expression after
+// see source_analyzer::adjust comment
+#[test]
+fn local_fn_shows_locals_for_params() {
+    check(
+        r#"
+fn outer() {
+    let foo = 3;
+    {
+        let bar = 3;
+        fn inner($0) {}
+        let baz = 3;
+        let qux = 3;
+    }
+    let fez = 3;
+}
+"#,
+        expect![[r#"
+            bn foo: i32
+            bn baz: i32
+            bn bar: i32
+            kw mut
+        "#]],
+    )
+}
+
+#[test]
+fn closure_shows_locals_for_params() {
+    check(
+        r#"
+fn outer() {
+    let foo = 3;
+    {
+        let bar = 3;
+        |$0| {};
+        let baz = 3;
+        let qux = 3;
+    }
+    let fez = 3;
+}
+"#,
+        expect![[r#"
+            bn baz: i32
+            bn bar: i32
+            bn foo: i32
+            kw mut
+        "#]],
+    )
+}