about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRyo Yoshida <low.ryoshida@gmail.com>2023-05-29 19:44:31 +0900
committerRyo Yoshida <low.ryoshida@gmail.com>2023-05-29 20:10:54 +0900
commit033e6ac57a5fb650e6f5240e7d1b8cc7841ff53b (patch)
tree84bcf484f5fd54aaa270d2f55d5891761f5c2570
parentab9347542c4f584952a5f554a18e1f92188b2fdb (diff)
downloadrust-033e6ac57a5fb650e6f5240e7d1b8cc7841ff53b.tar.gz
rust-033e6ac57a5fb650e6f5240e7d1b8cc7841ff53b.zip
Verify name references more rigidly
Previously we didn't verify that record expressions/patterns that were
found did actually point to the struct we're operating on. Moreover,
when that record expressions/patterns had missing child nodes, we would
continue traversing their ancestor nodes.
-rw-r--r--crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs172
1 files changed, 124 insertions, 48 deletions
diff --git a/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs b/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs
index ce31d1d891d..00a4e0530d2 100644
--- a/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs
+++ b/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs
@@ -1,9 +1,9 @@
 use either::Either;
-use ide_db::defs::Definition;
+use ide_db::{defs::Definition, search::FileReference};
 use itertools::Itertools;
 use syntax::{
     ast::{self, AstNode, HasGenericParams, HasVisibility},
-    match_ast, SyntaxKind, SyntaxNode,
+    match_ast, SyntaxKind,
 };
 
 use crate::{assist_context::SourceChangeBuilder, AssistContext, AssistId, AssistKind, Assists};
@@ -141,57 +141,70 @@ fn edit_struct_references(
     };
     let usages = strukt_def.usages(&ctx.sema).include_self_refs().all();
 
-    let edit_node = |edit: &mut SourceChangeBuilder, node: SyntaxNode| -> Option<()> {
-        match_ast! {
-            match node {
-                ast::RecordPat(record_struct_pat) => {
-                    let Some(fr) = ctx.sema.original_range_opt(record_struct_pat.syntax()) else {
-                        // We've found the node to replace, so we should return `Some` even if the
-                        // replacement failed to stop the ancestor node traversal.
-                        return Some(());
-                    };
-                    edit.replace(
-                        fr.range,
-                        ast::make::tuple_struct_pat(
-                            record_struct_pat.path()?,
-                            record_struct_pat
-                                .record_pat_field_list()?
-                                .fields()
-                                .filter_map(|pat| pat.pat())
-                        )
-                        .to_string()
-                    );
-                },
-                ast::RecordExpr(record_expr) => {
-                    let Some(fr) = ctx.sema.original_range_opt(record_expr.syntax()) else {
-                        // See the comment above.
-                        return Some(());
-                    };
-                    let path = record_expr.path()?;
-                    let args = record_expr
-                        .record_expr_field_list()?
-                        .fields()
-                        .filter_map(|f| f.expr())
-                        .join(", ");
-
-                    edit.replace(fr.range, format!("{path}({args})"));
-                },
-                _ => return None,
-            }
-        }
-        Some(())
-    };
-
     for (file_id, refs) in usages {
         edit.edit_file(file_id);
         for r in refs {
-            for node in r.name.syntax().ancestors() {
-                if edit_node(edit, node).is_some() {
-                    break;
-                }
-            }
+            process_struct_name_reference(ctx, r, edit);
+        }
+    }
+}
+
+fn process_struct_name_reference(
+    ctx: &AssistContext<'_>,
+    r: FileReference,
+    edit: &mut SourceChangeBuilder,
+) -> Option<()> {
+    // First check if it's the last semgnet of a path that directly belongs to a record
+    // expression/pattern.
+    let name_ref = r.name.as_name_ref()?;
+    let path_segment = name_ref.syntax().parent().and_then(ast::PathSegment::cast)?;
+    // A `PathSegment` always belongs to a `Path`, so there's at least one `Path` at this point.
+    let full_path =
+        path_segment.syntax().parent()?.ancestors().map_while(ast::Path::cast).last().unwrap();
+
+    if full_path.segment().unwrap().name_ref()? != *name_ref {
+        // `name_ref` isn't the last segment of the path, so `full_path` doesn't point to the
+        // struct we want to edit.
+        return None;
+    }
+
+    let parent = full_path.syntax().parent()?;
+    match_ast! {
+        match parent {
+            ast::RecordPat(record_struct_pat) => {
+                // When we failed to get the original range for the whole struct expression node,
+                // we can't provide any reasonable edit. Leave it untouched.
+                let file_range = ctx.sema.original_range_opt(record_struct_pat.syntax())?;
+                edit.replace(
+                    file_range.range,
+                    ast::make::tuple_struct_pat(
+                        record_struct_pat.path()?,
+                        record_struct_pat
+                            .record_pat_field_list()?
+                            .fields()
+                            .filter_map(|pat| pat.pat())
+                    )
+                    .to_string()
+                );
+            },
+            ast::RecordExpr(record_expr) => {
+                // When we failed to get the original range for the whole struct pattern node,
+                // we can't provide any reasonable edit. Leave it untouched.
+                let file_range = ctx.sema.original_range_opt(record_expr.syntax())?;
+                let path = record_expr.path()?;
+                let args = record_expr
+                    .record_expr_field_list()?
+                    .fields()
+                    .filter_map(|f| f.expr())
+                    .join(", ");
+
+                edit.replace(file_range.range, format!("{path}({args})"));
+            },
+            _ => {}
         }
     }
+
+    Some(())
 }
 
 fn edit_field_references(
@@ -901,4 +914,67 @@ fn test() {
 "#,
         );
     }
+
+    #[test]
+    fn struct_name_ref_may_not_be_part_of_struct_expr_or_struct_pat() {
+        check_assist(
+            convert_named_struct_to_tuple_struct,
+            r#"
+struct $0Struct {
+    inner: i32,
+}
+struct Outer<T> {
+    value: T,
+}
+fn foo<T>() -> T { loop {} }
+
+fn test() {
+    Outer {
+        value: foo::<Struct>();
+    }
+}
+
+trait HasAssoc {
+    type Assoc;
+    fn test();
+}
+impl HasAssoc for Struct {
+    type Assoc = Outer<i32>;
+    fn test() {
+        let a = Self::Assoc {
+            value: 42,
+        };
+        let Self::Assoc { value } = a;
+    }
+}
+"#,
+            r#"
+struct Struct(i32);
+struct Outer<T> {
+    value: T,
+}
+fn foo<T>() -> T { loop {} }
+
+fn test() {
+    Outer {
+        value: foo::<Struct>();
+    }
+}
+
+trait HasAssoc {
+    type Assoc;
+    fn test();
+}
+impl HasAssoc for Struct {
+    type Assoc = Outer<i32>;
+    fn test() {
+        let a = Self::Assoc {
+            value: 42,
+        };
+        let Self::Assoc { value } = a;
+    }
+}
+"#,
+        );
+    }
 }