about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDropDemBits <r3usrlnd@gmail.com>2023-11-13 20:30:58 -0500
committerDropDemBits <r3usrlnd@gmail.com>2023-11-13 20:41:06 -0500
commit787ca888e35eaa351c9755f4f4de75acca8bf836 (patch)
tree1226c6ac7ebfb15c2717a87a670b45ee6edef658
parent6f68cd33947c3f361e47c24904605e49e5637eba (diff)
downloadrust-787ca888e35eaa351c9755f4f4de75acca8bf836.tar.gz
rust-787ca888e35eaa351c9755f4f4de75acca8bf836.zip
Add `IdentPat::set_pat`
Needed so that the `tuple_pat` node gets added to the syntax tree,
which is required as we're using structured snippets.
-rw-r--r--crates/ide-assists/src/handlers/destructure_tuple_binding.rs10
-rw-r--r--crates/syntax/src/ast/edit_in_place.rs75
2 files changed, 76 insertions, 9 deletions
diff --git a/crates/ide-assists/src/handlers/destructure_tuple_binding.rs b/crates/ide-assists/src/handlers/destructure_tuple_binding.rs
index 2dc30e685a1..65b497e83aa 100644
--- a/crates/ide-assists/src/handlers/destructure_tuple_binding.rs
+++ b/crates/ide-assists/src/handlers/destructure_tuple_binding.rs
@@ -197,15 +197,7 @@ impl AssignmentEdit {
     fn apply(self) {
         // with sub_pattern: keep original tuple and add subpattern: `tup @ (_0, _1)`
         if self.in_sub_pattern {
-            ted::insert_all_raw(
-                ted::Position::after(self.ident_pat.syntax()),
-                vec![
-                    make::tokens::single_space().into(),
-                    make::token(T![@]).into(),
-                    make::tokens::single_space().into(),
-                    self.tuple_pat.syntax().clone().into(),
-                ],
-            )
+            self.ident_pat.set_pat(Some(self.tuple_pat.into()))
         } else {
             ted::replace(self.ident_pat.syntax(), self.tuple_pat.syntax())
         }
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs
index b9059a527d0..edb55a2f136 100644
--- a/crates/syntax/src/ast/edit_in_place.rs
+++ b/crates/syntax/src/ast/edit_in_place.rs
@@ -846,6 +846,53 @@ fn normalize_ws_between_braces(node: &SyntaxNode) -> Option<()> {
     Some(())
 }
 
+impl ast::IdentPat {
+    pub fn set_pat(&self, pat: Option<ast::Pat>) {
+        match pat {
+            None => {
+                if let Some(at_token) = self.at_token() {
+                    // Remove `@ Pat`
+                    let start = at_token.clone().into();
+                    let end = self
+                        .pat()
+                        .map(|it| it.syntax().clone().into())
+                        .unwrap_or_else(|| at_token.into());
+
+                    ted::remove_all(start..=end);
+
+                    // Remove any trailing ws
+                    if let Some(last) =
+                        self.syntax().last_token().filter(|it| it.kind() == WHITESPACE)
+                    {
+                        last.detach();
+                    }
+                }
+            }
+            Some(pat) => {
+                if let Some(old_pat) = self.pat() {
+                    // Replace existing pattern
+                    ted::replace(old_pat.syntax(), pat.syntax())
+                } else if let Some(at_token) = self.at_token() {
+                    // Have an `@` token but not a pattern yet
+                    ted::insert(ted::Position::after(at_token), pat.syntax());
+                } else {
+                    // Don't have an `@`, should have a name
+                    let name = self.name().unwrap();
+
+                    ted::insert_all(
+                        ted::Position::after(name.syntax()),
+                        vec![
+                            make::token(T![@]).into(),
+                            make::tokens::single_space().into(),
+                            pat.syntax().clone().into(),
+                        ],
+                    )
+                }
+            }
+        }
+    }
+}
+
 pub trait HasVisibilityEdit: ast::HasVisibility {
     fn set_visibility(&self, visbility: ast::Visibility) {
         match self.visibility() {
@@ -948,6 +995,34 @@ mod tests {
     }
 
     #[test]
+    fn test_ident_pat_set_pat() {
+        #[track_caller]
+        fn check(before: &str, expected: &str, pat: Option<ast::Pat>) {
+            let pat = pat.map(|it| it.clone_for_update());
+
+            let ident_pat = ast_mut_from_text::<ast::IdentPat>(&format!("fn f() {{ {before} }}"));
+            ident_pat.set_pat(pat);
+
+            let after = ast_mut_from_text::<ast::IdentPat>(&format!("fn f() {{ {expected} }}"));
+            assert_eq!(ident_pat.to_string(), after.to_string());
+        }
+
+        // replacing
+        check("let a @ _;", "let a @ ();", Some(make::tuple_pat([]).into()));
+
+        // note: no trailing semicolon is added for the below tests since it
+        // seems to be picked up by the ident pat during error recovery?
+
+        // adding
+        check("let a ", "let a @ ()", Some(make::tuple_pat([]).into()));
+        check("let a @ ", "let a @ ()", Some(make::tuple_pat([]).into()));
+
+        // removing
+        check("let a @ ()", "let a", None);
+        check("let a @ ", "let a", None);
+    }
+
+    #[test]
     fn add_variant_to_empty_enum() {
         let variant = make::variant(make::name("Bar"), None).clone_for_update();