about summary refs log tree commit diff
diff options
context:
space:
mode:
authorcynecx <me@cynecx.net>2021-04-18 19:56:13 +0200
committercynecx <me@cynecx.net>2021-04-18 19:56:13 +0200
commit6ed2fd233b569d01169fc888f30c358dd289d260 (patch)
treed5ff73fef371ac596e7c98db1c39c9dd3112f571
parent3d39e77003c5fe5ed9f8f3ac00a170f3804f8337 (diff)
downloadrust-6ed2fd233b569d01169fc888f30c358dd289d260.tar.gz
rust-6ed2fd233b569d01169fc888f30c358dd289d260.zip
hir_ty: keep body::Expander in TyLoweringContext
-rw-r--r--crates/hir_def/src/body.rs18
-rw-r--r--crates/hir_def/src/type_ref.rs38
-rw-r--r--crates/hir_ty/src/lower.rs71
-rw-r--r--crates/hir_ty/src/tests/macros.rs26
4 files changed, 94 insertions, 59 deletions
diff --git a/crates/hir_def/src/body.rs b/crates/hir_def/src/body.rs
index 8a9b936ea40..131f424cc8b 100644
--- a/crates/hir_def/src/body.rs
+++ b/crates/hir_def/src/body.rs
@@ -37,13 +37,15 @@ use crate::{
 
 /// A subset of Expander that only deals with cfg attributes. We only need it to
 /// avoid cyclic queries in crate def map during enum processing.
+#[derive(Debug)]
 pub(crate) struct CfgExpander {
     cfg_options: CfgOptions,
     hygiene: Hygiene,
     krate: CrateId,
 }
 
-pub(crate) struct Expander {
+#[derive(Debug)]
+pub struct Expander {
     cfg_expander: CfgExpander,
     def_map: Arc<DefMap>,
     current_file_id: HirFileId,
@@ -80,11 +82,7 @@ impl CfgExpander {
 }
 
 impl Expander {
-    pub(crate) fn new(
-        db: &dyn DefDatabase,
-        current_file_id: HirFileId,
-        module: ModuleId,
-    ) -> Expander {
+    pub fn new(db: &dyn DefDatabase, current_file_id: HirFileId, module: ModuleId) -> Expander {
         let cfg_expander = CfgExpander::new(db, current_file_id, module.krate);
         let def_map = module.def_map(db);
         let ast_id_map = db.ast_id_map(current_file_id);
@@ -98,7 +96,7 @@ impl Expander {
         }
     }
 
-    pub(crate) fn enter_expand<T: ast::AstNode>(
+    pub fn enter_expand<T: ast::AstNode>(
         &mut self,
         db: &dyn DefDatabase,
         macro_call: ast::MacroCall,
@@ -170,7 +168,7 @@ impl Expander {
         Ok(ExpandResult { value: Some((mark, node)), err })
     }
 
-    pub(crate) fn exit(&mut self, db: &dyn DefDatabase, mut mark: Mark) {
+    pub fn exit(&mut self, db: &dyn DefDatabase, mut mark: Mark) {
         self.cfg_expander.hygiene = Hygiene::new(db.upcast(), mark.file_id);
         self.current_file_id = mark.file_id;
         self.ast_id_map = mem::take(&mut mark.ast_id_map);
@@ -190,7 +188,7 @@ impl Expander {
         &self.cfg_expander.cfg_options
     }
 
-    pub(crate) fn current_file_id(&self) -> HirFileId {
+    pub fn current_file_id(&self) -> HirFileId {
         self.current_file_id
     }
 
@@ -210,7 +208,7 @@ impl Expander {
 }
 
 #[derive(Debug)]
-pub(crate) struct Mark {
+pub struct Mark {
     file_id: HirFileId,
     ast_id_map: Arc<AstIdMap>,
     bomb: DropBomb,
diff --git a/crates/hir_def/src/type_ref.rs b/crates/hir_def/src/type_ref.rs
index e18712d2460..ea29da5daae 100644
--- a/crates/hir_def/src/type_ref.rs
+++ b/crates/hir_def/src/type_ref.rs
@@ -1,15 +1,10 @@
 //! HIR for references to types. Paths in these are not yet resolved. They can
 //! be directly created from an ast::TypeRef, without further queries.
 
-use hir_expand::{name::Name, AstId, ExpandResult, InFile};
+use hir_expand::{name::Name, AstId, InFile};
 use syntax::ast;
 
-use crate::{
-    body::{Expander, LowerCtx},
-    db::DefDatabase,
-    path::Path,
-    ModuleId,
-};
+use crate::{body::LowerCtx, path::Path};
 
 #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
 pub enum Mutability {
@@ -124,7 +119,7 @@ pub enum TypeBound {
 
 impl TypeRef {
     /// Converts an `ast::TypeRef` to a `hir::TypeRef`.
-    pub(crate) fn from_ast(ctx: &LowerCtx, node: ast::Type) -> Self {
+    pub fn from_ast(ctx: &LowerCtx, node: ast::Type) -> Self {
         match node {
             ast::Type::ParenType(inner) => TypeRef::from_ast_opt(&ctx, inner.ty()),
             ast::Type::TupleType(inner) => {
@@ -303,30 +298,3 @@ impl TypeBound {
         }
     }
 }
-
-pub fn expand_macro_type(
-    db: &dyn DefDatabase,
-    module_id: ModuleId,
-    macro_type: &TypeRef,
-) -> Option<TypeRef> {
-    let macro_call = match macro_type {
-        TypeRef::Macro(macro_call) => macro_call,
-        _ => panic!("expected TypeRef::Macro"),
-    };
-
-    let file_id = macro_call.file_id;
-    let macro_call = macro_call.to_node(db.upcast());
-
-    let mut expander = Expander::new(db, file_id, module_id);
-    let (file_id, expanded) = match expander.enter_expand::<ast::Type>(db, macro_call.clone()) {
-        Ok(ExpandResult { value: Some((mark, expanded)), .. }) => {
-            let file_id = expander.current_file_id();
-            expander.exit(db, mark);
-            (file_id, expanded)
-        }
-        _ => return None,
-    };
-
-    let ctx = LowerCtx::new(db, file_id);
-    return Some(TypeRef::from_ast(&ctx, expanded));
-}
diff --git a/crates/hir_ty/src/lower.rs b/crates/hir_ty/src/lower.rs
index e01b7aa9190..a883334afbb 100644
--- a/crates/hir_ty/src/lower.rs
+++ b/crates/hir_ty/src/lower.rs
@@ -5,25 +5,28 @@
 //!  - Building the type for an item: This happens through the `type_for_def` query.
 //!
 //! This usually involves resolving names, collecting generic arguments etc.
+use std::cell::{Cell, RefCell};
 use std::{iter, sync::Arc};
 
 use base_db::CrateId;
 use chalk_ir::{cast::Cast, fold::Shift, interner::HasInterner, Mutability, Safety};
 use hir_def::{
     adt::StructKind,
+    body::{Expander, LowerCtx},
     builtin_type::BuiltinType,
     generics::{TypeParamProvenance, WherePredicate, WherePredicateTypeTarget},
     path::{GenericArg, Path, PathSegment, PathSegments},
     resolver::{HasResolver, Resolver, TypeNs},
-    type_ref::{expand_macro_type, TraitRef as HirTraitRef, TypeBound, TypeRef},
+    type_ref::{TraitRef as HirTraitRef, TypeBound, TypeRef},
     AdtId, AssocContainerId, AssocItemId, ConstId, ConstParamId, EnumId, EnumVariantId, FunctionId,
     GenericDefId, HasModule, ImplId, LocalFieldId, Lookup, StaticId, StructId, TraitId,
     TypeAliasId, TypeParamId, UnionId, VariantId,
 };
-use hir_expand::name::Name;
+use hir_expand::{name::Name, ExpandResult};
 use la_arena::ArenaMap;
 use smallvec::SmallVec;
 use stdx::impl_from;
+use syntax::ast;
 
 use crate::{
     db::HirDatabase,
@@ -50,7 +53,7 @@ pub struct TyLoweringContext<'a> {
     /// possible currently, so this should be fine for now.
     pub type_param_mode: TypeParamLoweringMode,
     pub impl_trait_mode: ImplTraitLoweringMode,
-    impl_trait_counter: std::cell::Cell<u16>,
+    impl_trait_counter: Cell<u16>,
     /// When turning `impl Trait` into opaque types, we have to collect the
     /// bounds at the same time to get the IDs correct (without becoming too
     /// complicated). I don't like using interior mutability (as for the
@@ -59,16 +62,17 @@ pub struct TyLoweringContext<'a> {
     /// we're grouping the mutable data (the counter and this field) together
     /// with the immutable context (the references to the DB and resolver).
     /// Splitting this up would be a possible fix.
-    opaque_type_data: std::cell::RefCell<Vec<ReturnTypeImplTrait>>,
+    opaque_type_data: RefCell<Vec<ReturnTypeImplTrait>>,
+    expander: RefCell<Option<Expander>>,
 }
 
 impl<'a> TyLoweringContext<'a> {
     pub fn new(db: &'a dyn HirDatabase, resolver: &'a Resolver) -> Self {
-        let impl_trait_counter = std::cell::Cell::new(0);
+        let impl_trait_counter = Cell::new(0);
         let impl_trait_mode = ImplTraitLoweringMode::Disallowed;
         let type_param_mode = TypeParamLoweringMode::Placeholder;
         let in_binders = DebruijnIndex::INNERMOST;
-        let opaque_type_data = std::cell::RefCell::new(Vec::new());
+        let opaque_type_data = RefCell::new(Vec::new());
         Self {
             db,
             resolver,
@@ -77,6 +81,7 @@ impl<'a> TyLoweringContext<'a> {
             impl_trait_counter,
             type_param_mode,
             opaque_type_data,
+            expander: RefCell::new(None),
         }
     }
 
@@ -86,15 +91,18 @@ impl<'a> TyLoweringContext<'a> {
         f: impl FnOnce(&TyLoweringContext) -> T,
     ) -> T {
         let opaque_ty_data_vec = self.opaque_type_data.replace(Vec::new());
+        let expander = self.expander.replace(None);
         let new_ctx = Self {
             in_binders: debruijn,
-            impl_trait_counter: std::cell::Cell::new(self.impl_trait_counter.get()),
-            opaque_type_data: std::cell::RefCell::new(opaque_ty_data_vec),
+            impl_trait_counter: Cell::new(self.impl_trait_counter.get()),
+            opaque_type_data: RefCell::new(opaque_ty_data_vec),
+            expander: RefCell::new(expander),
             ..*self
         };
         let result = f(&new_ctx);
         self.impl_trait_counter.set(new_ctx.impl_trait_counter.get());
         self.opaque_type_data.replace(new_ctx.opaque_type_data.into_inner());
+        self.expander.replace(new_ctx.expander.into_inner());
         result
     }
 
@@ -287,15 +295,50 @@ impl<'a> TyLoweringContext<'a> {
                     }
                 }
             }
-            mt @ TypeRef::Macro(_) => {
-                if let Some(module_id) = self.resolver.module() {
-                    match expand_macro_type(self.db.upcast(), module_id, mt) {
-                        Some(type_ref) => self.lower_ty(&type_ref),
-                        None => TyKind::Error.intern(&Interner),
+            TypeRef::Macro(macro_call) => {
+                let (expander, recursion_start) = match self.expander.borrow_mut() {
+                    expander if expander.is_some() => (Some(expander), false),
+                    mut expander => {
+                        if let Some(module_id) = self.resolver.module() {
+                            *expander = Some(Expander::new(
+                                self.db.upcast(),
+                                macro_call.file_id,
+                                module_id,
+                            ));
+                            (Some(expander), true)
+                        } else {
+                            (None, false)
+                        }
+                    }
+                };
+                let ty = if let Some(mut expander) = expander {
+                    let expander_mut = expander.as_mut().unwrap();
+                    let macro_call = macro_call.to_node(self.db.upcast());
+                    match expander_mut.enter_expand::<ast::Type>(self.db.upcast(), macro_call) {
+                        Ok(ExpandResult { value: Some((mark, expanded)), .. }) => {
+                            let ctx =
+                                LowerCtx::new(self.db.upcast(), expander_mut.current_file_id());
+                            let type_ref = TypeRef::from_ast(&ctx, expanded);
+
+                            drop(expander);
+                            let ty = self.lower_ty(&type_ref);
+
+                            self.expander
+                                .borrow_mut()
+                                .as_mut()
+                                .unwrap()
+                                .exit(self.db.upcast(), mark);
+                            Some(ty)
+                        }
+                        _ => None,
                     }
                 } else {
-                    TyKind::Error.intern(&Interner)
+                    None
+                };
+                if recursion_start {
+                    *self.expander.borrow_mut() = None;
                 }
+                ty.unwrap_or_else(|| TyKind::Error.intern(&Interner))
             }
             TypeRef::Error => TyKind::Error.intern(&Interner),
         };
diff --git a/crates/hir_ty/src/tests/macros.rs b/crates/hir_ty/src/tests/macros.rs
index cbe05a5c166..8de1e229f0d 100644
--- a/crates/hir_ty/src/tests/macros.rs
+++ b/crates/hir_ty/src/tests/macros.rs
@@ -1243,3 +1243,29 @@ fn macros_in_type_generics() {
         "#]],
     );
 }
+
+#[test]
+fn infinitely_recursive_macro_type() {
+    check_infer(
+        r#"
+        struct Bar<T>(T);
+
+        macro_rules! Foo {
+            () => { Foo!() }
+        }
+
+        type A = Foo!();
+        type B = Bar<Foo!()>;
+
+        fn main() {
+            let a: A;
+            let b: B;
+        }
+        "#,
+        expect![[r#"
+            112..143 '{     ...: B; }': ()
+            122..123 'a': {unknown}
+            136..137 'b': Bar<{unknown}>
+        "#]],
+    );
+}