about summary refs log tree commit diff
path: root/compiler/rustc_macros/src
diff options
context:
space:
mode:
authorAlan Egerton <eggyal@gmail.com>2022-06-17 10:53:29 +0100
committerAlan Egerton <eggyal@gmail.com>2022-07-05 22:25:15 +0100
commite4b9625b87c4d3b5845ea5660cf84739de224192 (patch)
tree932111ee5ad67cb91e90508f290f5649b62cd9cf /compiler/rustc_macros/src
parentbca894909cdb49b510dc6f60edbfa1465a948d09 (diff)
downloadrust-e4b9625b87c4d3b5845ea5660cf84739de224192.tar.gz
rust-e4b9625b87c4d3b5845ea5660cf84739de224192.zip
Add #[derive(TypeVisitable)]
Diffstat (limited to 'compiler/rustc_macros/src')
-rw-r--r--compiler/rustc_macros/src/lib.rs2
-rw-r--r--compiler/rustc_macros/src/type_foldable.rs13
-rw-r--r--compiler/rustc_macros/src/type_visitable.rs33
3 files changed, 35 insertions, 13 deletions
diff --git a/compiler/rustc_macros/src/lib.rs b/compiler/rustc_macros/src/lib.rs
index 7c8e3c6d140..ab1d6a439cd 100644
--- a/compiler/rustc_macros/src/lib.rs
+++ b/compiler/rustc_macros/src/lib.rs
@@ -18,6 +18,7 @@ mod query;
 mod serialize;
 mod symbols;
 mod type_foldable;
+mod type_visitable;
 
 #[proc_macro]
 pub fn rustc_queries(input: TokenStream) -> TokenStream {
@@ -121,6 +122,7 @@ decl_derive!([TyEncodable] => serialize::type_encodable_derive);
 decl_derive!([MetadataDecodable] => serialize::meta_decodable_derive);
 decl_derive!([MetadataEncodable] => serialize::meta_encodable_derive);
 decl_derive!([TypeFoldable, attributes(type_foldable)] => type_foldable::type_foldable_derive);
+decl_derive!([TypeVisitable, attributes(type_visitable)] => type_visitable::type_visitable_derive);
 decl_derive!([Lift, attributes(lift)] => lift::lift_derive);
 decl_derive!(
     [SessionDiagnostic, attributes(
diff --git a/compiler/rustc_macros/src/type_foldable.rs b/compiler/rustc_macros/src/type_foldable.rs
index 9e834d3ba1c..23e619221aa 100644
--- a/compiler/rustc_macros/src/type_foldable.rs
+++ b/compiler/rustc_macros/src/type_foldable.rs
@@ -11,11 +11,6 @@ pub fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::
     }
 
     s.add_bounds(synstructure::AddBounds::Generics);
-    let body_visit = s.each(|bind| {
-        quote! {
-            ::rustc_middle::ty::fold::TypeFoldable::visit_with(#bind, __folder)?;
-        }
-    });
     s.bind_with(|_| synstructure::BindStyle::Move);
     let body_fold = s.each_variant(|vi| {
         let bindings = vi.bindings();
@@ -36,14 +31,6 @@ pub fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::
             ) -> Result<Self, __F::Error> {
                 Ok(match self { #body_fold })
             }
-
-            fn visit_with<__F: ::rustc_middle::ty::fold::TypeVisitor<'tcx>>(
-                &self,
-                __folder: &mut __F
-            ) -> ::std::ops::ControlFlow<__F::BreakTy> {
-                match *self { #body_visit }
-                ::std::ops::ControlFlow::CONTINUE
-            }
         },
     )
 }
diff --git a/compiler/rustc_macros/src/type_visitable.rs b/compiler/rustc_macros/src/type_visitable.rs
new file mode 100644
index 00000000000..14e6aa6e0c1
--- /dev/null
+++ b/compiler/rustc_macros/src/type_visitable.rs
@@ -0,0 +1,33 @@
+use quote::quote;
+use syn::parse_quote;
+
+pub fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
+    if let syn::Data::Union(_) = s.ast().data {
+        panic!("cannot derive on union")
+    }
+
+    if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
+        s.add_impl_generic(parse_quote! { 'tcx });
+    }
+
+    s.add_bounds(synstructure::AddBounds::Generics);
+    let body_visit = s.each(|bind| {
+        quote! {
+            ::rustc_middle::ty::visit::TypeVisitable::visit_with(#bind, __visitor)?;
+        }
+    });
+    s.bind_with(|_| synstructure::BindStyle::Move);
+
+    s.bound_impl(
+        quote!(::rustc_middle::ty::visit::TypeVisitable<'tcx>),
+        quote! {
+            fn visit_with<__V: ::rustc_middle::ty::visit::TypeVisitor<'tcx>>(
+                &self,
+                __visitor: &mut __V
+            ) -> ::std::ops::ControlFlow<__V::BreakTy> {
+                match *self { #body_visit }
+                ::std::ops::ControlFlow::CONTINUE
+            }
+        },
+    )
+}