about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir-def/src/macro_expansion_tests/mbe/matching.rs10
-rw-r--r--crates/hir-def/src/macro_expansion_tests/proc_macros.rs6
-rw-r--r--crates/hir-expand/src/fixup.rs96
-rw-r--r--crates/mbe/src/syntax_bridge.rs89
-rw-r--r--crates/mbe/src/syntax_bridge/tests.rs93
5 files changed, 229 insertions, 65 deletions
diff --git a/crates/hir-def/src/macro_expansion_tests/mbe/matching.rs b/crates/hir-def/src/macro_expansion_tests/mbe/matching.rs
index bc162d0fa20..fc90c6e9f37 100644
--- a/crates/hir-def/src/macro_expansion_tests/mbe/matching.rs
+++ b/crates/hir-def/src/macro_expansion_tests/mbe/matching.rs
@@ -94,11 +94,11 @@ macro_rules! m {
     ($($s:stmt)*) => (stringify!($($s |)*);)
 }
 stringify!(;
-|;
-|92|;
-|let x = 92|;
+| ;
+|92| ;
+|let x = 92| ;
 |loop {}
-|;
+| ;
 |);
 "#]],
     );
@@ -118,7 +118,7 @@ m!(.. .. ..);
 macro_rules! m {
     ($($p:pat)*) => (stringify!($($p |)*);)
 }
-stringify!(.. .. ..|);
+stringify!(.. .. .. |);
 "#]],
     );
 }
diff --git a/crates/hir-def/src/macro_expansion_tests/proc_macros.rs b/crates/hir-def/src/macro_expansion_tests/proc_macros.rs
index 029821e5e87..118c14ed843 100644
--- a/crates/hir-def/src/macro_expansion_tests/proc_macros.rs
+++ b/crates/hir-def/src/macro_expansion_tests/proc_macros.rs
@@ -82,14 +82,14 @@ fn attribute_macro_syntax_completion_2() {
 #[proc_macros::identity_when_valid]
 fn foo() { bar.; blub }
 "#,
-        expect![[r##"
+        expect![[r#"
 #[proc_macros::identity_when_valid]
 fn foo() { bar.; blub }
 
 fn foo() {
-    bar.;
+    bar. ;
     blub
-}"##]],
+}"#]],
     );
 }
 
diff --git a/crates/hir-expand/src/fixup.rs b/crates/hir-expand/src/fixup.rs
index 893e6fe4b82..a4abe75626e 100644
--- a/crates/hir-expand/src/fixup.rs
+++ b/crates/hir-expand/src/fixup.rs
@@ -4,6 +4,7 @@ use std::mem;
 
 use mbe::{SyntheticToken, SyntheticTokenId, TokenMap};
 use rustc_hash::FxHashMap;
+use smallvec::SmallVec;
 use syntax::{
     ast::{self, AstNode, HasLoopBody},
     match_ast, SyntaxElement, SyntaxKind, SyntaxNode, TextRange,
@@ -292,25 +293,34 @@ pub(crate) fn reverse_fixups(
     token_map: &TokenMap,
     undo_info: &SyntaxFixupUndoInfo,
 ) {
-    tt.token_trees.retain(|tt| match tt {
-        tt::TokenTree::Leaf(leaf) => {
-            token_map.synthetic_token_id(leaf.id()).is_none()
-                || token_map.synthetic_token_id(leaf.id()) != Some(EMPTY_ID)
-        }
-        tt::TokenTree::Subtree(st) => st.delimiter.map_or(true, |d| {
-            token_map.synthetic_token_id(d.id).is_none()
-                || token_map.synthetic_token_id(d.id) != Some(EMPTY_ID)
-        }),
-    });
-    tt.token_trees.iter_mut().for_each(|tt| match tt {
-        tt::TokenTree::Subtree(tt) => reverse_fixups(tt, token_map, undo_info),
-        tt::TokenTree::Leaf(leaf) => {
-            if let Some(id) = token_map.synthetic_token_id(leaf.id()) {
-                let original = &undo_info.original[id.0 as usize];
-                *tt = tt::TokenTree::Subtree(original.clone());
+    let tts = std::mem::take(&mut tt.token_trees);
+    tt.token_trees = tts
+        .into_iter()
+        .filter(|tt| match tt {
+            tt::TokenTree::Leaf(leaf) => token_map.synthetic_token_id(leaf.id()) != Some(EMPTY_ID),
+            tt::TokenTree::Subtree(st) => {
+                st.delimiter.map_or(true, |d| token_map.synthetic_token_id(d.id) != Some(EMPTY_ID))
             }
-        }
-    });
+        })
+        .flat_map(|tt| match tt {
+            tt::TokenTree::Subtree(mut tt) => {
+                reverse_fixups(&mut tt, token_map, undo_info);
+                SmallVec::from_const([tt.into()])
+            }
+            tt::TokenTree::Leaf(leaf) => {
+                if let Some(id) = token_map.synthetic_token_id(leaf.id()) {
+                    let original = undo_info.original[id.0 as usize].clone();
+                    if original.delimiter.is_none() {
+                        original.token_trees.into()
+                    } else {
+                        SmallVec::from_const([original.into()])
+                    }
+                } else {
+                    SmallVec::from_const([leaf.into()])
+                }
+            }
+        })
+        .collect();
 }
 
 #[cfg(test)]
@@ -319,6 +329,31 @@ mod tests {
 
     use super::reverse_fixups;
 
+    // The following three functions are only meant to check partial structural equivalence of
+    // `TokenTree`s, see the last assertion in `check()`.
+    fn check_leaf_eq(a: &tt::Leaf, b: &tt::Leaf) -> bool {
+        match (a, b) {
+            (tt::Leaf::Literal(a), tt::Leaf::Literal(b)) => a.text == b.text,
+            (tt::Leaf::Punct(a), tt::Leaf::Punct(b)) => a.char == b.char,
+            (tt::Leaf::Ident(a), tt::Leaf::Ident(b)) => a.text == b.text,
+            _ => false,
+        }
+    }
+
+    fn check_subtree_eq(a: &tt::Subtree, b: &tt::Subtree) -> bool {
+        a.delimiter.map(|it| it.kind) == b.delimiter.map(|it| it.kind)
+            && a.token_trees.len() == b.token_trees.len()
+            && a.token_trees.iter().zip(&b.token_trees).all(|(a, b)| check_tt_eq(a, b))
+    }
+
+    fn check_tt_eq(a: &tt::TokenTree, b: &tt::TokenTree) -> bool {
+        match (a, b) {
+            (tt::TokenTree::Leaf(a), tt::TokenTree::Leaf(b)) => check_leaf_eq(a, b),
+            (tt::TokenTree::Subtree(a), tt::TokenTree::Subtree(b)) => check_subtree_eq(a, b),
+            _ => false,
+        }
+    }
+
     #[track_caller]
     fn check(ra_fixture: &str, mut expect: Expect) {
         let parsed = syntax::SourceFile::parse(ra_fixture);
@@ -331,17 +366,15 @@ mod tests {
             fixups.append,
         );
 
-        let mut actual = tt.to_string();
-        actual.push('\n');
+        let actual = format!("{}\n", tt);
 
         expect.indent(false);
         expect.assert_eq(&actual);
 
         // the fixed-up tree should be syntactically valid
         let (parse, _) = mbe::token_tree_to_syntax_node(&tt, ::mbe::TopEntryPoint::MacroItems);
-        assert_eq!(
-            parse.errors(),
-            &[],
+        assert!(
+            parse.errors().is_empty(),
             "parse has syntax errors. parse tree:\n{:#?}",
             parse.syntax_node()
         );
@@ -349,9 +382,12 @@ mod tests {
         reverse_fixups(&mut tt, &tmap, &fixups.undo_info);
 
         // the fixed-up + reversed version should be equivalent to the original input
-        // (but token IDs don't matter)
+        // modulo token IDs and `Punct`s' spacing.
         let (original_as_tt, _) = mbe::syntax_node_to_token_tree(&parsed.syntax_node());
-        assert_eq!(tt.to_string(), original_as_tt.to_string());
+        assert!(
+            check_subtree_eq(&tt, &original_as_tt),
+            "different token tree: {tt:?}, {original_as_tt:?}"
+        );
     }
 
     #[test]
@@ -468,7 +504,7 @@ fn foo() {
 }
 "#,
             expect![[r#"
-fn foo () {a .__ra_fixup}
+fn foo () {a . __ra_fixup}
 "#]],
         )
     }
@@ -482,7 +518,7 @@ fn foo() {
 }
 "#,
             expect![[r#"
-fn foo () {a .__ra_fixup ;}
+fn foo () {a . __ra_fixup ;}
 "#]],
         )
     }
@@ -497,7 +533,7 @@ fn foo() {
 }
 "#,
             expect![[r#"
-fn foo () {a .__ra_fixup ; bar () ;}
+fn foo () {a . __ra_fixup ; bar () ;}
 "#]],
         )
     }
@@ -525,7 +561,7 @@ fn foo() {
 }
 "#,
             expect![[r#"
-fn foo () {let x = a .__ra_fixup ;}
+fn foo () {let x = a . __ra_fixup ;}
 "#]],
         )
     }
@@ -541,7 +577,7 @@ fn foo() {
 }
 "#,
             expect![[r#"
-fn foo () {a .b ; bar () ;}
+fn foo () {a . b ; bar () ;}
 "#]],
         )
     }
diff --git a/crates/mbe/src/syntax_bridge.rs b/crates/mbe/src/syntax_bridge.rs
index e4c56565b92..cf53c16726b 100644
--- a/crates/mbe/src/syntax_bridge.rs
+++ b/crates/mbe/src/syntax_bridge.rs
@@ -12,6 +12,9 @@ use tt::buffer::{Cursor, TokenBuffer};
 
 use crate::{to_parser_input::to_parser_input, tt_iter::TtIter, TokenMap};
 
+#[cfg(test)]
+mod tests;
+
 /// Convert the syntax node to a `TokenTree` (what macro
 /// will consume).
 pub fn syntax_node_to_token_tree(node: &SyntaxNode) -> (tt::Subtree, TokenMap) {
@@ -35,7 +38,7 @@ pub fn syntax_node_to_token_tree_with_modifications(
     append: FxHashMap<SyntaxElement, Vec<SyntheticToken>>,
 ) -> (tt::Subtree, TokenMap, u32) {
     let global_offset = node.text_range().start();
-    let mut c = Convertor::new(node, global_offset, existing_token_map, next_id, replace, append);
+    let mut c = Converter::new(node, global_offset, existing_token_map, next_id, replace, append);
     let subtree = convert_tokens(&mut c);
     c.id_alloc.map.shrink_to_fit();
     always!(c.replace.is_empty(), "replace: {:?}", c.replace);
@@ -100,7 +103,7 @@ pub fn parse_to_token_tree(text: &str) -> Option<(tt::Subtree, TokenMap)> {
         return None;
     }
 
-    let mut conv = RawConvertor {
+    let mut conv = RawConverter {
         lexed,
         pos: 0,
         id_alloc: TokenIdAlloc {
@@ -148,7 +151,7 @@ pub fn parse_exprs_with_sep(tt: &tt::Subtree, sep: char) -> Vec<tt::Subtree> {
     res
 }
 
-fn convert_tokens<C: TokenConvertor>(conv: &mut C) -> tt::Subtree {
+fn convert_tokens<C: TokenConverter>(conv: &mut C) -> tt::Subtree {
     struct StackEntry {
         subtree: tt::Subtree,
         idx: usize,
@@ -228,7 +231,7 @@ fn convert_tokens<C: TokenConvertor>(conv: &mut C) -> tt::Subtree {
             }
 
             let spacing = match conv.peek().map(|next| next.kind(conv)) {
-                Some(kind) if !kind.is_trivia() => tt::Spacing::Joint,
+                Some(kind) if is_single_token_op(kind) => tt::Spacing::Joint,
                 _ => tt::Spacing::Alone,
             };
             let char = match token.to_char(conv) {
@@ -307,6 +310,35 @@ fn convert_tokens<C: TokenConvertor>(conv: &mut C) -> tt::Subtree {
     }
 }
 
+fn is_single_token_op(kind: SyntaxKind) -> bool {
+    matches!(
+        kind,
+        EQ | L_ANGLE
+            | R_ANGLE
+            | BANG
+            | AMP
+            | PIPE
+            | TILDE
+            | AT
+            | DOT
+            | COMMA
+            | SEMICOLON
+            | COLON
+            | POUND
+            | DOLLAR
+            | QUESTION
+            | PLUS
+            | MINUS
+            | STAR
+            | SLASH
+            | PERCENT
+            | CARET
+            // LIFETIME_IDENT will be split into a sequence of `'` (a single quote) and an
+            // identifier.
+            | LIFETIME_IDENT
+    )
+}
+
 /// Returns the textual content of a doc comment block as a quoted string
 /// That is, strips leading `///` (or `/**`, etc)
 /// and strips the ending `*/`
@@ -425,8 +457,8 @@ impl TokenIdAlloc {
     }
 }
 
-/// A raw token (straight from lexer) convertor
-struct RawConvertor<'a> {
+/// A raw token (straight from lexer) converter
+struct RawConverter<'a> {
     lexed: parser::LexedStr<'a>,
     pos: usize,
     id_alloc: TokenIdAlloc,
@@ -442,7 +474,7 @@ trait SrcToken<Ctx>: std::fmt::Debug {
     fn synthetic_id(&self, ctx: &Ctx) -> Option<SyntheticTokenId>;
 }
 
-trait TokenConvertor: Sized {
+trait TokenConverter: Sized {
     type Token: SrcToken<Self>;
 
     fn convert_doc_comment(&self, token: &Self::Token) -> Option<Vec<tt::TokenTree>>;
@@ -454,25 +486,25 @@ trait TokenConvertor: Sized {
     fn id_alloc(&mut self) -> &mut TokenIdAlloc;
 }
 
-impl<'a> SrcToken<RawConvertor<'a>> for usize {
-    fn kind(&self, ctx: &RawConvertor<'a>) -> SyntaxKind {
+impl<'a> SrcToken<RawConverter<'a>> for usize {
+    fn kind(&self, ctx: &RawConverter<'a>) -> SyntaxKind {
         ctx.lexed.kind(*self)
     }
 
-    fn to_char(&self, ctx: &RawConvertor<'a>) -> Option<char> {
+    fn to_char(&self, ctx: &RawConverter<'a>) -> Option<char> {
         ctx.lexed.text(*self).chars().next()
     }
 
-    fn to_text(&self, ctx: &RawConvertor<'_>) -> SmolStr {
+    fn to_text(&self, ctx: &RawConverter<'_>) -> SmolStr {
         ctx.lexed.text(*self).into()
     }
 
-    fn synthetic_id(&self, _ctx: &RawConvertor<'a>) -> Option<SyntheticTokenId> {
+    fn synthetic_id(&self, _ctx: &RawConverter<'a>) -> Option<SyntheticTokenId> {
         None
     }
 }
 
-impl<'a> TokenConvertor for RawConvertor<'a> {
+impl<'a> TokenConverter for RawConverter<'a> {
     type Token = usize;
 
     fn convert_doc_comment(&self, &token: &usize) -> Option<Vec<tt::TokenTree>> {
@@ -504,7 +536,7 @@ impl<'a> TokenConvertor for RawConvertor<'a> {
     }
 }
 
-struct Convertor {
+struct Converter {
     id_alloc: TokenIdAlloc,
     current: Option<SyntaxToken>,
     current_synthetic: Vec<SyntheticToken>,
@@ -515,7 +547,7 @@ struct Convertor {
     punct_offset: Option<(SyntaxToken, TextSize)>,
 }
 
-impl Convertor {
+impl Converter {
     fn new(
         node: &SyntaxNode,
         global_offset: TextSize,
@@ -523,11 +555,11 @@ impl Convertor {
         next_id: u32,
         mut replace: FxHashMap<SyntaxElement, Vec<SyntheticToken>>,
         mut append: FxHashMap<SyntaxElement, Vec<SyntheticToken>>,
-    ) -> Convertor {
+    ) -> Converter {
         let range = node.text_range();
         let mut preorder = node.preorder_with_tokens();
         let (first, synthetic) = Self::next_token(&mut preorder, &mut replace, &mut append);
-        Convertor {
+        Converter {
             id_alloc: { TokenIdAlloc { map: existing_token_map, global_offset, next_id } },
             current: first,
             current_synthetic: synthetic,
@@ -590,15 +622,15 @@ impl SynToken {
     }
 }
 
-impl SrcToken<Convertor> for SynToken {
-    fn kind(&self, _ctx: &Convertor) -> SyntaxKind {
+impl SrcToken<Converter> for SynToken {
+    fn kind(&self, ctx: &Converter) -> SyntaxKind {
         match self {
             SynToken::Ordinary(token) => token.kind(),
-            SynToken::Punch(token, _) => token.kind(),
+            SynToken::Punch(..) => SyntaxKind::from_char(self.to_char(ctx).unwrap()).unwrap(),
             SynToken::Synthetic(token) => token.kind,
         }
     }
-    fn to_char(&self, _ctx: &Convertor) -> Option<char> {
+    fn to_char(&self, _ctx: &Converter) -> Option<char> {
         match self {
             SynToken::Ordinary(_) => None,
             SynToken::Punch(it, i) => it.text().chars().nth((*i).into()),
@@ -606,7 +638,7 @@ impl SrcToken<Convertor> for SynToken {
             SynToken::Synthetic(_) => None,
         }
     }
-    fn to_text(&self, _ctx: &Convertor) -> SmolStr {
+    fn to_text(&self, _ctx: &Converter) -> SmolStr {
         match self {
             SynToken::Ordinary(token) => token.text().into(),
             SynToken::Punch(token, _) => token.text().into(),
@@ -614,7 +646,7 @@ impl SrcToken<Convertor> for SynToken {
         }
     }
 
-    fn synthetic_id(&self, _ctx: &Convertor) -> Option<SyntheticTokenId> {
+    fn synthetic_id(&self, _ctx: &Converter) -> Option<SyntheticTokenId> {
         match self {
             SynToken::Synthetic(token) => Some(token.id),
             _ => None,
@@ -622,7 +654,7 @@ impl SrcToken<Convertor> for SynToken {
     }
 }
 
-impl TokenConvertor for Convertor {
+impl TokenConverter for Converter {
     type Token = SynToken;
     fn convert_doc_comment(&self, token: &Self::Token) -> Option<Vec<tt::TokenTree>> {
         convert_doc_comment(token.token()?)
@@ -651,7 +683,7 @@ impl TokenConvertor for Convertor {
         }
 
         let curr = self.current.clone()?;
-        if !&self.range.contains_range(curr.text_range()) {
+        if !self.range.contains_range(curr.text_range()) {
             return None;
         }
         let (new_current, new_synth) =
@@ -809,12 +841,15 @@ impl<'a> TtTreeSink<'a> {
         let next = last.bump();
         if let (
             Some(tt::buffer::TokenTreeRef::Leaf(tt::Leaf::Punct(curr), _)),
-            Some(tt::buffer::TokenTreeRef::Leaf(tt::Leaf::Punct(_), _)),
+            Some(tt::buffer::TokenTreeRef::Leaf(tt::Leaf::Punct(next), _)),
         ) = (last.token_tree(), next.token_tree())
         {
             // Note: We always assume the semi-colon would be the last token in
             // other parts of RA such that we don't add whitespace here.
-            if curr.spacing == tt::Spacing::Alone && curr.char != ';' {
+            //
+            // When `next` is a `Punct` of `'`, that's a part of a lifetime identifier so we don't
+            // need to add whitespace either.
+            if curr.spacing == tt::Spacing::Alone && curr.char != ';' && next.char != '\'' {
                 self.inner.token(WHITESPACE, " ");
                 self.text_pos += TextSize::of(' ');
             }
diff --git a/crates/mbe/src/syntax_bridge/tests.rs b/crates/mbe/src/syntax_bridge/tests.rs
new file mode 100644
index 00000000000..4e04d2bc1c7
--- /dev/null
+++ b/crates/mbe/src/syntax_bridge/tests.rs
@@ -0,0 +1,93 @@
+use std::collections::HashMap;
+
+use syntax::{ast, AstNode};
+use test_utils::extract_annotations;
+use tt::{
+    buffer::{TokenBuffer, TokenTreeRef},
+    Leaf, Punct, Spacing,
+};
+
+use super::syntax_node_to_token_tree;
+
+fn check_punct_spacing(fixture: &str) {
+    let source_file = ast::SourceFile::parse(fixture).ok().unwrap();
+    let (subtree, token_map) = syntax_node_to_token_tree(source_file.syntax());
+    let mut annotations: HashMap<_, _> = extract_annotations(fixture)
+        .into_iter()
+        .map(|(range, annotation)| {
+            let token = token_map.token_by_range(range).expect("no token found");
+            let spacing = match annotation.as_str() {
+                "Alone" => Spacing::Alone,
+                "Joint" => Spacing::Joint,
+                a => panic!("unknown annotation: {}", a),
+            };
+            (token, spacing)
+        })
+        .collect();
+
+    let buf = TokenBuffer::from_subtree(&subtree);
+    let mut cursor = buf.begin();
+    while !cursor.eof() {
+        while let Some(token_tree) = cursor.token_tree() {
+            if let TokenTreeRef::Leaf(Leaf::Punct(Punct { spacing, id, .. }), _) = token_tree {
+                if let Some(expected) = annotations.remove(&id) {
+                    assert_eq!(expected, *spacing);
+                }
+            }
+            cursor = cursor.bump_subtree();
+        }
+        cursor = cursor.bump();
+    }
+
+    assert!(annotations.is_empty(), "unchecked annotations: {:?}", annotations);
+}
+
+#[test]
+fn punct_spacing() {
+    check_punct_spacing(
+        r#"
+fn main() {
+    0+0;
+   //^ Alone
+    0+(0);
+   //^ Alone
+    0<=0;
+   //^ Joint
+   // ^ Alone
+    0<=(0);
+   // ^ Alone
+    a=0;
+   //^ Alone
+    a=(0);
+   //^ Alone
+    a+=0;
+   //^ Joint
+   // ^ Alone
+    a+=(0);
+   // ^ Alone
+    a&&b;
+   //^ Joint
+   // ^ Alone
+    a&&(b);
+   // ^ Alone
+    foo::bar;
+   //  ^ Joint
+   //   ^ Alone
+    use foo::{bar,baz,};
+   //       ^ Alone
+   //            ^ Alone
+   //                ^ Alone
+    struct Struct<'a> {};
+   //            ^ Joint
+   //             ^ Joint
+    Struct::<0>;
+   //       ^ Alone
+    Struct::<{0}>;
+   //       ^ Alone
+    ;;
+  //^ Joint
+  // ^ Alone
+}
+        "#,
+    );
+}