about summary refs log tree commit diff
path: root/compiler/rustc_expand/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_expand/src')
-rw-r--r--compiler/rustc_expand/src/base.rs40
-rw-r--r--compiler/rustc_expand/src/expand.rs11
-rw-r--r--compiler/rustc_expand/src/mbe/macro_check.rs11
-rw-r--r--compiler/rustc_expand/src/mbe/macro_rules.rs86
4 files changed, 102 insertions, 46 deletions
diff --git a/compiler/rustc_expand/src/base.rs b/compiler/rustc_expand/src/base.rs
index 3956125bace..810a5a21a05 100644
--- a/compiler/rustc_expand/src/base.rs
+++ b/compiler/rustc_expand/src/base.rs
@@ -10,7 +10,7 @@ use rustc_ast::attr::{AttributeExt, MarkedAttrs};
 use rustc_ast::token::MetaVarKind;
 use rustc_ast::tokenstream::TokenStream;
 use rustc_ast::visit::{AssocCtxt, Visitor};
-use rustc_ast::{self as ast, AttrVec, Attribute, HasAttrs, Item, NodeId, PatKind};
+use rustc_ast::{self as ast, AttrVec, Attribute, HasAttrs, Item, NodeId, PatKind, Safety};
 use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
 use rustc_data_structures::sync;
 use rustc_errors::{BufferedEarlyLint, DiagCtxtHandle, ErrorGuaranteed, PResult};
@@ -324,16 +324,16 @@ pub trait BangProcMacro {
 
 impl<F> BangProcMacro for F
 where
-    F: Fn(TokenStream) -> TokenStream,
+    F: Fn(&mut ExtCtxt<'_>, Span, TokenStream) -> Result<TokenStream, ErrorGuaranteed>,
 {
     fn expand<'cx>(
         &self,
-        _ecx: &'cx mut ExtCtxt<'_>,
-        _span: Span,
+        ecx: &'cx mut ExtCtxt<'_>,
+        span: Span,
         ts: TokenStream,
     ) -> Result<TokenStream, ErrorGuaranteed> {
         // FIXME setup implicit context in TLS before calling self.
-        Ok(self(ts))
+        self(ecx, span, ts)
     }
 }
 
@@ -345,6 +345,21 @@ pub trait AttrProcMacro {
         annotation: TokenStream,
         annotated: TokenStream,
     ) -> Result<TokenStream, ErrorGuaranteed>;
+
+    // Default implementation for safe attributes; override if the attribute can be unsafe.
+    fn expand_with_safety<'cx>(
+        &self,
+        ecx: &'cx mut ExtCtxt<'_>,
+        safety: Safety,
+        span: Span,
+        annotation: TokenStream,
+        annotated: TokenStream,
+    ) -> Result<TokenStream, ErrorGuaranteed> {
+        if let Safety::Unsafe(span) = safety {
+            ecx.dcx().span_err(span, "unnecessary `unsafe` on safe attribute");
+        }
+        self.expand(ecx, span, annotation, annotated)
+    }
 }
 
 impl<F> AttrProcMacro for F
@@ -999,17 +1014,14 @@ impl SyntaxExtension {
 
     /// A dummy bang macro `foo!()`.
     pub fn dummy_bang(edition: Edition) -> SyntaxExtension {
-        fn expander<'cx>(
-            cx: &'cx mut ExtCtxt<'_>,
+        fn expand(
+            ecx: &mut ExtCtxt<'_>,
             span: Span,
-            _: TokenStream,
-        ) -> MacroExpanderResult<'cx> {
-            ExpandResult::Ready(DummyResult::any(
-                span,
-                cx.dcx().span_delayed_bug(span, "expanded a dummy bang macro"),
-            ))
+            _ts: TokenStream,
+        ) -> Result<TokenStream, ErrorGuaranteed> {
+            Err(ecx.dcx().span_delayed_bug(span, "expanded a dummy bang macro"))
         }
-        SyntaxExtension::default(SyntaxExtensionKind::LegacyBang(Arc::new(expander)), edition)
+        SyntaxExtension::default(SyntaxExtensionKind::Bang(Arc::new(expand)), edition)
     }
 
     /// A dummy derive macro `#[derive(Foo)]`.
diff --git a/compiler/rustc_expand/src/expand.rs b/compiler/rustc_expand/src/expand.rs
index 4c0e0bbfe26..3dfa3cdcc35 100644
--- a/compiler/rustc_expand/src/expand.rs
+++ b/compiler/rustc_expand/src/expand.rs
@@ -812,11 +812,12 @@ impl<'a, 'b> MacroExpander<'a, 'b> {
                         _ => item.to_tokens(),
                     };
                     let attr_item = attr.get_normal_item();
+                    let safety = attr_item.unsafety;
                     if let AttrArgs::Eq { .. } = attr_item.args {
                         self.cx.dcx().emit_err(UnsupportedKeyValue { span });
                     }
                     let inner_tokens = attr_item.args.inner_tokens();
-                    match expander.expand(self.cx, span, inner_tokens, tokens) {
+                    match expander.expand_with_safety(self.cx, safety, span, inner_tokens, tokens) {
                         Ok(tok_result) => {
                             let fragment = self.parse_ast_fragment(
                                 tok_result,
@@ -840,6 +841,9 @@ impl<'a, 'b> MacroExpander<'a, 'b> {
                         Err(guar) => return ExpandResult::Ready(fragment_kind.dummy(span, guar)),
                     }
                 } else if let SyntaxExtensionKind::LegacyAttr(expander) = ext {
+                    // `LegacyAttr` is only used for builtin attribute macros, which have their
+                    // safety checked by `check_builtin_meta_item`, so we don't need to check
+                    // `unsafety` here.
                     match validate_attr::parse_meta(&self.cx.sess.psess, &attr) {
                         Ok(meta) => {
                             let item_clone = macro_stats.then(|| item.clone());
@@ -882,6 +886,9 @@ impl<'a, 'b> MacroExpander<'a, 'b> {
                         }
                     }
                 } else if let SyntaxExtensionKind::NonMacroAttr = ext {
+                    if let ast::Safety::Unsafe(span) = attr.get_normal_item().unsafety {
+                        self.cx.dcx().span_err(span, "unnecessary `unsafe` on safe attribute");
+                    }
                     // `-Zmacro-stats` ignores these because they don't do any real expansion.
                     self.cx.expanded_inert_attrs.mark(&attr);
                     item.visit_attrs(|attrs| attrs.insert(pos, attr));
@@ -971,7 +978,7 @@ impl<'a, 'b> MacroExpander<'a, 'b> {
                             });
                         }
                     },
-                    SyntaxExtensionKind::LegacyBang(..) => {
+                    SyntaxExtensionKind::Bang(..) => {
                         let msg = "expanded a dummy glob delegation";
                         let guar = self.cx.dcx().span_delayed_bug(span, msg);
                         return ExpandResult::Ready(fragment_kind.dummy(span, guar));
diff --git a/compiler/rustc_expand/src/mbe/macro_check.rs b/compiler/rustc_expand/src/mbe/macro_check.rs
index ebd6e887f7d..0eae44a05e7 100644
--- a/compiler/rustc_expand/src/mbe/macro_check.rs
+++ b/compiler/rustc_expand/src/mbe/macro_check.rs
@@ -210,8 +210,7 @@ pub(super) fn check_meta_variables(
     guar.map_or(Ok(()), Err)
 }
 
-/// Checks `lhs` as part of the LHS of a macro definition, extends `binders` with new binders, and
-/// sets `valid` to false in case of errors.
+/// Checks `lhs` as part of the LHS of a macro definition.
 ///
 /// Arguments:
 /// - `psess` is used to emit diagnostics and lints
@@ -306,8 +305,7 @@ fn get_binder_info<'a>(
     binders.get(&name).or_else(|| macros.find_map(|state| state.binders.get(&name)))
 }
 
-/// Checks `rhs` as part of the RHS of a macro definition and sets `valid` to false in case of
-/// errors.
+/// Checks `rhs` as part of the RHS of a macro definition.
 ///
 /// Arguments:
 /// - `psess` is used to emit diagnostics and lints
@@ -372,7 +370,7 @@ enum NestedMacroState {
 }
 
 /// Checks `tts` as part of the RHS of a macro definition, tries to recognize nested macro
-/// definitions, and sets `valid` to false in case of errors.
+/// definitions.
 ///
 /// Arguments:
 /// - `psess` is used to emit diagnostics and lints
@@ -491,8 +489,7 @@ fn check_nested_occurrences(
     }
 }
 
-/// Checks the body of nested macro, returns where the check stopped, and sets `valid` to false in
-/// case of errors.
+/// Checks the body of nested macro, returns where the check stopped.
 ///
 /// The token trees are checked as long as they look like a list of (LHS) => {RHS} token trees. This
 /// check is a best-effort to detect a macro definition. It returns the position in `tts` where we
diff --git a/compiler/rustc_expand/src/mbe/macro_rules.rs b/compiler/rustc_expand/src/mbe/macro_rules.rs
index 1d147a0385c..c548cea537f 100644
--- a/compiler/rustc_expand/src/mbe/macro_rules.rs
+++ b/compiler/rustc_expand/src/mbe/macro_rules.rs
@@ -8,7 +8,7 @@ use rustc_ast::token::NtPatKind::*;
 use rustc_ast::token::TokenKind::*;
 use rustc_ast::token::{self, Delimiter, NonterminalKind, Token, TokenKind};
 use rustc_ast::tokenstream::{self, DelimSpan, TokenStream};
-use rustc_ast::{self as ast, DUMMY_NODE_ID, NodeId};
+use rustc_ast::{self as ast, DUMMY_NODE_ID, NodeId, Safety};
 use rustc_ast_pretty::pprust;
 use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
 use rustc_errors::{Applicability, Diag, ErrorGuaranteed, MultiSpan};
@@ -33,8 +33,8 @@ use super::diagnostics::{FailedMacro, failed_to_match_macro};
 use super::macro_parser::{NamedMatches, NamedParseResult};
 use super::{SequenceRepetition, diagnostics};
 use crate::base::{
-    AttrProcMacro, DummyResult, ExpandResult, ExtCtxt, MacResult, MacroExpanderResult,
-    SyntaxExtension, SyntaxExtensionKind, TTMacroExpander,
+    AttrProcMacro, BangProcMacro, DummyResult, ExpandResult, ExtCtxt, MacResult,
+    MacroExpanderResult, SyntaxExtension, SyntaxExtensionKind, TTMacroExpander,
 };
 use crate::errors;
 use crate::expand::{AstFragment, AstFragmentKind, ensure_complete_parse, parse_ast_fragment};
@@ -131,6 +131,7 @@ pub(super) enum MacroRule {
     Func { lhs: Vec<MatcherLoc>, lhs_span: Span, rhs: mbe::TokenTree },
     /// An attr rule, for use with `#[m]`
     Attr {
+        unsafe_rule: bool,
         args: Vec<MatcherLoc>,
         args_span: Span,
         body: Vec<MatcherLoc>,
@@ -248,7 +249,18 @@ impl TTMacroExpander for MacroRulesMacroExpander {
 impl AttrProcMacro for MacroRulesMacroExpander {
     fn expand(
         &self,
+        _cx: &mut ExtCtxt<'_>,
+        _sp: Span,
+        _args: TokenStream,
+        _body: TokenStream,
+    ) -> Result<TokenStream, ErrorGuaranteed> {
+        unreachable!("`expand` called on `MacroRulesMacroExpander`, expected `expand_with_safety`")
+    }
+
+    fn expand_with_safety(
+        &self,
         cx: &mut ExtCtxt<'_>,
+        safety: Safety,
         sp: Span,
         args: TokenStream,
         body: TokenStream,
@@ -260,6 +272,7 @@ impl AttrProcMacro for MacroRulesMacroExpander {
             self.node_id,
             self.name,
             self.transparency,
+            safety,
             args,
             body,
             &self.rules,
@@ -267,16 +280,16 @@ impl AttrProcMacro for MacroRulesMacroExpander {
     }
 }
 
-struct DummyExpander(ErrorGuaranteed);
+struct DummyBang(ErrorGuaranteed);
 
-impl TTMacroExpander for DummyExpander {
+impl BangProcMacro for DummyBang {
     fn expand<'cx>(
         &self,
         _: &'cx mut ExtCtxt<'_>,
-        span: Span,
+        _: Span,
         _: TokenStream,
-    ) -> ExpandResult<Box<dyn MacResult + 'cx>, ()> {
-        ExpandResult::Ready(DummyResult::any(span, self.0))
+    ) -> Result<TokenStream, ErrorGuaranteed> {
+        Err(self.0)
     }
 }
 
@@ -408,6 +421,7 @@ fn expand_macro_attr(
     node_id: NodeId,
     name: Ident,
     transparency: Transparency,
+    safety: Safety,
     args: TokenStream,
     body: TokenStream,
     rules: &[MacroRule],
@@ -429,13 +443,26 @@ fn expand_macro_attr(
     // Track nothing for the best performance.
     match try_match_macro_attr(psess, name, &args, &body, rules, &mut NoopTracker) {
         Ok((i, rule, named_matches)) => {
-            let MacroRule::Attr { rhs, .. } = rule else {
+            let MacroRule::Attr { rhs, unsafe_rule, .. } = rule else {
                 panic!("try_macro_match_attr returned non-attr rule");
             };
             let mbe::TokenTree::Delimited(rhs_span, _, rhs) = rhs else {
                 cx.dcx().span_bug(sp, "malformed macro rhs");
             };
 
+            match (safety, unsafe_rule) {
+                (Safety::Default, false) | (Safety::Unsafe(_), true) => {}
+                (Safety::Default, true) => {
+                    cx.dcx().span_err(sp, "unsafe attribute invocation requires `unsafe`");
+                }
+                (Safety::Unsafe(span), false) => {
+                    cx.dcx().span_err(span, "unnecessary `unsafe` on safe attribute invocation");
+                }
+                (Safety::Safe(span), _) => {
+                    cx.dcx().span_bug(span, "unexpected `safe` keyword");
+                }
+            }
+
             let id = cx.current_expansion.id;
             let tts = transcribe(psess, &named_matches, rhs, *rhs_span, transparency, id)
                 .map_err(|e| e.emit())?;
@@ -664,7 +691,7 @@ pub fn compile_declarative_macro(
         SyntaxExtension::new(sess, kind, span, Vec::new(), edition, ident.name, attrs, is_local)
     };
     let dummy_syn_ext =
-        |guar| (mk_syn_ext(SyntaxExtensionKind::LegacyBang(Arc::new(DummyExpander(guar)))), 0);
+        |guar| (mk_syn_ext(SyntaxExtensionKind::Bang(Arc::new(DummyBang(guar)))), 0);
 
     let macro_rules = macro_def.macro_rules;
     let exp_sep = if macro_rules { exp!(Semi) } else { exp!(Comma) };
@@ -681,6 +708,11 @@ pub fn compile_declarative_macro(
     let mut rules = Vec::new();
 
     while p.token != token::Eof {
+        let unsafe_rule = p.eat_keyword_noexpect(kw::Unsafe);
+        let unsafe_keyword_span = p.prev_token.span;
+        if unsafe_rule && let Some(guar) = check_no_eof(sess, &p, "expected `attr`") {
+            return dummy_syn_ext(guar);
+        }
         let (args, is_derive) = if p.eat_keyword_noexpect(sym::attr) {
             kinds |= MacroKinds::ATTR;
             if !features.macro_attr() {
@@ -705,6 +737,10 @@ pub fn compile_declarative_macro(
                 feature_err(sess, sym::macro_derive, span, "`macro_rules!` derives are unstable")
                     .emit();
             }
+            if unsafe_rule {
+                sess.dcx()
+                    .span_err(unsafe_keyword_span, "`unsafe` is only supported on `attr` rules");
+            }
             if let Some(guar) = check_no_eof(sess, &p, "expected `()` after `derive`") {
                 return dummy_syn_ext(guar);
             }
@@ -730,6 +766,10 @@ pub fn compile_declarative_macro(
             (None, true)
         } else {
             kinds |= MacroKinds::BANG;
+            if unsafe_rule {
+                sess.dcx()
+                    .span_err(unsafe_keyword_span, "`unsafe` is only supported on `attr` rules");
+            }
             (None, false)
         };
         let lhs_tt = p.parse_token_tree();
@@ -741,10 +781,10 @@ pub fn compile_declarative_macro(
         if let Some(guar) = check_no_eof(sess, &p, "expected right-hand side of macro rule") {
             return dummy_syn_ext(guar);
         }
-        let rhs_tt = p.parse_token_tree();
-        let rhs_tt = parse_one_tt(rhs_tt, RulePart::Body, sess, node_id, features, edition);
-        check_emission(check_rhs(sess, &rhs_tt));
-        check_emission(check_meta_variables(&sess.psess, node_id, args.as_ref(), &lhs_tt, &rhs_tt));
+        let rhs = p.parse_token_tree();
+        let rhs = parse_one_tt(rhs, RulePart::Body, sess, node_id, features, edition);
+        check_emission(check_rhs(sess, &rhs));
+        check_emission(check_meta_variables(&sess.psess, node_id, args.as_ref(), &lhs_tt, &rhs));
         let lhs_span = lhs_tt.span();
         // Convert the lhs into `MatcherLoc` form, which is better for doing the
         // actual matching.
@@ -760,11 +800,11 @@ pub fn compile_declarative_macro(
             };
             let args = mbe::macro_parser::compute_locs(&delimited.tts);
             let body_span = lhs_span;
-            rules.push(MacroRule::Attr { args, args_span, body: lhs, body_span, rhs: rhs_tt });
+            rules.push(MacroRule::Attr { unsafe_rule, args, args_span, body: lhs, body_span, rhs });
         } else if is_derive {
-            rules.push(MacroRule::Derive { body: lhs, body_span: lhs_span, rhs: rhs_tt });
+            rules.push(MacroRule::Derive { body: lhs, body_span: lhs_span, rhs });
         } else {
-            rules.push(MacroRule::Func { lhs, lhs_span, rhs: rhs_tt });
+            rules.push(MacroRule::Func { lhs, lhs_span, rhs });
         }
         if p.token == token::Eof {
             break;
@@ -894,12 +934,12 @@ fn check_redundant_vis_repetition(
     seq: &SequenceRepetition,
     span: &DelimSpan,
 ) {
-    let is_zero_or_one: bool = seq.kleene.op == KleeneOp::ZeroOrOne;
-    let is_vis = seq.tts.first().map_or(false, |tt| {
-        matches!(tt, mbe::TokenTree::MetaVarDecl { kind: NonterminalKind::Vis, .. })
-    });
-
-    if is_vis && is_zero_or_one {
+    if seq.kleene.op == KleeneOp::ZeroOrOne
+        && matches!(
+            seq.tts.first(),
+            Some(mbe::TokenTree::MetaVarDecl { kind: NonterminalKind::Vis, .. })
+        )
+    {
         err.note("a `vis` fragment can already be empty");
         err.multipart_suggestion(
             "remove the `$(` and `)?`",