about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir-def/src/nameres/attr_resolution.rs2
-rw-r--r--crates/hir-def/src/nameres/collector.rs28
-rw-r--r--crates/hir-expand/src/cfg_process.rs10
-rw-r--r--crates/hir-expand/src/db.rs65
-rw-r--r--crates/hir-expand/src/lib.rs3
-rw-r--r--crates/hir/src/lib.rs4
6 files changed, 80 insertions, 32 deletions
diff --git a/crates/hir-def/src/nameres/attr_resolution.rs b/crates/hir-def/src/nameres/attr_resolution.rs
index 662c80edf32..eb7f4c05ae2 100644
--- a/crates/hir-def/src/nameres/attr_resolution.rs
+++ b/crates/hir-def/src/nameres/attr_resolution.rs
@@ -136,6 +136,7 @@ pub(super) fn derive_macro_as_call_id(
     call_site: SyntaxContextId,
     krate: CrateId,
     resolver: impl Fn(path::ModPath) -> Option<(MacroId, MacroDefId)>,
+    derive_macro_id: MacroCallId,
 ) -> Result<(MacroId, MacroDefId, MacroCallId), UnresolvedMacro> {
     let (macro_id, def_id) = resolver(item_attr.path.clone())
         .filter(|(_, def_id)| def_id.is_derive())
@@ -147,6 +148,7 @@ pub(super) fn derive_macro_as_call_id(
             ast_id: item_attr.ast_id,
             derive_index: derive_pos,
             derive_attr_index,
+            derive_macro_id,
         },
         call_site,
     );
diff --git a/crates/hir-def/src/nameres/collector.rs b/crates/hir-def/src/nameres/collector.rs
index 3cfbf6c89cc..ef10b3d2da1 100644
--- a/crates/hir-def/src/nameres/collector.rs
+++ b/crates/hir-def/src/nameres/collector.rs
@@ -237,6 +237,8 @@ enum MacroDirectiveKind {
         derive_attr: AttrId,
         derive_pos: usize,
         ctxt: SyntaxContextId,
+        /// The "parent" macro it is resolved to.
+        derive_macro_id: MacroCallId,
     },
     Attr {
         ast_id: AstIdWithPath<ast::Item>,
@@ -1146,7 +1148,13 @@ impl DefCollector<'_> {
                         return Resolved::Yes;
                     }
                 }
-                MacroDirectiveKind::Derive { ast_id, derive_attr, derive_pos, ctxt: call_site } => {
+                MacroDirectiveKind::Derive {
+                    ast_id,
+                    derive_attr,
+                    derive_pos,
+                    ctxt: call_site,
+                    derive_macro_id,
+                } => {
                     let id = derive_macro_as_call_id(
                         self.db,
                         ast_id,
@@ -1155,6 +1163,7 @@ impl DefCollector<'_> {
                         *call_site,
                         self.def_map.krate,
                         resolver,
+                        *derive_macro_id,
                     );
 
                     if let Ok((macro_id, def_id, call_id)) = id {
@@ -1224,6 +1233,8 @@ impl DefCollector<'_> {
                         _ => return Resolved::No,
                     };
 
+                    let call_id =
+                        attr_macro_as_call_id(self.db, file_ast_id, attr, self.def_map.krate, def);
                     if let MacroDefId {
                         kind:
                             MacroDefKind::BuiltInAttr(
@@ -1252,6 +1263,7 @@ impl DefCollector<'_> {
                                 return recollect_without(self);
                             }
                         };
+
                         let ast_id = ast_id.with_value(ast_adt_id);
 
                         match attr.parse_path_comma_token_tree(self.db.upcast()) {
@@ -1267,6 +1279,7 @@ impl DefCollector<'_> {
                                             derive_attr: attr.id,
                                             derive_pos: idx,
                                             ctxt: call_site.ctx,
+                                            derive_macro_id: call_id,
                                         },
                                         container: directive.container,
                                     });
@@ -1301,10 +1314,6 @@ impl DefCollector<'_> {
                         return recollect_without(self);
                     }
 
-                    // Not resolved to a derive helper or the derive attribute, so try to treat as a normal attribute.
-                    let call_id =
-                        attr_macro_as_call_id(self.db, file_ast_id, attr, self.def_map.krate, def);
-
                     // Skip #[test]/#[bench] expansion, which would merely result in more memory usage
                     // due to duplicating functions into macro expansions
                     if matches!(
@@ -1460,13 +1469,20 @@ impl DefCollector<'_> {
                         ));
                     }
                 }
-                MacroDirectiveKind::Derive { ast_id, derive_attr, derive_pos, ctxt: _ } => {
+                MacroDirectiveKind::Derive {
+                    ast_id,
+                    derive_attr,
+                    derive_pos,
+                    derive_macro_id,
+                    ..
+                } => {
                     self.def_map.diagnostics.push(DefDiagnostic::unresolved_macro_call(
                         directive.module_id,
                         MacroCallKind::Derive {
                             ast_id: ast_id.ast_id,
                             derive_attr_index: *derive_attr,
                             derive_index: *derive_pos as u32,
+                            derive_macro_id: *derive_macro_id,
                         },
                         ast_id.path.clone(),
                     ));
diff --git a/crates/hir-expand/src/cfg_process.rs b/crates/hir-expand/src/cfg_process.rs
index 68de05fafe2..db3558a84e9 100644
--- a/crates/hir-expand/src/cfg_process.rs
+++ b/crates/hir-expand/src/cfg_process.rs
@@ -10,7 +10,7 @@ use syntax::{
 use tracing::{debug, warn};
 use tt::SmolStr;
 
-use crate::{db::ExpandDatabase, MacroCallKind, MacroCallLoc};
+use crate::{db::ExpandDatabase, proc_macro::ProcMacroKind, MacroCallLoc, MacroDefKind};
 
 fn check_cfg_attr(attr: &Attr, loc: &MacroCallLoc, db: &dyn ExpandDatabase) -> Option<bool> {
     if !attr.simple_name().as_deref().map(|v| v == "cfg")? {
@@ -180,7 +180,13 @@ pub(crate) fn process_cfg_attrs(
     db: &dyn ExpandDatabase,
 ) -> Option<FxHashSet<SyntaxElement>> {
     // FIXME: #[cfg_eval] is not implemented. But it is not stable yet
-    if !matches!(loc.kind, MacroCallKind::Derive { .. }) {
+    let is_derive = match loc.def.kind {
+        MacroDefKind::BuiltInDerive(..)
+        | MacroDefKind::ProcMacro(_, ProcMacroKind::CustomDerive, _) => true,
+        MacroDefKind::BuiltInAttr(expander, _) => expander.is_derive(),
+        _ => false,
+    };
+    if !is_derive {
         return None;
     }
     let mut remove = FxHashSet::default();
diff --git a/crates/hir-expand/src/db.rs b/crates/hir-expand/src/db.rs
index ec68f2f96e5..5461c1c49a3 100644
--- a/crates/hir-expand/src/db.rs
+++ b/crates/hir-expand/src/db.rs
@@ -24,7 +24,8 @@ use crate::{
     HirFileId, HirFileIdRepr, MacroCallId, MacroCallKind, MacroCallLoc, MacroDefId, MacroDefKind,
     MacroFileId,
 };
-
+/// This is just to ensure the types of smart_macro_arg and macro_arg are the same
+type MacroArgResult = (Arc<tt::Subtree>, SyntaxFixupUndoInfo, Span);
 /// Total limit on the number of tokens produced by any macro invocation.
 ///
 /// If an invocation produces more tokens than this limit, it will not be stored in the database and
@@ -98,7 +99,13 @@ pub trait ExpandDatabase: SourceDatabase {
     /// Lowers syntactic macro call to a token tree representation. That's a firewall
     /// query, only typing in the macro call itself changes the returned
     /// subtree.
-    fn macro_arg(&self, id: MacroCallId) -> (Arc<tt::Subtree>, SyntaxFixupUndoInfo, Span);
+    fn macro_arg(&self, id: MacroCallId) -> MacroArgResult;
+    #[salsa::transparent]
+    fn macro_arg_considering_derives(
+        &self,
+        id: MacroCallId,
+        kind: &MacroCallKind,
+    ) -> MacroArgResult;
     /// Fetches the expander for this macro.
     #[salsa::transparent]
     #[salsa::invoke(TokenExpander::macro_expander)]
@@ -144,7 +151,7 @@ pub fn expand_speculative(
     let span_map = RealSpanMap::absolute(FileId::BOGUS);
     let span_map = SpanMapRef::RealSpanMap(&span_map);
 
-    let (_, _, span) = db.macro_arg(actual_macro_call);
+    let (_, _, span) = db.macro_arg_considering_derives(actual_macro_call, &loc.kind);
 
     // Build the subtree and token mapping for the speculative args
     let (mut tt, undo_info) = match loc.kind {
@@ -339,12 +346,24 @@ pub(crate) fn parse_with_map(
     }
 }
 
-// FIXME: for derive attributes, this will return separate copies of the same structures! Though
-// they may differ in spans due to differing call sites...
-fn macro_arg(
+/// This resolves the [MacroCallId] to check if it is a derive macro if so get the [macro_arg] for the derive.
+/// Other wise return the [macro_arg] for the macro_call_id.
+///
+/// This is not connected to the database so it does not cached the result. However, the inner [macro_arg] query is
+fn macro_arg_considering_derives(
     db: &dyn ExpandDatabase,
     id: MacroCallId,
-) -> (Arc<tt::Subtree>, SyntaxFixupUndoInfo, Span) {
+    kind: &MacroCallKind,
+) -> MacroArgResult {
+    match kind {
+        // Get the macro arg for the derive macro
+        MacroCallKind::Derive { derive_macro_id, .. } => db.macro_arg(*derive_macro_id),
+        // Normal macro arg
+        _ => db.macro_arg(id),
+    }
+}
+
+fn macro_arg(db: &dyn ExpandDatabase, id: MacroCallId) -> MacroArgResult {
     let loc = db.lookup_intern_macro_call(id);
 
     if let MacroCallLoc {
@@ -414,29 +433,30 @@ fn macro_arg(
             }
             return (Arc::new(tt), SyntaxFixupUndoInfo::NONE, span);
         }
-        MacroCallKind::Derive { ast_id, derive_attr_index, .. } => {
-            let node = ast_id.to_ptr(db).to_node(&root);
-            let censor_derive_input = censor_derive_input(derive_attr_index, &node);
-            let item_node = node.into();
-            let attr_source = attr_source(derive_attr_index, &item_node);
-            // FIXME: This is wrong, this should point to the path of the derive attribute`
-            let span =
-                map.span_for_range(attr_source.as_ref().and_then(|it| it.path()).map_or_else(
-                    || item_node.syntax().text_range(),
-                    |it| it.syntax().text_range(),
-                ));
-            (censor_derive_input, item_node, span)
+        // MacroCallKind::Derive should not be here. As we are getting the argument for the derive macro
+        MacroCallKind::Derive { .. } => {
+            unreachable!("`ExpandDatabase::macro_arg` called with `MacroCallKind::Derive`")
         }
         MacroCallKind::Attr { ast_id, invoc_attr_index, .. } => {
             let node = ast_id.to_ptr(db).to_node(&root);
             let attr_source = attr_source(invoc_attr_index, &node);
+
             let span = map.span_for_range(
                 attr_source
                     .as_ref()
                     .and_then(|it| it.path())
                     .map_or_else(|| node.syntax().text_range(), |it| it.syntax().text_range()),
             );
-            (attr_source.into_iter().map(|it| it.syntax().clone().into()).collect(), node, span)
+            // If derive attribute we need to censor the derive input
+            if matches!(loc.def.kind, MacroDefKind::BuiltInAttr(expander, ..) if expander.is_derive())
+                && ast::Adt::can_cast(node.syntax().kind())
+            {
+                let adt = ast::Adt::cast(node.syntax().clone()).unwrap();
+                let censor_derive_input = censor_derive_input(invoc_attr_index, &adt);
+                (censor_derive_input, node, span)
+            } else {
+                (attr_source.into_iter().map(|it| it.syntax().clone().into()).collect(), node, span)
+            }
         }
     };
 
@@ -526,7 +546,8 @@ fn macro_expand(
     let (ExpandResult { value: tt, err }, span) = match loc.def.kind {
         MacroDefKind::ProcMacro(..) => return db.expand_proc_macro(macro_call_id).map(CowArc::Arc),
         _ => {
-            let (macro_arg, undo_info, span) = db.macro_arg(macro_call_id);
+            let (macro_arg, undo_info, span) =
+                db.macro_arg_considering_derives(macro_call_id, &loc.kind);
 
             let arg = &*macro_arg;
             let res =
@@ -603,7 +624,7 @@ fn proc_macro_span(db: &dyn ExpandDatabase, ast: AstId<ast::Fn>) -> Span {
 
 fn expand_proc_macro(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult<Arc<tt::Subtree>> {
     let loc = db.lookup_intern_macro_call(id);
-    let (macro_arg, undo_info, span) = db.macro_arg(id);
+    let (macro_arg, undo_info, span) = db.macro_arg_considering_derives(id, &loc.kind);
 
     let (expander, ast) = match loc.def.kind {
         MacroDefKind::ProcMacro(expander, _, ast) => (expander, ast),
diff --git a/crates/hir-expand/src/lib.rs b/crates/hir-expand/src/lib.rs
index 2233ae51fa8..e7a313c68cd 100644
--- a/crates/hir-expand/src/lib.rs
+++ b/crates/hir-expand/src/lib.rs
@@ -224,6 +224,9 @@ pub enum MacroCallKind {
         derive_attr_index: AttrId,
         /// Index of the derive macro in the derive attribute
         derive_index: u32,
+        /// The "parent" macro call.
+        /// We will resolve the same token tree for all derive macros in the same derive attribute.
+        derive_macro_id: MacroCallId,
     },
     Attr {
         ast_id: AstId<ast::Item>,
diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs
index 321ff1eb395..13949fa78fb 100644
--- a/crates/hir/src/lib.rs
+++ b/crates/hir/src/lib.rs
@@ -972,7 +972,7 @@ fn precise_macro_call_location(
                 MacroKind::ProcMacro,
             )
         }
-        MacroCallKind::Derive { ast_id, derive_attr_index, derive_index } => {
+        MacroCallKind::Derive { ast_id, derive_attr_index, derive_index, .. } => {
             let node = ast_id.to_node(db.upcast());
             // Compute the precise location of the macro name's token in the derive
             // list.
@@ -3708,7 +3708,7 @@ impl Impl {
         let macro_file = src.file_id.macro_file()?;
         let loc = macro_file.macro_call_id.lookup(db.upcast());
         let (derive_attr, derive_index) = match loc.kind {
-            MacroCallKind::Derive { ast_id, derive_attr_index, derive_index } => {
+            MacroCallKind::Derive { ast_id, derive_attr_index, derive_index, .. } => {
                 let module_id = self.id.lookup(db.upcast()).container;
                 (
                     db.crate_def_map(module_id.krate())[module_id.local_id]