about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDropDemBits <r3usrlnd@gmail.com>2023-11-14 17:35:24 -0500
committerDropDemBits <r3usrlnd@gmail.com>2023-11-14 17:35:24 -0500
commitdf629627c5bd63ea8f44eca1751f2026baf65e5b (patch)
tree27a9f3fbe3f7f33d6ab6d8e4a53b5a2b7efed155
parent787ca888e35eaa351c9755f4f4de75acca8bf836 (diff)
downloadrust-df629627c5bd63ea8f44eca1751f2026baf65e5b.tar.gz
rust-df629627c5bd63ea8f44eca1751f2026baf65e5b.zip
Add tests for `LetStmt::set_ty`
-rw-r--r--crates/syntax/src/ast/edit_in_place.rs62
1 files changed, 48 insertions, 14 deletions
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs
index edb55a2f136..37d8212042d 100644
--- a/crates/syntax/src/ast/edit_in_place.rs
+++ b/crates/syntax/src/ast/edit_in_place.rs
@@ -663,25 +663,28 @@ impl ast::LetStmt {
 
                     ted::remove(existing_ty.syntax());
                 }
+
+                // Remove any trailing ws
+                if let Some(last) = self.syntax().last_token().filter(|it| it.kind() == WHITESPACE)
+                {
+                    last.detach();
+                }
             }
             Some(new_ty) => {
                 if self.colon_token().is_none() {
-                    let mut to_insert: Vec<SyntaxElement> = vec![];
-
-                    let position = match self.pat() {
-                        Some(pat) => Position::after(pat.syntax()),
-                        None => {
-                            to_insert.push(make::tokens::single_space().into());
-                            Position::after(self.let_token().unwrap())
-                        }
-                    };
-
-                    to_insert.push(make::token(T![:]).into());
-
-                    ted::insert_all_raw(position, to_insert);
+                    ted::insert_raw(
+                        Position::after(
+                            self.pat().expect("let stmt should have a pattern").syntax(),
+                        ),
+                        make::token(T![:]),
+                    );
                 }
 
-                ted::insert(Position::after(self.colon_token().unwrap()), new_ty.syntax());
+                if let Some(old_ty) = self.ty() {
+                    ted::replace(old_ty.syntax(), new_ty.syntax());
+                } else {
+                    ted::insert(Position::after(self.colon_token().unwrap()), new_ty.syntax());
+                }
             }
         }
     }
@@ -1023,6 +1026,37 @@ mod tests {
     }
 
     #[test]
+    fn test_let_stmt_set_ty() {
+        #[track_caller]
+        fn check(before: &str, expected: &str, ty: Option<ast::Type>) {
+            let ty = ty.map(|it| it.clone_for_update());
+
+            let let_stmt = ast_mut_from_text::<ast::LetStmt>(&format!("fn f() {{ {before} }}"));
+            let_stmt.set_ty(ty);
+
+            let after = ast_mut_from_text::<ast::LetStmt>(&format!("fn f() {{ {expected} }}"));
+            assert_eq!(let_stmt.to_string(), after.to_string(), "{let_stmt:#?}\n!=\n{after:#?}");
+        }
+
+        // adding
+        check("let a;", "let a: ();", Some(make::ty_tuple([])));
+        // no semicolon due to it being eaten during error recovery
+        check("let a:", "let a: ()", Some(make::ty_tuple([])));
+
+        // replacing
+        check("let a: u8;", "let a: ();", Some(make::ty_tuple([])));
+        check("let a: u8 = 3;", "let a: () = 3;", Some(make::ty_tuple([])));
+        check("let a: = 3;", "let a: () = 3;", Some(make::ty_tuple([])));
+
+        // removing
+        check("let a: u8;", "let a;", None);
+        check("let a:;", "let a;", None);
+
+        check("let a: u8 = 3;", "let a = 3;", None);
+        check("let a: = 3;", "let a = 3;", None);
+    }
+
+    #[test]
     fn add_variant_to_empty_enum() {
         let variant = make::variant(make::name("Bar"), None).clone_for_update();