about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2023-07-03 14:42:27 +0200
committerLukas Wirth <lukastw97@gmail.com>2023-07-03 15:05:25 +0200
commitbdc3d0f5514d677f31eb362172f564f6470d3c36 (patch)
tree02d6f61a7431e7ad300d044ffa46e90fe57f8ff1
parent321e570d92d5ff32280247d0a67145c4ae4eab48 (diff)
downloadrust-bdc3d0f5514d677f31eb362172f564f6470d3c36.tar.gz
rust-bdc3d0f5514d677f31eb362172f564f6470d3c36.zip
Shuffle some proc_macro_expand query things around
-rw-r--r--crates/hir-expand/src/db.rs87
1 files changed, 54 insertions, 33 deletions
diff --git a/crates/hir-expand/src/db.rs b/crates/hir-expand/src/db.rs
index 78b2db7306b..de653d76b48 100644
--- a/crates/hir-expand/src/db.rs
+++ b/crates/hir-expand/src/db.rs
@@ -57,10 +57,7 @@ impl TokenExpander {
             TokenExpander::BuiltinAttr(it) => it.expand(db, id, tt),
             TokenExpander::BuiltinDerive(it) => it.expand(db, id, tt),
             TokenExpander::ProcMacro(_) => {
-                // We store the result in salsa db to prevent non-deterministic behavior in
-                // some proc-macro implementation
-                // See #4315 for details
-                db.expand_proc_macro(id)
+                unreachable!("ExpandDatabase::expand_proc_macro should be used for proc macros")
             }
         }
     }
@@ -141,8 +138,8 @@ pub trait ExpandDatabase: SourceDatabase {
     /// Special case of the previous query for procedural macros. We can't LRU
     /// proc macros, since they are not deterministic in general, and
     /// non-determinism breaks salsa in a very, very, very bad way.
-    /// @edwin0cheng heroically debugged this once!
-    fn expand_proc_macro(&self, call: MacroCallId) -> ExpandResult<tt::Subtree>;
+    /// @edwin0cheng heroically debugged this once! See #4315 for details
+    fn expand_proc_macro(&self, call: MacroCallId) -> ExpandResult<Arc<tt::Subtree>>;
     /// Firewall query that returns the errors from the `parse_macro_expansion` query.
     fn parse_macro_expansion_error(
         &self,
@@ -297,6 +294,14 @@ fn parse_macro_expansion(
     ExpandResult { value: (parse, Arc::new(rev_token_map)), err }
 }
 
+fn parse_macro_expansion_error(
+    db: &dyn ExpandDatabase,
+    macro_call_id: MacroCallId,
+) -> ExpandResult<Box<[SyntaxError]>> {
+    db.parse_macro_expansion(MacroFile { macro_call_id })
+        .map(|it| it.0.errors().to_vec().into_boxed_slice())
+}
+
 fn macro_arg(
     db: &dyn ExpandDatabase,
     id: MacroCallId,
@@ -445,6 +450,11 @@ fn macro_expand(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult<Arc<tt
         // This is an input expansion for an eager macro. These are already pre-expanded
         return ExpandResult { value: Arc::new(arg.0.clone()), err: error.clone() };
     }
+
+    if let MacroDefKind::ProcMacro(..) = loc.def.kind {
+        return db.expand_proc_macro(id);
+    }
+
     let expander = match db.macro_def(loc.def) {
         Ok(it) => it,
         // FIXME: We should make sure to enforce a variant that invalid macro
@@ -467,7 +477,7 @@ fn macro_expand(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult<Arc<tt
                     token_trees: Vec::new(),
                 },
             ),
-            // FIXME: We should make sure to enforce a variant that invalid macro
+            // FIXME: We should make sure to enforce an invariant that invalid macro
             // calls do not reach this call path!
             err: Some(ExpandError::other(
                 "invalid token tree"
@@ -483,19 +493,8 @@ fn macro_expand(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult<Arc<tt
     }
 
     // Set a hard limit for the expanded tt
-    let count = tt.count();
-    if TOKEN_LIMIT.check(count).is_err() {
-        return ExpandResult {
-            value: Arc::new(tt::Subtree {
-                delimiter: tt::Delimiter::UNSPECIFIED,
-                token_trees: vec![],
-            }),
-            err: Some(ExpandError::other(format!(
-                "macro invocation exceeds token limit: produced {} tokens, limit is {}",
-                count,
-                TOKEN_LIMIT.inner(),
-            ))),
-        };
+    if let Err(value) = check_tt_count(&tt) {
+        return value;
     }
 
     fixup::reverse_fixups(&mut tt, arg_tm, undo_info);
@@ -503,27 +502,20 @@ fn macro_expand(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult<Arc<tt
     ExpandResult { value: Arc::new(tt), err }
 }
 
-fn parse_macro_expansion_error(
-    db: &dyn ExpandDatabase,
-    macro_call_id: MacroCallId,
-) -> ExpandResult<Box<[SyntaxError]>> {
-    db.parse_macro_expansion(MacroFile { macro_call_id })
-        .map(|it| it.0.errors().to_vec().into_boxed_slice())
-}
-
-fn expand_proc_macro(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult<tt::Subtree> {
+fn expand_proc_macro(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult<Arc<tt::Subtree>> {
     let loc = db.lookup_intern_macro_call(id);
     let Some(macro_arg) = db.macro_arg(id) else {
         return ExpandResult {
-            value: tt::Subtree {
+            value: Arc::new(tt::Subtree {
                 delimiter: tt::Delimiter::UNSPECIFIED,
                 token_trees: Vec::new(),
-            },
+            }),
             err: Some(ExpandError::other(
                 "invalid token tree"
             )),
         };
     };
+    let (arg_tt, arg_tm, undo_info) = &*macro_arg;
 
     let expander = match loc.def.kind {
         MacroDefKind::ProcMacro(expander, ..) => expander,
@@ -533,13 +525,23 @@ fn expand_proc_macro(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult<t
     let attr_arg = match &loc.kind {
         MacroCallKind::Attr { attr_args, .. } => {
             let mut attr_args = attr_args.0.clone();
-            mbe::Shift::new(&macro_arg.0).shift_all(&mut attr_args);
+            mbe::Shift::new(arg_tt).shift_all(&mut attr_args);
             Some(attr_args)
         }
         _ => None,
     };
 
-    expander.expand(db, loc.def.krate, loc.krate, &macro_arg.0, attr_arg.as_ref())
+    let ExpandResult { value: mut tt, err } =
+        expander.expand(db, loc.def.krate, loc.krate, arg_tt, attr_arg.as_ref());
+
+    // Set a hard limit for the expanded tt
+    if let Err(value) = check_tt_count(&tt) {
+        return value;
+    }
+
+    fixup::reverse_fixups(&mut tt, arg_tm, undo_info);
+
+    ExpandResult { value: Arc::new(tt), err }
 }
 
 fn hygiene_frame(db: &dyn ExpandDatabase, file_id: HirFileId) -> Arc<HygieneFrame> {
@@ -563,3 +565,22 @@ fn token_tree_to_syntax_node(
     };
     mbe::token_tree_to_syntax_node(tt, entry_point)
 }
+
+fn check_tt_count(tt: &tt::Subtree) -> Result<(), ExpandResult<Arc<tt::Subtree>>> {
+    let count = tt.count();
+    if TOKEN_LIMIT.check(count).is_err() {
+        Err(ExpandResult {
+            value: Arc::new(tt::Subtree {
+                delimiter: tt::Delimiter::UNSPECIFIED,
+                token_trees: vec![],
+            }),
+            err: Some(ExpandError::other(format!(
+                "macro invocation exceeds token limit: produced {} tokens, limit is {}",
+                count,
+                TOKEN_LIMIT.inner(),
+            ))),
+        })
+    } else {
+        Ok(())
+    }
+}