about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJonas Schievink <jonas.schievink@ferrous-systems.com>2022-05-19 18:53:08 +0200
committerJonas Schievink <jonas.schievink@ferrous-systems.com>2022-05-19 18:53:08 +0200
commit52ff863abcc8a6cb06689b25590616982504d916 (patch)
treea8fc55e2835bb8f897af746ae1cbcee92973775e
parenteba26af9f1af8566755f97259124e0c8d78b6c85 (diff)
downloadrust-52ff863abcc8a6cb06689b25590616982504d916.tar.gz
rust-52ff863abcc8a6cb06689b25590616982504d916.zip
Teach `Callable` about closures properly
-rw-r--r--crates/hir/src/lib.rs71
-rw-r--r--crates/ide/src/inlay_hints.rs20
-rw-r--r--crates/ide/src/signature_help.rs8
3 files changed, 77 insertions, 22 deletions
diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs
index 12e06bf4aca..3f62a2cd334 100644
--- a/crates/hir/src/lib.rs
+++ b/crates/hir/src/lib.rs
@@ -62,9 +62,9 @@ use hir_ty::{
     subst_prefix,
     traits::FnTrait,
     AliasEq, AliasTy, BoundVar, CallableDefId, CallableSig, Canonical, CanonicalVarKinds, Cast,
-    DebruijnIndex, GenericArgData, InEnvironment, Interner, ParamKind, QuantifiedWhereClause,
-    Scalar, Solution, Substitution, TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyDefId, TyExt,
-    TyKind, TyVariableKind, WhereClause,
+    ClosureId, DebruijnIndex, GenericArgData, InEnvironment, Interner, ParamKind,
+    QuantifiedWhereClause, Scalar, Solution, Substitution, TraitEnvironment, TraitRefExt, Ty,
+    TyBuilder, TyDefId, TyExt, TyKind, TyVariableKind, WhereClause,
 };
 use itertools::Itertools;
 use nameres::diagnostics::DefDiagnosticKind;
@@ -2819,10 +2819,14 @@ impl Type {
     }
 
     pub fn as_callable(&self, db: &dyn HirDatabase) -> Option<Callable> {
-        let def = self.ty.callable_def(db);
+        let callee = match self.ty.kind(Interner) {
+            TyKind::Closure(id, _) => Callee::Closure(*id),
+            TyKind::Function(_) => Callee::FnPtr,
+            _ => Callee::Def(self.ty.callable_def(db)?),
+        };
 
         let sig = self.ty.callable_sig(db)?;
-        Some(Callable { ty: self.clone(), sig, def, is_bound_method: false })
+        Some(Callable { ty: self.clone(), sig, callee, is_bound_method: false })
     }
 
     pub fn is_closure(&self) -> bool {
@@ -3265,34 +3269,43 @@ impl Type {
     }
 }
 
-// FIXME: closures
 #[derive(Debug)]
 pub struct Callable {
     ty: Type,
     sig: CallableSig,
-    def: Option<CallableDefId>,
+    callee: Callee,
     pub(crate) is_bound_method: bool,
 }
 
+#[derive(Debug)]
+enum Callee {
+    Def(CallableDefId),
+    Closure(ClosureId),
+    FnPtr,
+}
+
 pub enum CallableKind {
     Function(Function),
     TupleStruct(Struct),
     TupleEnumVariant(Variant),
     Closure,
+    FnPtr,
 }
 
 impl Callable {
     pub fn kind(&self) -> CallableKind {
-        match self.def {
-            Some(CallableDefId::FunctionId(it)) => CallableKind::Function(it.into()),
-            Some(CallableDefId::StructId(it)) => CallableKind::TupleStruct(it.into()),
-            Some(CallableDefId::EnumVariantId(it)) => CallableKind::TupleEnumVariant(it.into()),
-            None => CallableKind::Closure,
+        use Callee::*;
+        match self.callee {
+            Def(CallableDefId::FunctionId(it)) => CallableKind::Function(it.into()),
+            Def(CallableDefId::StructId(it)) => CallableKind::TupleStruct(it.into()),
+            Def(CallableDefId::EnumVariantId(it)) => CallableKind::TupleEnumVariant(it.into()),
+            Closure(_) => CallableKind::Closure,
+            FnPtr => CallableKind::FnPtr,
         }
     }
     pub fn receiver_param(&self, db: &dyn HirDatabase) -> Option<ast::SelfParam> {
-        let func = match self.def {
-            Some(CallableDefId::FunctionId(it)) if self.is_bound_method => it,
+        let func = match self.callee {
+            Callee::Def(CallableDefId::FunctionId(it)) if self.is_bound_method => it,
             _ => return None,
         };
         let src = func.lookup(db.upcast()).source(db.upcast());
@@ -3312,8 +3325,9 @@ impl Callable {
             .iter()
             .skip(if self.is_bound_method { 1 } else { 0 })
             .map(|ty| self.ty.derived(ty.clone()));
-        let patterns = match self.def {
-            Some(CallableDefId::FunctionId(func)) => {
+        let map_param = |it: ast::Param| it.pat().map(Either::Right);
+        let patterns = match self.callee {
+            Callee::Def(CallableDefId::FunctionId(func)) => {
                 let src = func.lookup(db.upcast()).source(db.upcast());
                 src.value.param_list().map(|param_list| {
                     param_list
@@ -3321,9 +3335,20 @@ impl Callable {
                         .map(|it| Some(Either::Left(it)))
                         .filter(|_| !self.is_bound_method)
                         .into_iter()
-                        .chain(param_list.params().map(|it| it.pat().map(Either::Right)))
+                        .chain(param_list.params().map(map_param))
                 })
             }
+            Callee::Closure(closure_id) => match closure_source(db, closure_id) {
+                Some(src) => src.param_list().map(|param_list| {
+                    param_list
+                        .self_param()
+                        .map(|it| Some(Either::Left(it)))
+                        .filter(|_| !self.is_bound_method)
+                        .into_iter()
+                        .chain(param_list.params().map(map_param))
+                }),
+                None => None,
+            },
             _ => None,
         };
         patterns.into_iter().flatten().chain(iter::repeat(None)).zip(types).collect()
@@ -3333,6 +3358,18 @@ impl Callable {
     }
 }
 
+fn closure_source(db: &dyn HirDatabase, closure: ClosureId) -> Option<ast::ClosureExpr> {
+    let (owner, expr_id) = db.lookup_intern_closure(closure.into());
+    let (_, source_map) = db.body_with_source_map(owner);
+    let ast = source_map.expr_syntax(expr_id).ok()?;
+    let root = ast.file_syntax(db.upcast());
+    let expr = ast.value.to_node(&root);
+    match expr {
+        ast::Expr::ClosureExpr(it) => Some(it),
+        _ => None,
+    }
+}
+
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum BindingMode {
     Move,
diff --git a/crates/ide/src/inlay_hints.rs b/crates/ide/src/inlay_hints.rs
index 3cb60d9e446..47f1a08b6fb 100644
--- a/crates/ide/src/inlay_hints.rs
+++ b/crates/ide/src/inlay_hints.rs
@@ -1170,6 +1170,23 @@ fn main() {
     }
 
     #[test]
+    fn param_hints_on_closure() {
+        check_params(
+            r#"
+fn main() {
+    let clo = |a: u8, b: u8| a + b;
+    clo(
+        1,
+      //^ a
+        2,
+      //^ b
+    );
+}
+            "#,
+        );
+    }
+
+    #[test]
     fn param_name_similar_to_fn_name_still_hints() {
         check_params(
             r#"
@@ -2000,7 +2017,8 @@ fn main() {
 
     ;
 
-    let _: i32 = multiply(1, 2);
+    let _: i32 = multiply(1,  2);
+                        //^ a ^ b
     let multiply_ref = &multiply;
       //^^^^^^^^^^^^ &|i32, i32| -> i32
 
diff --git a/crates/ide/src/signature_help.rs b/crates/ide/src/signature_help.rs
index 32e7c59b2a5..cb38f48f32a 100644
--- a/crates/ide/src/signature_help.rs
+++ b/crates/ide/src/signature_help.rs
@@ -149,7 +149,7 @@ fn signature_help_for_call(
                 variant.name(db)
             );
         }
-        hir::CallableKind::Closure => (),
+        hir::CallableKind::Closure | hir::CallableKind::FnPtr => (),
     }
 
     res.signature.push('(');
@@ -189,7 +189,7 @@ fn signature_help_for_call(
         hir::CallableKind::Function(func) if callable.return_type().contains_unknown() => {
             render(func.ret_type(db))
         }
-        hir::CallableKind::Function(_) | hir::CallableKind::Closure => {
+        hir::CallableKind::Function(_) | hir::CallableKind::Closure | hir::CallableKind::FnPtr => {
             render(callable.return_type())
         }
         hir::CallableKind::TupleStruct(_) | hir::CallableKind::TupleEnumVariant(_) => {}
@@ -914,8 +914,8 @@ fn main() {
 }
         "#,
             expect![[r#"
-                (S) -> i32
-                 ^
+                (s: S) -> i32
+                 ^^^^
             "#]],
         )
     }