diff options
| author | bors <bors@rust-lang.org> | 2015-03-26 13:38:41 +0000 |
|---|---|---|
| committer | bors <bors@rust-lang.org> | 2015-03-26 13:38:41 +0000 |
| commit | 557d4346a26266d2eb13f6b0adf106b8873b0da1 (patch) | |
| tree | d090541009a2400121f5bac3bca1002598eade28 /src/libsyntax/ext | |
| parent | 1501f33e76f6f9621aa08fb0cbbc5f85a5ac7f0f (diff) | |
| parent | 9cabe273d3adb06a19f63460deda96ae224b28bf (diff) | |
| download | rust-557d4346a26266d2eb13f6b0adf106b8873b0da1.tar.gz rust-557d4346a26266d2eb13f6b0adf106b8873b0da1.zip | |
Auto merge of #21237 - erickt:derive-assoc-types, r=erickt
This PR adds support for associated types to the `#[derive(...)]` syntax extension. In order to do this, it switches over to using where predicates to apply the type constraints. So now this:
```rust
type Trait {
type Type;
}
#[derive(Clone)]
struct Foo<A> where A: Trait {
a: A,
b: <A as Trait>::Type,
}
```
Gets expended into this impl:
```rust
impl<A: Clone> Clone for Foo<A> where
A: Trait,
<A as Trait>::Type: Clone,
{
fn clone(&self) -> Foo<T> {
Foo {
a: self.a.clone(),
b: self.b.clone(),
}
}
}
```
Diffstat (limited to 'src/libsyntax/ext')
| -rw-r--r-- | src/libsyntax/ext/deriving/generic/mod.rs | 124 |
1 files changed, 118 insertions, 6 deletions
diff --git a/src/libsyntax/ext/deriving/generic/mod.rs b/src/libsyntax/ext/deriving/generic/mod.rs index 58b6d96607d..0c5e4d67642 100644 --- a/src/libsyntax/ext/deriving/generic/mod.rs +++ b/src/libsyntax/ext/deriving/generic/mod.rs @@ -332,6 +332,46 @@ pub fn combine_substructure<'a>(f: CombineSubstructureFunc<'a>) RefCell::new(f) } +/// This method helps to extract all the type parameters referenced from a +/// type. For a type parameter `<T>`, it looks for either a `TyPath` that +/// is not global and starts with `T`, or a `TyQPath`. +fn find_type_parameters(ty: &ast::Ty, ty_param_names: &[ast::Name]) -> Vec<P<ast::Ty>> { + use visit; + + struct Visitor<'a> { + ty_param_names: &'a [ast::Name], + types: Vec<P<ast::Ty>>, + } + + impl<'a> visit::Visitor<'a> for Visitor<'a> { + fn visit_ty(&mut self, ty: &'a ast::Ty) { + match ty.node { + ast::TyPath(_, ref path) if !path.global => { + match path.segments.first() { + Some(segment) => { + if self.ty_param_names.contains(&segment.identifier.name) { + self.types.push(P(ty.clone())); + } + } + None => {} + } + } + _ => {} + } + + visit::walk_ty(self, ty) + } + } + + let mut visitor = Visitor { + ty_param_names: ty_param_names, + types: Vec::new(), + }; + + visit::Visitor::visit_ty(&mut visitor, ty); + + visitor.types +} impl<'a> TraitDef<'a> { pub fn expand<F>(&self, @@ -374,18 +414,42 @@ impl<'a> TraitDef<'a> { })) } - /// Given that we are deriving a trait `Tr` for a type `T<'a, ..., - /// 'z, A, ..., Z>`, creates an impl like: + /// Given that we are deriving a trait `DerivedTrait` for a type like: /// /// ```ignore - /// impl<'a, ..., 'z, A:Tr B1 B2, ..., Z: Tr B1 B2> Tr for T<A, ..., Z> { ... } + /// struct Struct<'a, ..., 'z, A, B: DeclaredTrait, C, ..., Z> where C: WhereTrait { + /// a: A, + /// b: B::Item, + /// b1: <B as DeclaredTrait>::Item, + /// c1: <C as WhereTrait>::Item, + /// c2: Option<<C as WhereTrait>::Item>, + /// ... + /// } + /// ``` + /// + /// create an impl like: + /// + /// ```ignore + /// impl<'a, ..., 'z, A, B: DeclaredTrait, C, ... Z> where + /// C: WhereTrait, + /// A: DerivedTrait + B1 + ... + BN, + /// B: DerivedTrait + B1 + ... + BN, + /// C: DerivedTrait + B1 + ... + BN, + /// B::Item: DerivedTrait + B1 + ... + BN, + /// <C as WhereTrait>::Item: DerivedTrait + B1 + ... + BN, + /// ... + /// { + /// ... + /// } /// ``` /// - /// where B1, B2, ... are the bounds given by `bounds_paths`.' + /// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and + /// therefore does not get bound by the derived trait. fn create_derived_impl(&self, cx: &mut ExtCtxt, type_ident: Ident, generics: &Generics, + field_tys: Vec<P<ast::Ty>>, methods: Vec<P<ast::ImplItem>>) -> P<ast::Item> { let trait_path = self.path.to_path(cx, self.span, type_ident, generics); @@ -466,6 +530,35 @@ impl<'a> TraitDef<'a> { } })); + if !ty_params.is_empty() { + let ty_param_names: Vec<ast::Name> = ty_params.iter() + .map(|ty_param| ty_param.ident.name) + .collect(); + + for field_ty in field_tys.into_iter() { + let tys = find_type_parameters(&*field_ty, &ty_param_names); + + for ty in tys.into_iter() { + let mut bounds: Vec<_> = self.additional_bounds.iter().map(|p| { + cx.typarambound(p.to_path(cx, self.span, type_ident, generics)) + }).collect(); + + // require the current trait + bounds.push(cx.typarambound(trait_path.clone())); + + let predicate = ast::WhereBoundPredicate { + span: self.span, + bound_lifetimes: vec![], + bounded_ty: ty, + bounds: OwnedSlice::from_vec(bounds), + }; + + let predicate = ast::WherePredicate::BoundPredicate(predicate); + where_clause.predicates.push(predicate); + } + } + } + let trait_generics = Generics { lifetimes: lifetimes, ty_params: OwnedSlice::from_vec(ty_params), @@ -518,6 +611,10 @@ impl<'a> TraitDef<'a> { struct_def: &StructDef, type_ident: Ident, generics: &Generics) -> P<ast::Item> { + let field_tys: Vec<P<ast::Ty>> = struct_def.fields.iter() + .map(|field| field.node.ty.clone()) + .collect(); + let methods = self.methods.iter().map(|method_def| { let (explicit_self, self_args, nonself_args, tys) = method_def.split_self_nonself_args( @@ -550,7 +647,7 @@ impl<'a> TraitDef<'a> { body) }).collect(); - self.create_derived_impl(cx, type_ident, generics, methods) + self.create_derived_impl(cx, type_ident, generics, field_tys, methods) } fn expand_enum_def(&self, @@ -558,6 +655,21 @@ impl<'a> TraitDef<'a> { enum_def: &EnumDef, type_ident: Ident, generics: &Generics) -> P<ast::Item> { + let mut field_tys = Vec::new(); + + for variant in enum_def.variants.iter() { + match variant.node.kind { + ast::VariantKind::TupleVariantKind(ref args) => { + field_tys.extend(args.iter() + .map(|arg| arg.ty.clone())); + } + ast::VariantKind::StructVariantKind(ref args) => { + field_tys.extend(args.fields.iter() + .map(|field| field.node.ty.clone())); + } + } + } + let methods = self.methods.iter().map(|method_def| { let (explicit_self, self_args, nonself_args, tys) = method_def.split_self_nonself_args(cx, self, @@ -590,7 +702,7 @@ impl<'a> TraitDef<'a> { body) }).collect(); - self.create_derived_impl(cx, type_ident, generics, methods) + self.create_derived_impl(cx, type_ident, generics, field_tys, methods) } } |
