about summary refs log tree commit diff
path: root/compiler/rustc_macros/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_macros/src')
-rw-r--r--compiler/rustc_macros/src/diagnostics/diagnostic_builder.rs39
-rw-r--r--compiler/rustc_macros/src/extension.rs154
-rw-r--r--compiler/rustc_macros/src/lib.rs6
3 files changed, 161 insertions, 38 deletions
diff --git a/compiler/rustc_macros/src/diagnostics/diagnostic_builder.rs b/compiler/rustc_macros/src/diagnostics/diagnostic_builder.rs
index 85bb9584a05..ae481efb263 100644
--- a/compiler/rustc_macros/src/diagnostics/diagnostic_builder.rs
+++ b/compiler/rustc_macros/src/diagnostics/diagnostic_builder.rs
@@ -331,44 +331,7 @@ impl DiagnosticDeriveVariantBuilder {
                 }
             }
             (Meta::Path(_), "subdiagnostic") => {
-                if FieldInnerTy::from_type(&info.binding.ast().ty).will_iterate() {
-                    let DiagnosticDeriveKind::Diagnostic = self.kind else {
-                        // No eager translation for lints.
-                        return Ok(quote! { diag.subdiagnostic(#binding); });
-                    };
-                    return Ok(quote! { diag.eager_subdiagnostic(dcx, #binding); });
-                } else {
-                    return Ok(quote! { diag.subdiagnostic(#binding); });
-                }
-            }
-            (Meta::List(meta_list), "subdiagnostic") => {
-                let err = || {
-                    span_err(
-                        meta_list.span().unwrap(),
-                        "`eager` is the only supported nested attribute for `subdiagnostic`",
-                    )
-                    .emit();
-                };
-
-                let Ok(p): Result<Path, _> = meta_list.parse_args() else {
-                    err();
-                    return Ok(quote! {});
-                };
-
-                if !p.is_ident("eager") {
-                    err();
-                    return Ok(quote! {});
-                }
-
-                match &self.kind {
-                    DiagnosticDeriveKind::Diagnostic => {}
-                    DiagnosticDeriveKind::LintDiagnostic => {
-                        throw_invalid_attr!(attr, |diag| {
-                            diag.help("eager subdiagnostics are not supported on lints")
-                        })
-                    }
-                };
-                return Ok(quote! { diag.eager_subdiagnostic(dcx, #binding); });
+                return Ok(quote! { diag.subdiagnostic(diag.dcx, #binding); });
             }
             _ => (),
         }
diff --git a/compiler/rustc_macros/src/extension.rs b/compiler/rustc_macros/src/extension.rs
new file mode 100644
index 00000000000..5377bbdfeab
--- /dev/null
+++ b/compiler/rustc_macros/src/extension.rs
@@ -0,0 +1,154 @@
+use proc_macro2::Ident;
+use quote::quote;
+use syn::parse::{Parse, ParseStream};
+use syn::punctuated::Punctuated;
+use syn::spanned::Spanned;
+use syn::{
+    braced, parse_macro_input, Attribute, Generics, ImplItem, Pat, PatIdent, Path, Signature,
+    Token, TraitItem, TraitItemConst, TraitItemFn, TraitItemMacro, TraitItemType, Type, Visibility,
+};
+
+pub(crate) fn extension(
+    attr: proc_macro::TokenStream,
+    input: proc_macro::TokenStream,
+) -> proc_macro::TokenStream {
+    let ExtensionAttr { vis, trait_ } = parse_macro_input!(attr as ExtensionAttr);
+    let Impl { attrs, generics, self_ty, items } = parse_macro_input!(input as Impl);
+    let headers: Vec<_> = items
+        .iter()
+        .map(|item| match item {
+            ImplItem::Fn(f) => TraitItem::Fn(TraitItemFn {
+                attrs: scrub_attrs(&f.attrs),
+                sig: scrub_header(f.sig.clone()),
+                default: None,
+                semi_token: Some(Token![;](f.block.span())),
+            }),
+            ImplItem::Const(ct) => TraitItem::Const(TraitItemConst {
+                attrs: scrub_attrs(&ct.attrs),
+                const_token: ct.const_token,
+                ident: ct.ident.clone(),
+                generics: ct.generics.clone(),
+                colon_token: ct.colon_token,
+                ty: ct.ty.clone(),
+                default: None,
+                semi_token: ct.semi_token,
+            }),
+            ImplItem::Type(ty) => TraitItem::Type(TraitItemType {
+                attrs: scrub_attrs(&ty.attrs),
+                type_token: ty.type_token,
+                ident: ty.ident.clone(),
+                generics: ty.generics.clone(),
+                colon_token: None,
+                bounds: Punctuated::new(),
+                default: None,
+                semi_token: ty.semi_token,
+            }),
+            ImplItem::Macro(mac) => TraitItem::Macro(TraitItemMacro {
+                attrs: scrub_attrs(&mac.attrs),
+                mac: mac.mac.clone(),
+                semi_token: mac.semi_token,
+            }),
+            ImplItem::Verbatim(stream) => TraitItem::Verbatim(stream.clone()),
+            _ => unimplemented!(),
+        })
+        .collect();
+
+    quote! {
+        #(#attrs)*
+        #vis trait #trait_ {
+            #(#headers)*
+        }
+
+        impl #generics #trait_ for #self_ty {
+            #(#items)*
+        }
+    }
+    .into()
+}
+
+/// Only keep `#[doc]` attrs.
+fn scrub_attrs(attrs: &[Attribute]) -> Vec<Attribute> {
+    attrs
+        .into_iter()
+        .cloned()
+        .filter(|attr| {
+            let ident = &attr.path().segments[0].ident;
+            ident == "doc" || ident == "must_use"
+        })
+        .collect()
+}
+
+/// Scrub arguments so that they're valid for trait signatures.
+fn scrub_header(mut sig: Signature) -> Signature {
+    for (idx, input) in sig.inputs.iter_mut().enumerate() {
+        match input {
+            syn::FnArg::Receiver(rcvr) => {
+                // `mut self` -> `self`
+                if rcvr.reference.is_none() {
+                    rcvr.mutability.take();
+                }
+            }
+            syn::FnArg::Typed(arg) => match &mut *arg.pat {
+                Pat::Ident(arg) => {
+                    // `ref mut ident @ pat` -> `ident`
+                    arg.by_ref.take();
+                    arg.mutability.take();
+                    arg.subpat.take();
+                }
+                _ => {
+                    // `pat` -> `__arg0`
+                    arg.pat = Box::new(
+                        PatIdent {
+                            attrs: vec![],
+                            by_ref: None,
+                            mutability: None,
+                            ident: Ident::new(&format!("__arg{idx}"), arg.pat.span()),
+                            subpat: None,
+                        }
+                        .into(),
+                    )
+                }
+            },
+        }
+    }
+    sig
+}
+
+struct ExtensionAttr {
+    vis: Visibility,
+    trait_: Path,
+}
+
+impl Parse for ExtensionAttr {
+    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
+        let vis = input.parse()?;
+        let _: Token![trait] = input.parse()?;
+        let trait_ = input.parse()?;
+        Ok(ExtensionAttr { vis, trait_ })
+    }
+}
+
+struct Impl {
+    attrs: Vec<Attribute>,
+    generics: Generics,
+    self_ty: Type,
+    items: Vec<ImplItem>,
+}
+
+impl Parse for Impl {
+    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
+        let attrs = input.call(Attribute::parse_outer)?;
+        let _: Token![impl] = input.parse()?;
+        let generics = input.parse()?;
+        let self_ty = input.parse()?;
+
+        let content;
+        let _brace_token = braced!(content in input);
+        let mut items = Vec::new();
+        while !content.is_empty() {
+            items.push(content.parse()?);
+        }
+
+        Ok(Impl { attrs, generics, self_ty, items })
+    }
+}
diff --git a/compiler/rustc_macros/src/lib.rs b/compiler/rustc_macros/src/lib.rs
index af65c908ee6..619f93c8a53 100644
--- a/compiler/rustc_macros/src/lib.rs
+++ b/compiler/rustc_macros/src/lib.rs
@@ -14,6 +14,7 @@ use proc_macro::TokenStream;
 
 mod current_version;
 mod diagnostics;
+mod extension;
 mod hash_stable;
 mod lift;
 mod query;
@@ -40,6 +41,11 @@ pub fn symbols(input: TokenStream) -> TokenStream {
     symbols::symbols(input.into()).into()
 }
 
+#[proc_macro_attribute]
+pub fn extension(attr: TokenStream, input: TokenStream) -> TokenStream {
+    extension::extension(attr, input)
+}
+
 decl_derive!([HashStable, attributes(stable_hasher)] => hash_stable::hash_stable_derive);
 decl_derive!(
     [HashStable_Generic, attributes(stable_hasher)] =>