about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/proc-macro-srv/src/server.rs55
-rw-r--r--crates/proc-macro-srv/src/server/rust_analyzer_span.rs74
-rw-r--r--crates/proc-macro-srv/src/server/token_id.rs69
3 files changed, 76 insertions, 122 deletions
diff --git a/crates/proc-macro-srv/src/server.rs b/crates/proc-macro-srv/src/server.rs
index ff8fd295d88..9339d226a2c 100644
--- a/crates/proc-macro-srv/src/server.rs
+++ b/crates/proc-macro-srv/src/server.rs
@@ -17,6 +17,7 @@ pub mod rust_analyzer_span;
 mod symbol;
 pub mod token_id;
 pub use symbol::*;
+use syntax::ast::{self, HasModuleItem, IsString};
 use tt::Spacing;
 
 fn delim_to_internal<S>(d: proc_macro::Delimiter, span: bridge::DelimSpan<S>) -> tt::Delimiter<S> {
@@ -54,6 +55,60 @@ fn spacing_to_external(spacing: Spacing) -> proc_macro::Spacing {
     }
 }
 
+fn literal_to_external(literal: ast::LiteralKind) -> Option<proc_macro::bridge::LitKind> {
+    Some(match lit.kind() {
+        ast::LiteralKind::String(data) => {
+            if data.is_raw() {
+                bridge::LitKind::StrRaw(raw_delimiter_count(data)?)
+            } else {
+                bridge::LitKind::Str
+            }
+        }
+        ast::LiteralKind::ByteString(data) => {
+            if data.is_raw() {
+                bridge::LitKind::ByteStrRaw(raw_delimiter_count(data)?)
+            } else {
+                bridge::LitKind::ByteStr
+            }
+        }
+        ast::LiteralKind::CString(data) => {
+            if data.is_raw() {
+                bridge::LitKind::CStrRaw(raw_delimiter_count(data)?)
+            } else {
+                bridge::LitKind::CStr
+            }
+        }
+        ast::LiteralKind::IntNumber(num) => bridge::LitKind::Integer,
+        ast::LiteralKind::FloatNumber(num) => bridge::LitKind::Float,
+        ast::LiteralKind::Char(_) => bridge::LitKind::Char,
+        ast::LiteralKind::Byte(_) => bridge::LitKind::Byte,
+        ast::LiteralKind::Bool(_) => unreachable!(),
+    })
+}
+
+fn raw_delimiter_count<S: IsString>(s: S) -> Option<u8> {
+    let text = s.text();
+    let quote_range = s.text_range_between_quotes()?;
+    let range_start = s.syntax().text_range().start();
+    text[TextRange::up_to((quote_range - range_start).start())].matches('#').count().try_into().ok()
+}
+
+fn str_to_lit_node(input: &str) -> Option<ast::Literal> {
+    let input = input.trim();
+    let source_code = format!("fn f() {{ let _ = {input}; }}");
+
+    let parse = ast::SourceFile::parse(&source_code);
+    let file = parse.tree();
+
+    let ast::Item::Fn(func) = file.items().next()? else { return None };
+    let ast::Stmt::LetStmt(stmt) = func.body()?.stmt_list()?.statements().next()? else {
+        return None;
+    };
+    let ast::Expr::Literal(lit) = stmt.initializer()? else { return None };
+
+    Some(lit)
+}
+
 struct LiteralFormatter<S>(bridge::Literal<S, Symbol>);
 
 impl<S> LiteralFormatter<S> {
diff --git a/crates/proc-macro-srv/src/server/rust_analyzer_span.rs b/crates/proc-macro-srv/src/server/rust_analyzer_span.rs
index f7bbbcc09d4..4a2ad40ad99 100644
--- a/crates/proc-macro-srv/src/server/rust_analyzer_span.rs
+++ b/crates/proc-macro-srv/src/server/rust_analyzer_span.rs
@@ -4,6 +4,7 @@
 //! It is an unfortunate result of how the proc-macro API works that we need to look into the
 //! concrete representation of the spans, and as such, RustRover cannot make use of this unless they
 //! change their representation to be compatible with rust-analyzer's.
+use core::num;
 use std::{
     collections::{HashMap, HashSet},
     iter,
@@ -13,11 +14,11 @@ use std::{
 use ::tt::{TextRange, TextSize};
 use proc_macro::bridge::{self, server};
 use span::{Span, FIXUP_ERASED_FILE_AST_ID_MARKER};
-use syntax::ast::{self, HasModuleItem, IsString};
+use syntax::ast;
 
 use crate::server::{
-    delim_to_external, delim_to_internal, token_stream::TokenStreamBuilder, LiteralFormatter,
-    Symbol, SymbolInternerRef, SYMBOL_INTERNER,
+    delim_to_external, delim_to_internal, literal_to_external, str_to_lit_node,
+    token_stream::TokenStreamBuilder, LiteralFormatter, Symbol, SymbolInternerRef, SYMBOL_INTERNER,
 };
 mod tt {
     pub use ::tt::*;
@@ -71,66 +72,15 @@ impl server::FreeFunctions for RaSpanServer {
         &mut self,
         s: &str,
     ) -> Result<bridge::Literal<Self::Span, Self::Symbol>, ()> {
-        let input = s.trim();
-        let source_code = format!("fn f() {{ let _ = {input}; }}");
-
-        let parse = ast::SourceFile::parse(&source_code);
-        let file = parse.tree();
-
-        let Some(ast::Item::Fn(func)) = file.items().next() else { return Err(()) };
-        let Some(ast::Stmt::LetStmt(stmt)) =
-            func.body().ok_or(Err(()))?.stmt_list().ok_or(Err(()))?.statements().next()
-        else {
-            return Err(());
-        };
-        let Some(ast::Expr::Literal(lit)) = stmt.initializer() else { return Err(()) };
-
-        fn raw_delimiter_count<S: IsString>(s: S) -> Option<u8> {
-            let text = s.text();
-            let quote_range = s.text_range_between_quotes()?;
-            let range_start = s.syntax().text_range().start();
-            text[TextRange::up_to((quote_range - range_start).start())]
-                .matches('#')
-                .count()
-                .try_into()
-                .ok()
-        }
+        let literal = str_to_lit_node(s).ok_or(Err(()))?;
 
-        let mut suffix = None;
-        let kind = match lit.kind() {
-            ast::LiteralKind::String(data) => {
-                if data.is_raw() {
-                    bridge::LitKind::StrRaw(raw_delimiter_count(data).ok_or(Err(()))?)
-                } else {
-                    bridge::LitKind::Str
-                }
-            }
-            ast::LiteralKind::ByteString(data) => {
-                if data.is_raw() {
-                    bridge::LitKind::ByteStrRaw(raw_delimiter_count(data).ok_or(Err(()))?)
-                } else {
-                    bridge::LitKind::ByteStr
-                }
-            }
-            ast::LiteralKind::CString(data) => {
-                if data.is_raw() {
-                    bridge::LitKind::CStrRaw(raw_delimiter_count(data).ok_or(Err(()))?)
-                } else {
-                    bridge::LitKind::CStr
-                }
-            }
-            ast::LiteralKind::IntNumber(num) => {
-                suffix = num.suffix();
-                bridge::LitKind::Integer
-            }
-            ast::LiteralKind::FloatNumber(num) => {
-                suffix = num.suffix();
-                bridge::LitKind::Float
-            }
-            ast::LiteralKind::Char(_) => bridge::LitKind::Char,
-            ast::LiteralKind::Byte(_) => bridge::LitKind::Byte,
-            ast::LiteralKind::Bool(_) => unreachable!(),
-        };
+        let kind = literal_to_external(literal.kind()).ok_or(Err(()))?;
+
+        let suffix = match literal.kind() {
+            ast::LiteralKind::FloatNumber(num) | ast::LiteralKind::IntNumber(num) => num.suffix(),
+            _ => None,
+        }
+        .map(|suffix| Symbol::intern(self.interner, suffix));
 
         Ok(bridge::Literal {
             kind,
diff --git a/crates/proc-macro-srv/src/server/token_id.rs b/crates/proc-macro-srv/src/server/token_id.rs
index 5c74c15b360..4dbda7e53c2 100644
--- a/crates/proc-macro-srv/src/server/token_id.rs
+++ b/crates/proc-macro-srv/src/server/token_id.rs
@@ -6,11 +6,11 @@ use std::{
 };
 
 use proc_macro::bridge::{self, server};
-use syntax::ast::{self, HasModuleItem, IsString};
+use syntax::ast;
 
 use crate::server::{
-    delim_to_external, delim_to_internal, token_stream::TokenStreamBuilder, LiteralFormatter,
-    Symbol, SymbolInternerRef, SYMBOL_INTERNER,
+    delim_to_external, delim_to_internal, literal_to_external, str_to_lit_node,
+    token_stream::TokenStreamBuilder, LiteralFormatter, Symbol, SymbolInternerRef, SYMBOL_INTERNER,
 };
 mod tt {
     pub use proc_macro_api::msg::TokenId;
@@ -63,66 +63,15 @@ impl server::FreeFunctions for TokenIdServer {
         &mut self,
         s: &str,
     ) -> Result<bridge::Literal<Self::Span, Self::Symbol>, ()> {
-        let input = s.trim();
-        let source_code = format!("fn f() {{ let _ = {input}; }}");
+        let literal = str_to_lit_node(s).ok_or(Err(()))?;
 
-        let parse = ast::SourceFile::parse(&source_code);
-        let file = parse.tree();
+        let kind = literal_to_external(literal.kind()).ok_or(Err(()))?;
 
-        let Some(ast::Item::Fn(func)) = file.items().next() else { return Err(()) };
-        let Some(ast::Stmt::LetStmt(stmt)) =
-            func.body().ok_or(Err(()))?.stmt_list().ok_or(Err(()))?.statements().next()
-        else {
-            return Err(());
-        };
-        let Some(ast::Expr::Literal(lit)) = stmt.initializer() else { return Err(()) };
-
-        fn raw_delimiter_count<S: IsString>(s: S) -> Option<u8> {
-            let text = s.text();
-            let quote_range = s.text_range_between_quotes()?;
-            let range_start = s.syntax().text_range().start();
-            text[TextRange::up_to((quote_range - range_start).start())]
-                .matches('#')
-                .count()
-                .try_into()
-                .ok()
+        let suffix = match literal.kind() {
+            ast::LiteralKind::FloatNumber(num) | ast::LiteralKind::IntNumber(num) => num.suffix(),
+            _ => None,
         }
-
-        let mut suffix = None;
-        let kind = match lit.kind() {
-            ast::LiteralKind::String(data) => {
-                if data.is_raw() {
-                    bridge::LitKind::StrRaw(raw_delimiter_count(data).ok_or(Err(()))?)
-                } else {
-                    bridge::LitKind::Str
-                }
-            }
-            ast::LiteralKind::ByteString(data) => {
-                if data.is_raw() {
-                    bridge::LitKind::ByteStrRaw(raw_delimiter_count(data).ok_or(Err(()))?)
-                } else {
-                    bridge::LitKind::ByteStr
-                }
-            }
-            ast::LiteralKind::CString(data) => {
-                if data.is_raw() {
-                    bridge::LitKind::CStrRaw(raw_delimiter_count(data).ok_or(Err(()))?)
-                } else {
-                    bridge::LitKind::CStr
-                }
-            }
-            ast::LiteralKind::IntNumber(num) => {
-                suffix = num.suffix();
-                bridge::LitKind::Integer
-            }
-            ast::LiteralKind::FloatNumber(num) => {
-                suffix = num.suffix();
-                bridge::LitKind::Float
-            }
-            ast::LiteralKind::Char(_) => bridge::LitKind::Char,
-            ast::LiteralKind::Byte(_) => bridge::LitKind::Byte,
-            ast::LiteralKind::Bool(_) => unreachable!(),
-        };
+        .map(|suffix| Symbol::intern(self.interner, suffix));
 
         Ok(bridge::Literal {
             kind,