about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2024-05-18 12:35:55 +0200
committerLukas Wirth <lukastw97@gmail.com>2024-05-18 12:35:55 +0200
commit7045044da359af8be28f44a8acfaa69f6b2682a9 (patch)
tree2eb6be2b8d9fe301406f0d05904ebc95efd27d8a
parent9ff4ffb81763a1532d792c81f1e4f61f8d04440b (diff)
downloadrust-7045044da359af8be28f44a8acfaa69f6b2682a9.tar.gz
rust-7045044da359af8be28f44a8acfaa69f6b2682a9.zip
Allow hir::Param to refer to other entity params aside from functions
-rw-r--r--src/tools/rust-analyzer/crates/hir/src/lib.rs127
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_method_eager_lazy.rs4
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/utils/suggest_name.rs7
-rw-r--r--src/tools/rust-analyzer/crates/ide-db/src/active_parameter.rs14
-rw-r--r--src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/format_string.rs1
-rw-r--r--src/tools/rust-analyzer/crates/ide/src/inlay_hints/param_name.rs8
-rw-r--r--src/tools/rust-analyzer/crates/ide/src/signature_help.rs17
-rw-r--r--src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/escape.rs35
-rw-r--r--src/tools/rust-analyzer/crates/syntax/src/ast/token_ext.rs19
9 files changed, 114 insertions, 118 deletions
diff --git a/src/tools/rust-analyzer/crates/hir/src/lib.rs b/src/tools/rust-analyzer/crates/hir/src/lib.rs
index a902ae2d336..49e6241f7aa 100644
--- a/src/tools/rust-analyzer/crates/hir/src/lib.rs
+++ b/src/tools/rust-analyzer/crates/hir/src/lib.rs
@@ -35,7 +35,7 @@ pub mod term_search;
 
 mod display;
 
-use std::{iter, mem::discriminant, ops::ControlFlow};
+use std::{mem::discriminant, ops::ControlFlow};
 
 use arrayvec::ArrayVec;
 use base_db::{CrateDisplayName, CrateId, CrateOrigin, FileId};
@@ -52,7 +52,6 @@ use hir_def::{
     path::ImportAlias,
     per_ns::PerNs,
     resolver::{HasResolver, Resolver},
-    src::HasSource as _,
     AssocItemId, AssocItemLoc, AttrDefId, ConstId, ConstParamId, CrateRootModuleId, DefWithBodyId,
     EnumId, EnumVariantId, ExternCrateId, FunctionId, GenericDefId, GenericParamId, HasModule,
     ImplId, InTypeConstId, ItemContainerId, LifetimeParamId, LocalFieldId, Lookup, MacroExpander,
@@ -1965,7 +1964,7 @@ impl Function {
             .enumerate()
             .map(|(idx, ty)| {
                 let ty = Type { env: environment.clone(), ty: ty.clone() };
-                Param { func: self, ty, idx }
+                Param { func: Callee::Def(CallableDefId::FunctionId(self.id)), ty, idx }
             })
             .collect()
     }
@@ -1991,7 +1990,7 @@ impl Function {
             .skip(skip)
             .map(|(idx, ty)| {
                 let ty = Type { env: environment.clone(), ty: ty.clone() };
-                Param { func: self, ty, idx }
+                Param { func: Callee::Def(CallableDefId::FunctionId(self.id)), ty, idx }
             })
             .collect()
     }
@@ -2037,7 +2036,7 @@ impl Function {
             .skip(skip)
             .map(|(idx, ty)| {
                 let ty = Type { env: environment.clone(), ty: ty.clone() };
-                Param { func: self, ty, idx }
+                Param { func: Callee::Def(CallableDefId::FunctionId(self.id)), ty, idx }
             })
             .collect()
     }
@@ -2167,17 +2166,24 @@ impl From<hir_ty::Mutability> for Access {
 
 #[derive(Clone, PartialEq, Eq, Hash, Debug)]
 pub struct Param {
-    func: Function,
+    func: Callee,
     /// The index in parameter list, including self parameter.
     idx: usize,
     ty: Type,
 }
 
 impl Param {
-    pub fn parent_fn(&self) -> Function {
-        self.func
+    pub fn parent_fn(&self) -> Option<Function> {
+        match self.func {
+            Callee::Def(CallableDefId::FunctionId(f)) => Some(f.into()),
+            _ => None,
+        }
     }
 
+    // pub fn parent_closure(&self) -> Option<Closure> {
+    //     self.func.as_ref().right().cloned()
+    // }
+
     pub fn index(&self) -> usize {
         self.idx
     }
@@ -2191,7 +2197,11 @@ impl Param {
     }
 
     pub fn as_local(&self, db: &dyn HirDatabase) -> Option<Local> {
-        let parent = DefWithBodyId::FunctionId(self.func.into());
+        let parent = match self.func {
+            Callee::Def(CallableDefId::FunctionId(it)) => DefWithBodyId::FunctionId(it),
+            Callee::Closure(closure) => db.lookup_intern_closure(closure.into()).0,
+            _ => return None,
+        };
         let body = db.body(parent);
         if let Some(self_param) = body.self_param.filter(|_| self.idx == 0) {
             Some(Local { parent, binding_id: self_param })
@@ -2205,18 +2215,45 @@ impl Param {
     }
 
     pub fn pattern_source(&self, db: &dyn HirDatabase) -> Option<ast::Pat> {
-        self.source(db).and_then(|p| p.value.pat())
+        self.source(db).and_then(|p| p.value.right()?.pat())
     }
 
-    pub fn source(&self, db: &dyn HirDatabase) -> Option<InFile<ast::Param>> {
-        let InFile { file_id, value } = self.func.source(db)?;
-        let params = value.param_list()?;
-        if params.self_param().is_some() {
-            params.params().nth(self.idx.checked_sub(params.self_param().is_some() as usize)?)
-        } else {
-            params.params().nth(self.idx)
+    pub fn source(
+        &self,
+        db: &dyn HirDatabase,
+    ) -> Option<InFile<Either<ast::SelfParam, ast::Param>>> {
+        match self.func {
+            Callee::Def(CallableDefId::FunctionId(func)) => {
+                let InFile { file_id, value } = Function { id: func }.source(db)?;
+                let params = value.param_list()?;
+                if let Some(self_param) = params.self_param() {
+                    if let Some(idx) = self.idx.checked_sub(1 as usize) {
+                        params.params().nth(idx).map(Either::Right)
+                    } else {
+                        Some(Either::Left(self_param))
+                    }
+                } else {
+                    params.params().nth(self.idx).map(Either::Right)
+                }
+                .map(|value| InFile { file_id, value })
+            }
+            Callee::Closure(closure) => {
+                let InternedClosure(owner, expr_id) = db.lookup_intern_closure(closure.into());
+                let (_, source_map) = db.body_with_source_map(owner);
+                let ast @ InFile { file_id, value } = source_map.expr_syntax(expr_id).ok()?;
+                let root = db.parse_or_expand(file_id);
+                match value.to_node(&root) {
+                    ast::Expr::ClosureExpr(it) => it
+                        .param_list()?
+                        .params()
+                        .nth(self.idx)
+                        .map(Either::Right)
+                        .map(|value| InFile { file_id: ast.file_id, value }),
+                    _ => None,
+                }
+            }
+            _ => None,
         }
-        .map(|value| InFile { file_id, value })
     }
 }
 
@@ -4919,7 +4956,7 @@ pub struct Callable {
     pub(crate) is_bound_method: bool,
 }
 
-#[derive(Debug)]
+#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
 enum Callee {
     Def(CallableDefId),
     Closure(ClosureId),
@@ -4960,43 +4997,15 @@ impl Callable {
     pub fn n_params(&self) -> usize {
         self.sig.params().len() - if self.is_bound_method { 1 } else { 0 }
     }
-    pub fn params(
-        &self,
-        db: &dyn HirDatabase,
-    ) -> Vec<(Option<Either<ast::SelfParam, ast::Pat>>, Type)> {
-        let types = self
-            .sig
+    pub fn params(&self) -> Vec<Param> {
+        self.sig
             .params()
             .iter()
+            .enumerate()
             .skip(if self.is_bound_method { 1 } else { 0 })
-            .map(|ty| self.ty.derived(ty.clone()));
-        let map_param = |it: ast::Param| it.pat().map(Either::Right);
-        let patterns = match self.callee {
-            Callee::Def(CallableDefId::FunctionId(func)) => {
-                let src = func.lookup(db.upcast()).source(db.upcast());
-                src.value.param_list().map(|param_list| {
-                    param_list
-                        .self_param()
-                        .map(|it| Some(Either::Left(it)))
-                        .filter(|_| !self.is_bound_method)
-                        .into_iter()
-                        .chain(param_list.params().map(map_param))
-                })
-            }
-            Callee::Closure(closure_id) => match closure_source(db, closure_id) {
-                Some(src) => src.param_list().map(|param_list| {
-                    param_list
-                        .self_param()
-                        .map(|it| Some(Either::Left(it)))
-                        .filter(|_| !self.is_bound_method)
-                        .into_iter()
-                        .chain(param_list.params().map(map_param))
-                }),
-                None => None,
-            },
-            _ => None,
-        };
-        patterns.into_iter().flatten().chain(iter::repeat(None)).zip(types).collect()
+            .map(|(idx, ty)| (idx, self.ty.derived(ty.clone())))
+            .map(|(idx, ty)| Param { func: self.callee, idx, ty })
+            .collect()
     }
     pub fn return_type(&self) -> Type {
         self.ty.derived(self.sig.ret().clone())
@@ -5006,18 +5015,6 @@ impl Callable {
     }
 }
 
-fn closure_source(db: &dyn HirDatabase, closure: ClosureId) -> Option<ast::ClosureExpr> {
-    let InternedClosure(owner, expr_id) = db.lookup_intern_closure(closure.into());
-    let (_, source_map) = db.body_with_source_map(owner);
-    let ast = source_map.expr_syntax(expr_id).ok()?;
-    let root = ast.file_syntax(db.upcast());
-    let expr = ast.value.to_node(&root);
-    match expr {
-        ast::Expr::ClosureExpr(it) => Some(it),
-        _ => None,
-    }
-}
-
 #[derive(Clone, Debug, Eq, PartialEq)]
 pub struct Layout(Arc<TyLayout>, Arc<TargetDataLayout>);
 
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_method_eager_lazy.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_method_eager_lazy.rs
index 7f3b0d75883..37ea5123a71 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_method_eager_lazy.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_method_eager_lazy.rs
@@ -114,10 +114,10 @@ pub(crate) fn replace_with_eager_method(acc: &mut Assists, ctx: &AssistContext<'
     let callable = ctx.sema.resolve_method_call_as_callable(&call)?;
     let (_, receiver_ty) = callable.receiver_param(ctx.sema.db)?;
     let n_params = callable.n_params() + 1;
-    let params = callable.params(ctx.sema.db);
+    let params = callable.params();
 
     // FIXME: Check that the arg is of the form `() -> T`
-    if !params.first()?.1.impls_fnonce(ctx.sema.db) {
+    if !params.first()?.ty().impls_fnonce(ctx.sema.db) {
         return None;
     }
 
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/utils/suggest_name.rs b/src/tools/rust-analyzer/crates/ide-assists/src/utils/suggest_name.rs
index 1859825b3d6..23a06404f30 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/utils/suggest_name.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/utils/suggest_name.rs
@@ -253,11 +253,8 @@ fn from_param(expr: &ast::Expr, sema: &Semantics<'_, RootDatabase>) -> Option<St
     };
 
     let (idx, _) = arg_list.args().find_position(|it| it == expr).unwrap();
-    let (pat, _) = func.params(sema.db).into_iter().nth(idx)?;
-    let pat = match pat? {
-        either::Either::Right(pat) => pat,
-        _ => return None,
-    };
+    let param = func.params().into_iter().nth(idx)?;
+    let pat = param.source(sema.db)?.value.right()?.pat()?;
     let name = var_name_from_pat(&pat)?;
     normalize(&name.to_string())
 }
diff --git a/src/tools/rust-analyzer/crates/ide-db/src/active_parameter.rs b/src/tools/rust-analyzer/crates/ide-db/src/active_parameter.rs
index 5780b5a5bb9..abc60a77a56 100644
--- a/src/tools/rust-analyzer/crates/ide-db/src/active_parameter.rs
+++ b/src/tools/rust-analyzer/crates/ide-db/src/active_parameter.rs
@@ -1,7 +1,7 @@
 //! This module provides functionality for querying callable information about a token.
 
 use either::Either;
-use hir::{Semantics, Type};
+use hir::{InFile, Semantics, Type};
 use parser::T;
 use syntax::{
     ast::{self, HasArgList, HasName},
@@ -13,7 +13,7 @@ use crate::RootDatabase;
 #[derive(Debug)]
 pub struct ActiveParameter {
     pub ty: Type,
-    pub pat: Option<Either<ast::SelfParam, ast::Pat>>,
+    pub src: Option<InFile<Either<ast::SelfParam, ast::Param>>>,
 }
 
 impl ActiveParameter {
@@ -22,18 +22,18 @@ impl ActiveParameter {
         let (signature, active_parameter) = callable_for_token(sema, token)?;
 
         let idx = active_parameter?;
-        let mut params = signature.params(sema.db);
+        let mut params = signature.params();
         if idx >= params.len() {
             cov_mark::hit!(too_many_arguments);
             return None;
         }
-        let (pat, ty) = params.swap_remove(idx);
-        Some(ActiveParameter { ty, pat })
+        let param = params.swap_remove(idx);
+        Some(ActiveParameter { ty: param.ty().clone(), src: param.source(sema.db) })
     }
 
     pub fn ident(&self) -> Option<ast::Name> {
-        self.pat.as_ref().and_then(|param| match param {
-            Either::Right(ast::Pat::IdentPat(ident)) => ident.name(),
+        self.src.as_ref().and_then(|param| match param.value.as_ref().right()?.pat()? {
+            ast::Pat::IdentPat(ident) => ident.name(),
             _ => None,
         })
     }
diff --git a/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/format_string.rs b/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/format_string.rs
index 8302b015dda..92478ef480d 100644
--- a/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/format_string.rs
+++ b/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/format_string.rs
@@ -41,6 +41,7 @@ pub enum FormatSpecifier {
     Escape,
 }
 
+// FIXME: Remove this, we can use rustc_format_parse instead
 pub fn lex_format_specifiers(
     string: &ast::String,
     mut callback: &mut dyn FnMut(TextRange, FormatSpecifier),
diff --git a/src/tools/rust-analyzer/crates/ide/src/inlay_hints/param_name.rs b/src/tools/rust-analyzer/crates/ide/src/inlay_hints/param_name.rs
index 96e845b2f32..20e8aca8491 100644
--- a/src/tools/rust-analyzer/crates/ide/src/inlay_hints/param_name.rs
+++ b/src/tools/rust-analyzer/crates/ide/src/inlay_hints/param_name.rs
@@ -24,15 +24,15 @@ pub(super) fn hints(
 
     let (callable, arg_list) = get_callable(sema, &expr)?;
     let hints = callable
-        .params(sema.db)
+        .params()
         .into_iter()
         .zip(arg_list.args())
-        .filter_map(|((param, _ty), arg)| {
+        .filter_map(|(p, arg)| {
             // Only annotate hints for expressions that exist in the original file
             let range = sema.original_range_opt(arg.syntax())?;
-            let (param_name, name_syntax) = match param.as_ref()? {
+            let (param_name, name_syntax) = match p.source(sema.db)?.value.as_ref() {
                 Either::Left(pat) => (pat.name()?, pat.name()),
-                Either::Right(pat) => match pat {
+                Either::Right(param) => match param.pat()? {
                     ast::Pat::IdentPat(it) => (it.name()?, it.name()),
                     _ => return None,
                 },
diff --git a/src/tools/rust-analyzer/crates/ide/src/signature_help.rs b/src/tools/rust-analyzer/crates/ide/src/signature_help.rs
index b2eb5a5fff1..05e605f6e4a 100644
--- a/src/tools/rust-analyzer/crates/ide/src/signature_help.rs
+++ b/src/tools/rust-analyzer/crates/ide/src/signature_help.rs
@@ -210,12 +210,15 @@ fn signature_help_for_call(
             format_to!(res.signature, "{}", self_param.display(db))
         }
         let mut buf = String::new();
-        for (idx, (pat, ty)) in callable.params(db).into_iter().enumerate() {
+        for (idx, p) in callable.params().into_iter().enumerate() {
             buf.clear();
-            if let Some(pat) = pat {
-                match pat {
-                    Either::Left(_self) => format_to!(buf, "self: "),
-                    Either::Right(pat) => format_to!(buf, "{}: ", pat),
+            if let Some(param) = p.source(sema.db) {
+                match param.value {
+                    Either::Right(param) => match param.pat() {
+                        Some(pat) => format_to!(buf, "{}: ", pat),
+                        None => format_to!(buf, "?: "),
+                    },
+                    Either::Left(_) => format_to!(buf, "self: "),
                 }
             }
             // APITs (argument position `impl Trait`s) are inferred as {unknown} as the user is
@@ -223,9 +226,9 @@ fn signature_help_for_call(
             // In that case, fall back to render definitions of the respective parameters.
             // This is overly conservative: we do not substitute known type vars
             // (see FIXME in tests::impl_trait) and falling back on any unknowns.
-            match (ty.contains_unknown(), fn_params.as_deref()) {
+            match (p.ty().contains_unknown(), fn_params.as_deref()) {
                 (true, Some(fn_params)) => format_to!(buf, "{}", fn_params[idx].ty().display(db)),
-                _ => format_to!(buf, "{}", ty.display(db)),
+                _ => format_to!(buf, "{}", p.ty().display(db)),
             }
             res.push_call_param(&buf);
         }
diff --git a/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/escape.rs b/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/escape.rs
index 0439e509d21..2f387968c96 100644
--- a/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/escape.rs
+++ b/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/escape.rs
@@ -9,8 +9,9 @@ pub(super) fn highlight_escape_string<T: IsString>(
     string: &T,
     start: TextSize,
 ) {
+    let text = string.text();
     string.escaped_char_ranges(&mut |piece_range, char| {
-        if string.text()[piece_range.start().into()..].starts_with('\\') {
+        if text[piece_range.start().into()..].starts_with('\\') {
             let highlight = match char {
                 Ok(_) => HlTag::EscapeSequence,
                 Err(_) => HlTag::InvalidEscapeSequence,
@@ -33,17 +34,15 @@ pub(super) fn highlight_escape_char(stack: &mut Highlights, char: &Char, start:
     }
 
     let text = char.text();
-    if !text.starts_with('\'') || !text.ends_with('\'') {
+    let Some(text) = text
+        .strip_prefix('\'')
+        .and_then(|it| it.strip_suffix('\''))
+        .filter(|it| it.starts_with('\\'))
+    else {
         return;
-    }
-
-    let text = &text[1..text.len() - 1];
-    if !text.starts_with('\\') {
-        return;
-    }
+    };
 
-    let range =
-        TextRange::new(start + TextSize::from(1), start + TextSize::from(text.len() as u32 + 1));
+    let range = TextRange::at(start + TextSize::from(1), TextSize::from(text.len() as u32));
     stack.add(HlRange { range, highlight: HlTag::EscapeSequence.into(), binding_hash: None })
 }
 
@@ -54,16 +53,14 @@ pub(super) fn highlight_escape_byte(stack: &mut Highlights, byte: &Byte, start:
     }
 
     let text = byte.text();
-    if !text.starts_with("b'") || !text.ends_with('\'') {
+    let Some(text) = text
+        .strip_prefix("b'")
+        .and_then(|it| it.strip_suffix('\''))
+        .filter(|it| it.starts_with('\\'))
+    else {
         return;
-    }
-
-    let text = &text[2..text.len() - 1];
-    if !text.starts_with('\\') {
-        return;
-    }
+    };
 
-    let range =
-        TextRange::new(start + TextSize::from(2), start + TextSize::from(text.len() as u32 + 2));
+    let range = TextRange::at(start + TextSize::from(2), TextSize::from(text.len() as u32));
     stack.add(HlRange { range, highlight: HlTag::EscapeSequence.into(), binding_hash: None })
 }
diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/token_ext.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/token_ext.rs
index 16599881d64..1ce548f8fc7 100644
--- a/src/tools/rust-analyzer/crates/syntax/src/ast/token_ext.rs
+++ b/src/tools/rust-analyzer/crates/syntax/src/ast/token_ext.rs
@@ -8,6 +8,7 @@ use std::{
 use rustc_lexer::unescape::{
     unescape_byte, unescape_char, unescape_mixed, unescape_unicode, EscapeError, MixedUnit, Mode,
 };
+use stdx::always;
 
 use crate::{
     ast::{self, AstToken},
@@ -181,25 +182,25 @@ pub trait IsString: AstToken {
         self.quote_offsets().map(|it| it.quotes.1)
     }
     fn escaped_char_ranges(&self, cb: &mut dyn FnMut(TextRange, Result<char, EscapeError>)) {
-        let text_range_no_quotes = match self.text_range_between_quotes() {
-            Some(it) => it,
-            None => return,
-        };
+        let Some(text_range_no_quotes) = self.text_range_between_quotes() else { return };
 
         let start = self.syntax().text_range().start();
         let text = &self.text()[text_range_no_quotes - start];
         let offset = text_range_no_quotes.start() - start;
 
         unescape_unicode(text, Self::MODE, &mut |range, unescaped_char| {
-            let text_range =
-                TextRange::new(range.start.try_into().unwrap(), range.end.try_into().unwrap());
-            cb(text_range + offset, unescaped_char);
+            if let Some((s, e)) = range.start.try_into().ok().zip(range.end.try_into().ok()) {
+                cb(TextRange::new(s, e) + offset, unescaped_char);
+            }
         });
     }
     fn map_range_up(&self, range: TextRange) -> Option<TextRange> {
         let contents_range = self.text_range_between_quotes()?;
-        assert!(TextRange::up_to(contents_range.len()).contains_range(range));
-        Some(range + contents_range.start())
+        if always!(TextRange::up_to(contents_range.len()).contains_range(range)) {
+            Some(range + contents_range.start())
+        } else {
+            None
+        }
     }
 }