about summary refs log tree commit diff
path: root/compiler/rustc_macros
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-02-13 23:49:28 +0000
committerMichael Goulet <michael@errs.io>2024-02-16 15:07:37 +0000
commit3250e953050be1a6867aac3745742f86ab406361 (patch)
tree6b2572c184c33e88b8d15dcde4cf94eb0db07e8a /compiler/rustc_macros
parentc9a7db6e20c8892f770b94dd6d5a16a03721b658 (diff)
downloadrust-3250e953050be1a6867aac3745742f86ab406361.tar.gz
rust-3250e953050be1a6867aac3745742f86ab406361.zip
Add a simple extension trait derive
Diffstat (limited to 'compiler/rustc_macros')
-rw-r--r--compiler/rustc_macros/src/extension.rs136
-rw-r--r--compiler/rustc_macros/src/lib.rs6
2 files changed, 142 insertions, 0 deletions
diff --git a/compiler/rustc_macros/src/extension.rs b/compiler/rustc_macros/src/extension.rs
new file mode 100644
index 00000000000..7bb07285ae2
--- /dev/null
+++ b/compiler/rustc_macros/src/extension.rs
@@ -0,0 +1,136 @@
+use proc_macro2::Ident;
+use quote::quote;
+use syn::parse::{Parse, ParseStream};
+use syn::punctuated::Punctuated;
+use syn::spanned::Spanned;
+use syn::{
+    braced, parse_macro_input, Attribute, Generics, ImplItem, Pat, PatIdent, Path, Signature,
+    Token, TraitItem, TraitItemConst, TraitItemFn, TraitItemMacro, TraitItemType, Type, Visibility,
+};
+
+pub(crate) fn extension(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+    // Parse the input tokens into a syntax tree
+    let Extension { attrs, generics, vis, trait_, self_ty, items } =
+        parse_macro_input!(input as Extension);
+    let headers: Vec<_> = items
+        .iter()
+        .map(|item| match item {
+            ImplItem::Fn(f) => TraitItem::Fn(TraitItemFn {
+                attrs: scrub_attrs(&f.attrs),
+                sig: scrub_header(f.sig.clone()),
+                default: None,
+                semi_token: Some(Token![;](f.block.span())),
+            }),
+            ImplItem::Const(ct) => TraitItem::Const(TraitItemConst {
+                attrs: scrub_attrs(&ct.attrs),
+                const_token: ct.const_token,
+                ident: ct.ident.clone(),
+                generics: ct.generics.clone(),
+                colon_token: ct.colon_token,
+                ty: ct.ty.clone(),
+                default: None,
+                semi_token: ct.semi_token,
+            }),
+            ImplItem::Type(ty) => TraitItem::Type(TraitItemType {
+                attrs: scrub_attrs(&ty.attrs),
+                type_token: ty.type_token,
+                ident: ty.ident.clone(),
+                generics: ty.generics.clone(),
+                colon_token: None,
+                bounds: Punctuated::new(),
+                default: None,
+                semi_token: ty.semi_token,
+            }),
+            ImplItem::Macro(mac) => TraitItem::Macro(TraitItemMacro {
+                attrs: scrub_attrs(&mac.attrs),
+                mac: mac.mac.clone(),
+                semi_token: mac.semi_token,
+            }),
+            ImplItem::Verbatim(stream) => TraitItem::Verbatim(stream.clone()),
+            _ => unimplemented!(),
+        })
+        .collect();
+
+    quote! {
+        #(#attrs)*
+        #vis trait #trait_ {
+            #(#headers)*
+        }
+
+        impl #generics #trait_ for #self_ty {
+            #(#items)*
+        }
+    }
+    .into()
+}
+
+/// Only keep `#[doc]` attrs.
+fn scrub_attrs(attrs: &[Attribute]) -> Vec<Attribute> {
+    attrs.into_iter().cloned().filter(|attr| attr.path().segments[0].ident == "doc").collect()
+}
+
+/// Scrub arguments so that they're valid for trait signatures.
+fn scrub_header(mut sig: Signature) -> Signature {
+    for (idx, input) in sig.inputs.iter_mut().enumerate() {
+        match input {
+            syn::FnArg::Receiver(rcvr) => {
+                // `mut self` -> `self`
+                if rcvr.reference.is_none() {
+                    rcvr.mutability.take();
+                }
+            }
+            syn::FnArg::Typed(arg) => match &mut *arg.pat {
+                Pat::Ident(arg) => {
+                    // `ref mut ident @ pat` -> `ident`
+                    arg.by_ref.take();
+                    arg.mutability.take();
+                    arg.subpat.take();
+                }
+                _ => {
+                    // `pat` -> `__arg0`
+                    arg.pat = Box::new(
+                        PatIdent {
+                            attrs: vec![],
+                            by_ref: None,
+                            mutability: None,
+                            ident: Ident::new(&format!("__arg{idx}"), arg.pat.span()),
+                            subpat: None,
+                        }
+                        .into(),
+                    )
+                }
+            },
+        }
+    }
+    sig
+}
+
+struct Extension {
+    attrs: Vec<Attribute>,
+    vis: Visibility,
+    generics: Generics,
+    trait_: Path,
+    self_ty: Type,
+    items: Vec<ImplItem>,
+}
+
+impl Parse for Extension {
+    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
+        let attrs = input.call(Attribute::parse_outer)?;
+        let vis = input.parse()?;
+        let _: Token![impl] = input.parse()?;
+        let generics = input.parse()?;
+        let trait_ = input.parse()?;
+        let _: Token![for] = input.parse()?;
+        let self_ty = input.parse()?;
+
+        let content;
+        let _brace_token = braced!(content in input);
+        let mut items = Vec::new();
+        while !content.is_empty() {
+            items.push(content.parse()?);
+        }
+
+        Ok(Extension { attrs, generics, vis, trait_, self_ty, items })
+    }
+}
diff --git a/compiler/rustc_macros/src/lib.rs b/compiler/rustc_macros/src/lib.rs
index af65c908ee6..841f5c06126 100644
--- a/compiler/rustc_macros/src/lib.rs
+++ b/compiler/rustc_macros/src/lib.rs
@@ -14,6 +14,7 @@ use proc_macro::TokenStream;
 
 mod current_version;
 mod diagnostics;
+mod extension;
 mod hash_stable;
 mod lift;
 mod query;
@@ -40,6 +41,11 @@ pub fn symbols(input: TokenStream) -> TokenStream {
     symbols::symbols(input.into()).into()
 }
 
+#[proc_macro_attribute]
+pub fn extension(_attr: TokenStream, input: TokenStream) -> TokenStream {
+    extension::extension(input)
+}
+
 decl_derive!([HashStable, attributes(stable_hasher)] => hash_stable::hash_stable_derive);
 decl_derive!(
     [HashStable_Generic, attributes(stable_hasher)] =>