about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2024-03-16 17:49:59 +0100
committerLukas Wirth <lukastw97@gmail.com>2024-03-16 17:49:59 +0100
commitd69a81fddbbe3c06c135ca7ae596ffc02b24cf8e (patch)
treebaf9b9a489aa308503c8f617722b7604dbea9ae0
parent0dd89d7ee7800c53dc10cec941c4e740562d54df (diff)
downloadrust-d69a81fddbbe3c06c135ca7ae596ffc02b24cf8e.tar.gz
rust-d69a81fddbbe3c06c135ca7ae596ffc02b24cf8e.zip
fix: Fix wrong where clause rendering on hover
-rw-r--r--crates/hir/src/display.rs88
-rw-r--r--crates/ide/src/hover/tests.rs230
2 files changed, 210 insertions, 108 deletions
diff --git a/crates/hir/src/display.rs b/crates/hir/src/display.rs
index cdc0db8653c..c5d44c11f2c 100644
--- a/crates/hir/src/display.rs
+++ b/crates/hir/src/display.rs
@@ -159,6 +159,7 @@ impl HirDisplay for Adt {
 impl HirDisplay for Struct {
     fn hir_fmt(&self, f: &mut HirFormatter<'_>) -> Result<(), HirDisplayError> {
         let module_id = self.module(f.db).id;
+        // FIXME: Render repr if its set explicitly?
         write_visibility(module_id, self.visibility(f.db), f)?;
         f.write_str("struct ")?;
         write!(f, "{}", self.name(f.db).display(f.db.upcast()))?;
@@ -166,37 +167,40 @@ impl HirDisplay for Struct {
         write_generic_params(def_id, f)?;
 
         let variant_data = self.variant_data(f.db);
-        if let StructKind::Tuple = variant_data.kind() {
-            f.write_char('(')?;
-            let mut it = variant_data.fields().iter().peekable();
-
-            while let Some((id, _)) = it.next() {
-                let field = Field { parent: (*self).into(), id };
-                write_visibility(module_id, field.visibility(f.db), f)?;
-                field.ty(f.db).hir_fmt(f)?;
-                if it.peek().is_some() {
-                    f.write_str(", ")?;
-                }
-            }
-
-            f.write_str(");")?;
-        }
+        match variant_data.kind() {
+            StructKind::Tuple => {
+                f.write_char('(')?;
+                let mut it = variant_data.fields().iter().peekable();
 
-        write_where_clause(def_id, f)?;
+                while let Some((id, _)) = it.next() {
+                    let field = Field { parent: (*self).into(), id };
+                    write_visibility(module_id, field.visibility(f.db), f)?;
+                    field.ty(f.db).hir_fmt(f)?;
+                    if it.peek().is_some() {
+                        f.write_str(", ")?;
+                    }
+                }
 
-        if let StructKind::Record = variant_data.kind() {
-            let fields = self.fields(f.db);
-            if fields.is_empty() {
-                f.write_str(" {}")?;
-            } else {
-                f.write_str(" {\n")?;
-                for field in self.fields(f.db) {
-                    f.write_str("    ")?;
-                    field.hir_fmt(f)?;
-                    f.write_str(",\n")?;
+                f.write_char(')')?;
+                write_where_clause(def_id, f)?;
+            }
+            StructKind::Record => {
+                let has_where_clause = write_where_clause(def_id, f)?;
+                let fields = self.fields(f.db);
+                f.write_char(if !has_where_clause { ' ' } else { '\n' })?;
+                if fields.is_empty() {
+                    f.write_str("{}")?;
+                } else {
+                    f.write_str("{\n")?;
+                    for field in self.fields(f.db) {
+                        f.write_str("    ")?;
+                        field.hir_fmt(f)?;
+                        f.write_str(",\n")?;
+                    }
+                    f.write_str("}")?;
                 }
-                f.write_str("}")?;
             }
+            StructKind::Unit => _ = write_where_clause(def_id, f)?,
         }
 
         Ok(())
@@ -210,11 +214,12 @@ impl HirDisplay for Enum {
         write!(f, "{}", self.name(f.db).display(f.db.upcast()))?;
         let def_id = GenericDefId::AdtId(AdtId::EnumId(self.id));
         write_generic_params(def_id, f)?;
-        write_where_clause(def_id, f)?;
+        let has_where_clause = write_where_clause(def_id, f)?;
 
         let variants = self.variants(f.db);
         if !variants.is_empty() {
-            f.write_str(" {\n")?;
+            f.write_char(if !has_where_clause { ' ' } else { '\n' })?;
+            f.write_str("{\n")?;
             for variant in variants {
                 f.write_str("    ")?;
                 variant.hir_fmt(f)?;
@@ -234,11 +239,12 @@ impl HirDisplay for Union {
         write!(f, "{}", self.name(f.db).display(f.db.upcast()))?;
         let def_id = GenericDefId::AdtId(AdtId::UnionId(self.id));
         write_generic_params(def_id, f)?;
-        write_where_clause(def_id, f)?;
+        let has_where_clause = write_where_clause(def_id, f)?;
 
         let fields = self.fields(f.db);
         if !fields.is_empty() {
-            f.write_str(" {\n")?;
+            f.write_char(if !has_where_clause { ' ' } else { '\n' })?;
+            f.write_str("{\n")?;
             for field in self.fields(f.db) {
                 f.write_str("    ")?;
                 field.hir_fmt(f)?;
@@ -446,7 +452,10 @@ fn write_generic_params(
     Ok(())
 }
 
-fn write_where_clause(def: GenericDefId, f: &mut HirFormatter<'_>) -> Result<(), HirDisplayError> {
+fn write_where_clause(
+    def: GenericDefId,
+    f: &mut HirFormatter<'_>,
+) -> Result<bool, HirDisplayError> {
     let params = f.db.generic_params(def);
 
     // unnamed type targets are displayed inline with the argument itself, e.g. `f: impl Y`.
@@ -465,7 +474,7 @@ fn write_where_clause(def: GenericDefId, f: &mut HirFormatter<'_>) -> Result<(),
         });
 
     if !has_displayable_predicate {
-        return Ok(());
+        return Ok(false);
     }
 
     let write_target = |target: &WherePredicateTypeTarget, f: &mut HirFormatter<'_>| match target {
@@ -543,7 +552,7 @@ fn write_where_clause(def: GenericDefId, f: &mut HirFormatter<'_>) -> Result<(),
     // End of final predicate. There must be at least one predicate here.
     f.write_char(',')?;
 
-    Ok(())
+    Ok(true)
 }
 
 impl HirDisplay for Const {
@@ -594,19 +603,20 @@ impl HirDisplay for Trait {
         write!(f, "trait {}", data.name.display(f.db.upcast()))?;
         let def_id = GenericDefId::TraitId(self.id);
         write_generic_params(def_id, f)?;
-        write_where_clause(def_id, f)?;
+        let has_where_clause = write_where_clause(def_id, f)?;
 
         if let Some(limit) = f.entity_limit {
             let assoc_items = self.items(f.db);
             let count = assoc_items.len().min(limit);
+            f.write_char(if !has_where_clause { ' ' } else { '\n' })?;
             if count == 0 {
                 if assoc_items.is_empty() {
-                    f.write_str(" {}")?;
+                    f.write_str("{}")?;
                 } else {
-                    f.write_str(" { /* … */ }")?;
+                    f.write_str("{ /* … */ }")?;
                 }
             } else {
-                f.write_str(" {\n")?;
+                f.write_str("{\n")?;
                 for item in &assoc_items[..count] {
                     f.write_str("    ")?;
                     match item {
@@ -651,7 +661,6 @@ impl HirDisplay for TypeAlias {
         write!(f, "type {}", data.name.display(f.db.upcast()))?;
         let def_id = GenericDefId::TypeAliasId(self.id);
         write_generic_params(def_id, f)?;
-        write_where_clause(def_id, f)?;
         if !data.bounds.is_empty() {
             f.write_str(": ")?;
             f.write_joined(data.bounds.iter(), " + ")?;
@@ -660,6 +669,7 @@ impl HirDisplay for TypeAlias {
             f.write_str(" = ")?;
             ty.hir_fmt(f)?;
         }
+        write_where_clause(def_id, f)?;
         Ok(())
     }
 }
diff --git a/crates/ide/src/hover/tests.rs b/crates/ide/src/hover/tests.rs
index 051a96233a0..4451e31870f 100644
--- a/crates/ide/src/hover/tests.rs
+++ b/crates/ide/src/hover/tests.rs
@@ -819,7 +819,7 @@ fn foo(foo: Foo) {
 fn hover_tuple_struct() {
     check(
         r#"
-struct Foo$0(pub u32)
+struct Foo$0(pub u32) where u32: Copy;
 "#,
         expect![[r#"
             *Foo*
@@ -830,7 +830,99 @@ struct Foo$0(pub u32)
 
             ```rust
             // size = 4, align = 4
-            struct Foo(pub u32);
+            struct Foo(pub u32)
+            where
+                u32: Copy,
+            ```
+        "#]],
+    );
+}
+
+#[test]
+fn hover_record_struct() {
+    check(
+        r#"
+struct Foo$0 { field: u32 }
+"#,
+        expect![[r#"
+            *Foo*
+
+            ```rust
+            test
+            ```
+
+            ```rust
+            // size = 4, align = 4
+            struct Foo {
+                field: u32,
+            }
+            ```
+        "#]],
+    );
+    check(
+        r#"
+struct Foo$0 where u32: Copy { field: u32 }
+"#,
+        expect![[r#"
+            *Foo*
+
+            ```rust
+            test
+            ```
+
+            ```rust
+            // size = 4, align = 4
+            struct Foo
+            where
+                u32: Copy,
+            {
+                field: u32,
+            }
+            ```
+        "#]],
+    );
+}
+
+#[test]
+fn hover_unit_struct() {
+    check(
+        r#"
+struct Foo$0 where u32: Copy;
+"#,
+        expect![[r#"
+            *Foo*
+
+            ```rust
+            test
+            ```
+
+            ```rust
+            // size = 0, align = 1
+            struct Foo
+            where
+                u32: Copy,
+            ```
+        "#]],
+    );
+}
+
+#[test]
+fn hover_type_alias() {
+    check(
+        r#"
+type Fo$0o: Trait = S where T: Trait;
+"#,
+        expect![[r#"
+            *Foo*
+
+            ```rust
+            test
+            ```
+
+            ```rust
+            type Foo: Trait = S
+            where
+                T: Trait,
             ```
         "#]],
     );
@@ -2540,7 +2632,7 @@ fn main() { let s$0t = S{ f1:Arg(0) }; }
                                 focus_range: 7..10,
                                 name: "Arg",
                                 kind: Struct,
-                                description: "struct Arg(u32);",
+                                description: "struct Arg(u32)",
                             },
                         },
                         HoverGotoTypeData {
@@ -2599,7 +2691,7 @@ fn main() { let s$0t = S{ f1: S{ f1: Arg(0) } }; }
                                 focus_range: 7..10,
                                 name: "Arg",
                                 kind: Struct,
-                                description: "struct Arg(u32);",
+                                description: "struct Arg(u32)",
                             },
                         },
                         HoverGotoTypeData {
@@ -2648,7 +2740,7 @@ fn main() { let s$0t = (A(1), B(2), M::C(3) ); }
                                 focus_range: 7..8,
                                 name: "A",
                                 kind: Struct,
-                                description: "struct A(u32);",
+                                description: "struct A(u32)",
                             },
                         },
                         HoverGotoTypeData {
@@ -2661,7 +2753,7 @@ fn main() { let s$0t = (A(1), B(2), M::C(3) ); }
                                 focus_range: 22..23,
                                 name: "B",
                                 kind: Struct,
-                                description: "struct B(u32);",
+                                description: "struct B(u32)",
                             },
                         },
                         HoverGotoTypeData {
@@ -2675,7 +2767,7 @@ fn main() { let s$0t = (A(1), B(2), M::C(3) ); }
                                 name: "C",
                                 kind: Struct,
                                 container_name: "M",
-                                description: "pub struct C(u32);",
+                                description: "pub struct C(u32)",
                             },
                         },
                     ],
@@ -3331,26 +3423,26 @@ struct Foo<const BAR: Bar>;
 impl<const BAR: Bar> Foo<BAR$0> {}
 "#,
         expect![[r#"
-                [
-                    GoToType(
-                        [
-                            HoverGotoTypeData {
-                                mod_path: "test::Bar",
-                                nav: NavigationTarget {
-                                    file_id: FileId(
-                                        0,
-                                    ),
-                                    full_range: 0..11,
-                                    focus_range: 7..10,
-                                    name: "Bar",
-                                    kind: Struct,
-                                    description: "struct Bar",
-                                },
+            [
+                GoToType(
+                    [
+                        HoverGotoTypeData {
+                            mod_path: "test::Bar",
+                            nav: NavigationTarget {
+                                file_id: FileId(
+                                    0,
+                                ),
+                                full_range: 0..11,
+                                focus_range: 7..10,
+                                name: "Bar",
+                                kind: Struct,
+                                description: "struct Bar",
                             },
-                        ],
-                    ),
-                ]
-            "#]],
+                        },
+                    ],
+                ),
+            ]
+        "#]],
     );
 }
 
@@ -3396,26 +3488,26 @@ impl Foo {
 }
 "#,
         expect![[r#"
-                [
-                    GoToType(
-                        [
-                            HoverGotoTypeData {
-                                mod_path: "test::Foo",
-                                nav: NavigationTarget {
-                                    file_id: FileId(
-                                        0,
-                                    ),
-                                    full_range: 0..11,
-                                    focus_range: 7..10,
-                                    name: "Foo",
-                                    kind: Struct,
-                                    description: "struct Foo",
-                                },
+            [
+                GoToType(
+                    [
+                        HoverGotoTypeData {
+                            mod_path: "test::Foo",
+                            nav: NavigationTarget {
+                                file_id: FileId(
+                                    0,
+                                ),
+                                full_range: 0..11,
+                                focus_range: 7..10,
+                                name: "Foo",
+                                kind: Struct,
+                                description: "struct Foo",
                             },
-                        ],
-                    ),
-                ]
-            "#]],
+                        },
+                    ],
+                ),
+            ]
+        "#]],
     );
 }
 
@@ -3498,7 +3590,7 @@ struct S$0T<const C: usize = 1, T = Foo>(T);
             ```
 
             ```rust
-            struct ST<const C: usize = 1, T = Foo>(T);
+            struct ST<const C: usize = 1, T = Foo>(T)
             ```
         "#]],
     );
@@ -3519,7 +3611,7 @@ struct S$0T<const C: usize = {40 + 2}, T = Foo>(T);
             ```
 
             ```rust
-            struct ST<const C: usize = {const}, T = Foo>(T);
+            struct ST<const C: usize = {const}, T = Foo>(T)
             ```
         "#]],
     );
@@ -3541,7 +3633,7 @@ struct S$0T<const C: usize = VAL, T = Foo>(T);
             ```
 
             ```rust
-            struct ST<const C: usize = VAL, T = Foo>(T);
+            struct ST<const C: usize = VAL, T = Foo>(T)
             ```
         "#]],
     );
@@ -5931,26 +6023,26 @@ fn foo() {
 }
 "#,
         expect![[r#"
-                [
-                    GoToType(
-                        [
-                            HoverGotoTypeData {
-                                mod_path: "test::Foo",
-                                nav: NavigationTarget {
-                                    file_id: FileId(
-                                        0,
-                                    ),
-                                    full_range: 0..11,
-                                    focus_range: 7..10,
-                                    name: "Foo",
-                                    kind: Struct,
-                                    description: "struct Foo",
-                                },
+            [
+                GoToType(
+                    [
+                        HoverGotoTypeData {
+                            mod_path: "test::Foo",
+                            nav: NavigationTarget {
+                                file_id: FileId(
+                                    0,
+                                ),
+                                full_range: 0..11,
+                                focus_range: 7..10,
+                                name: "Foo",
+                                kind: Struct,
+                                description: "struct Foo",
                             },
-                        ],
-                    ),
-                ]
-            "#]],
+                        },
+                    ],
+                ),
+            ]
+        "#]],
     );
 }
 
@@ -6166,7 +6258,7 @@ pub struct Foo(i32);
 
             ```rust
             // size = 4, align = 4
-            pub struct Foo(i32);
+            pub struct Foo(i32)
             ```
 
             ---
@@ -6191,7 +6283,7 @@ pub struct Foo<T>(T);
             ```
 
             ```rust
-            pub struct Foo<T>(T);
+            pub struct Foo<T>(T)
             ```
 
             ---