about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRyo Yoshida <low.ryoshida@gmail.com>2023-02-08 20:11:54 +0900
committerRyo Yoshida <low.ryoshida@gmail.com>2023-02-11 18:19:08 +0900
commitae7e62c50fb2974430cf401e24747bafef7f701a (patch)
treecc2205a555f43e195146b7f454378dad629c30e6
parent8011029d3a0f4014217e1ade75688c0f3c5305db (diff)
downloadrust-ae7e62c50fb2974430cf401e24747bafef7f701a.tar.gz
rust-ae7e62c50fb2974430cf401e24747bafef7f701a.zip
Don't expand macros in the same expansion tree after overflow
-rw-r--r--crates/hir-def/src/body.rs140
-rw-r--r--crates/hir-def/src/body/lower.rs6
-rw-r--r--crates/hir-def/src/body/tests.rs13
-rw-r--r--crates/hir-expand/src/lib.rs4
4 files changed, 114 insertions, 49 deletions
diff --git a/crates/hir-def/src/body.rs b/crates/hir-def/src/body.rs
index 9713256813e..8fd9255b8b1 100644
--- a/crates/hir-def/src/body.rs
+++ b/crates/hir-def/src/body.rs
@@ -19,7 +19,7 @@ use la_arena::{Arena, ArenaMap};
 use limit::Limit;
 use profile::Count;
 use rustc_hash::FxHashMap;
-use syntax::{ast, AstPtr, SyntaxNodePtr};
+use syntax::{ast, AstPtr, SyntaxNode, SyntaxNodePtr};
 
 use crate::{
     attr::Attrs,
@@ -51,7 +51,8 @@ pub struct Expander {
     def_map: Arc<DefMap>,
     current_file_id: HirFileId,
     module: LocalModuleId,
-    recursion_limit: usize,
+    /// `recursion_depth == usize::MAX` indicates that the recursion limit has been reached.
+    recursion_depth: usize,
 }
 
 impl CfgExpander {
@@ -84,7 +85,7 @@ impl Expander {
             def_map,
             current_file_id,
             module: module.local_id,
-            recursion_limit: 0,
+            recursion_depth: 0,
         }
     }
 
@@ -93,31 +94,37 @@ impl Expander {
         db: &dyn DefDatabase,
         macro_call: ast::MacroCall,
     ) -> Result<ExpandResult<Option<(Mark, T)>>, UnresolvedMacro> {
-        if self.recursion_limit(db).check(self.recursion_limit + 1).is_err() {
-            cov_mark::hit!(your_stack_belongs_to_me);
-            return Ok(ExpandResult::only_err(ExpandError::Other(
-                "reached recursion limit during macro expansion".into(),
-            )));
+        let mut unresolved_macro_err = None;
+
+        let result = self.within_limit(db, |this| {
+            let macro_call = InFile::new(this.current_file_id, &macro_call);
+
+            let resolver =
+                |path| this.resolve_path_as_macro(db, &path).map(|it| macro_id_to_def_id(db, it));
+
+            let mut err = None;
+            let call_id = match macro_call.as_call_id_with_errors(
+                db,
+                this.def_map.krate(),
+                resolver,
+                &mut |e| {
+                    err.get_or_insert(e);
+                },
+            ) {
+                Ok(call_id) => call_id,
+                Err(resolve_err) => {
+                    unresolved_macro_err = Some(resolve_err);
+                    return ExpandResult { value: None, err: None };
+                }
+            };
+            ExpandResult { value: call_id.ok(), err }
+        });
+
+        if let Some(err) = unresolved_macro_err {
+            Err(err)
+        } else {
+            Ok(result)
         }
-
-        let macro_call = InFile::new(self.current_file_id, &macro_call);
-
-        let resolver =
-            |path| self.resolve_path_as_macro(db, &path).map(|it| macro_id_to_def_id(db, it));
-
-        let mut err = None;
-        let call_id =
-            macro_call.as_call_id_with_errors(db, self.def_map.krate(), resolver, &mut |e| {
-                err.get_or_insert(e);
-            })?;
-        let call_id = match call_id {
-            Ok(it) => it,
-            Err(_) => {
-                return Ok(ExpandResult { value: None, err });
-            }
-        };
-
-        Ok(self.enter_expand_inner(db, call_id, err))
     }
 
     pub fn enter_expand_id<T: ast::AstNode>(
@@ -125,15 +132,14 @@ impl Expander {
         db: &dyn DefDatabase,
         call_id: MacroCallId,
     ) -> ExpandResult<Option<(Mark, T)>> {
-        self.enter_expand_inner(db, call_id, None)
+        self.within_limit(db, |_this| ExpandResult::ok(Some(call_id)))
     }
 
-    fn enter_expand_inner<T: ast::AstNode>(
-        &mut self,
+    fn enter_expand_inner(
         db: &dyn DefDatabase,
         call_id: MacroCallId,
         mut err: Option<ExpandError>,
-    ) -> ExpandResult<Option<(Mark, T)>> {
+    ) -> ExpandResult<Option<(HirFileId, SyntaxNode)>> {
         if err.is_none() {
             err = db.macro_expand_error(call_id);
         }
@@ -154,29 +160,21 @@ impl Expander {
             }
         };
 
-        let node = match T::cast(raw_node) {
-            Some(it) => it,
-            None => {
-                // This can happen without being an error, so only forward previous errors.
-                return ExpandResult { value: None, err };
-            }
-        };
-
-        tracing::debug!("macro expansion {:#?}", node.syntax());
-
-        self.recursion_limit += 1;
-        let mark =
-            Mark { file_id: self.current_file_id, bomb: DropBomb::new("expansion mark dropped") };
-        self.cfg_expander.hygiene = Hygiene::new(db.upcast(), file_id);
-        self.current_file_id = file_id;
-
-        ExpandResult { value: Some((mark, node)), err }
+        ExpandResult { value: Some((file_id, raw_node)), err }
     }
 
     pub fn exit(&mut self, db: &dyn DefDatabase, mut mark: Mark) {
         self.cfg_expander.hygiene = Hygiene::new(db.upcast(), mark.file_id);
         self.current_file_id = mark.file_id;
-        self.recursion_limit -= 1;
+        if self.recursion_depth == usize::MAX {
+            // Recursion limit has been reached somewhere in the macro expansion tree. Reset the
+            // depth only when we get out of the tree.
+            if !self.current_file_id.is_macro() {
+                self.recursion_depth = 0;
+            }
+        } else {
+            self.recursion_depth -= 1;
+        }
         mark.bomb.defuse();
     }
 
@@ -215,6 +213,50 @@ impl Expander {
         #[cfg(test)]
         return Limit::new(std::cmp::min(32, limit));
     }
+
+    fn within_limit<F, T: ast::AstNode>(
+        &mut self,
+        db: &dyn DefDatabase,
+        op: F,
+    ) -> ExpandResult<Option<(Mark, T)>>
+    where
+        F: FnOnce(&mut Self) -> ExpandResult<Option<MacroCallId>>,
+    {
+        if self.recursion_depth == usize::MAX {
+            // Recursion limit has been reached somewhere in the macro expansion tree. We should
+            // stop expanding other macro calls in this tree, or else this may result in
+            // exponential number of macro expansions, leading to a hang.
+            //
+            // The overflow error should have been reported when it occurred (see the next branch),
+            // so don't return overflow error here to avoid diagnostics duplication.
+            cov_mark::hit!(overflow_but_not_me);
+            return ExpandResult::only_err(ExpandError::RecursionOverflowPosioned);
+        } else if self.recursion_limit(db).check(self.recursion_depth + 1).is_err() {
+            self.recursion_depth = usize::MAX;
+            cov_mark::hit!(your_stack_belongs_to_me);
+            return ExpandResult::only_err(ExpandError::Other(
+                "reached recursion limit during macro expansion".into(),
+            ));
+        }
+
+        let ExpandResult { value, err } = op(self);
+        let Some(call_id) = value else {
+            return ExpandResult { value: None, err };
+        };
+
+        Self::enter_expand_inner(db, call_id, err).map(|value| {
+            value.and_then(|(new_file_id, node)| {
+                let node = T::cast(node)?;
+
+                self.recursion_depth += 1;
+                self.cfg_expander.hygiene = Hygiene::new(db.upcast(), new_file_id);
+                let old_file_id = std::mem::replace(&mut self.current_file_id, new_file_id);
+                let mark =
+                    Mark { file_id: old_file_id, bomb: DropBomb::new("expansion mark dropped") };
+                Some((mark, node))
+            })
+        })
+    }
 }
 
 #[derive(Debug)]
diff --git a/crates/hir-def/src/body/lower.rs b/crates/hir-def/src/body/lower.rs
index a78fa91f53b..04b1c4f01e2 100644
--- a/crates/hir-def/src/body/lower.rs
+++ b/crates/hir-def/src/body/lower.rs
@@ -624,6 +624,10 @@ impl ExprCollector<'_> {
                         krate: *krate,
                     });
                 }
+                Some(ExpandError::RecursionOverflowPosioned) => {
+                    // Recursion limit has been reached in the macro expansion tree, but not in
+                    // this very macro call. Don't add diagnostics to avoid duplication.
+                }
                 Some(err) => {
                     self.source_map.diagnostics.push(BodyDiagnostic::MacroError {
                         node: InFile::new(outer_file, syntax_ptr),
@@ -636,6 +640,8 @@ impl ExprCollector<'_> {
 
         match res.value {
             Some((mark, expansion)) => {
+                // Keep collecting even with expansion errors so we can provide completions and
+                // other services in incomplete macro expressions.
                 self.source_map.expansions.insert(macro_call_ptr, self.expander.current_file_id);
                 let prev_ast_id_map = mem::replace(
                     &mut self.ast_id_map,
diff --git a/crates/hir-def/src/body/tests.rs b/crates/hir-def/src/body/tests.rs
index c9601f85527..edee2c7ff96 100644
--- a/crates/hir-def/src/body/tests.rs
+++ b/crates/hir-def/src/body/tests.rs
@@ -62,6 +62,19 @@ fn main() { n_nuple!(1,2,3); }
 }
 
 #[test]
+fn your_stack_belongs_to_me2() {
+    cov_mark::check!(overflow_but_not_me);
+    lower(
+        r#"
+macro_rules! foo {
+    () => {{ foo!(); foo!(); }}
+}
+fn main() { foo!(); }
+"#,
+    );
+}
+
+#[test]
 fn recursion_limit() {
     cov_mark::check!(your_stack_belongs_to_me);
 
diff --git a/crates/hir-expand/src/lib.rs b/crates/hir-expand/src/lib.rs
index bc941b54172..a52716cc02c 100644
--- a/crates/hir-expand/src/lib.rs
+++ b/crates/hir-expand/src/lib.rs
@@ -55,6 +55,7 @@ pub type ExpandResult<T> = ValueResult<T, ExpandError>;
 pub enum ExpandError {
     UnresolvedProcMacro(CrateId),
     Mbe(mbe::ExpandError),
+    RecursionOverflowPosioned,
     Other(Box<str>),
 }
 
@@ -69,6 +70,9 @@ impl fmt::Display for ExpandError {
         match self {
             ExpandError::UnresolvedProcMacro(_) => f.write_str("unresolved proc-macro"),
             ExpandError::Mbe(it) => it.fmt(f),
+            ExpandError::RecursionOverflowPosioned => {
+                f.write_str("overflow expanding the original macro")
+            }
             ExpandError::Other(it) => f.write_str(it),
         }
     }