about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorJacob Pratt <jacob@jhpratt.dev>2021-06-29 20:22:52 -0400
committerJacob Pratt <jacob@jhpratt.dev>2021-07-27 15:47:47 -0400
commitc70147fd66e08962ab06adf12eb6a41bc1ea7f08 (patch)
tree3efabfe7f35fb460fdee7dbf4ec35c8d4b5305d3 /compiler
parentfd853c00e255559255885aadff9e93a1760c8728 (diff)
downloadrust-c70147fd66e08962ab06adf12eb6a41bc1ea7f08.tar.gz
rust-c70147fd66e08962ab06adf12eb6a41bc1ea7f08.zip
Permit deriving default on enums with `#[default]`
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_builtin_macros/src/deriving/default.rs197
-rw-r--r--compiler/rustc_feature/src/active.rs3
-rw-r--r--compiler/rustc_span/src/symbol.rs1
3 files changed, 170 insertions, 31 deletions
diff --git a/compiler/rustc_builtin_macros/src/deriving/default.rs b/compiler/rustc_builtin_macros/src/deriving/default.rs
index 980be3a0050..7c1dc5e5365 100644
--- a/compiler/rustc_builtin_macros/src/deriving/default.rs
+++ b/compiler/rustc_builtin_macros/src/deriving/default.rs
@@ -2,11 +2,15 @@ use crate::deriving::generic::ty::*;
 use crate::deriving::generic::*;
 
 use rustc_ast::ptr::P;
+use rustc_ast::EnumDef;
+use rustc_ast::VariantData;
 use rustc_ast::{Expr, MetaItem};
-use rustc_errors::struct_span_err;
+use rustc_errors::Applicability;
 use rustc_expand::base::{Annotatable, DummyResult, ExtCtxt};
+use rustc_span::symbol::Ident;
 use rustc_span::symbol::{kw, sym};
 use rustc_span::Span;
+use smallvec::SmallVec;
 
 pub fn expand_deriving_default(
     cx: &mut ExtCtxt<'_>,
@@ -34,8 +38,25 @@ pub fn expand_deriving_default(
             attributes: attrs,
             is_unsafe: false,
             unify_fieldless_variants: false,
-            combine_substructure: combine_substructure(Box::new(|a, b, c| {
-                default_substructure(a, b, c)
+            combine_substructure: combine_substructure(Box::new(|cx, trait_span, substr| {
+                match substr.fields {
+                    StaticStruct(_, fields) => {
+                        default_struct_substructure(cx, trait_span, substr, fields)
+                    }
+                    StaticEnum(enum_def, _) => {
+                        if !cx.sess.features_untracked().derive_default_enum {
+                            rustc_session::parse::feature_err(
+                                cx.parse_sess(),
+                                sym::derive_default_enum,
+                                span,
+                                "deriving `Default` on enums is experimental",
+                            )
+                            .emit();
+                        }
+                        default_enum_substructure(cx, trait_span, enum_def)
+                    }
+                    _ => cx.span_bug(trait_span, "method in `derive(Default)`"),
+                }
             })),
         }],
         associated_types: Vec::new(),
@@ -43,44 +64,158 @@ pub fn expand_deriving_default(
     trait_def.expand(cx, mitem, item, push)
 }
 
-fn default_substructure(
+fn default_struct_substructure(
     cx: &mut ExtCtxt<'_>,
     trait_span: Span,
     substr: &Substructure<'_>,
+    summary: &StaticFields,
 ) -> P<Expr> {
     // Note that `kw::Default` is "default" and `sym::Default` is "Default"!
     let default_ident = cx.std_path(&[kw::Default, sym::Default, kw::Default]);
     let default_call = |span| cx.expr_call_global(span, default_ident.clone(), Vec::new());
 
-    match *substr.fields {
-        StaticStruct(_, ref summary) => match *summary {
-            Unnamed(ref fields, is_tuple) => {
-                if !is_tuple {
-                    cx.expr_ident(trait_span, substr.type_ident)
-                } else {
-                    let exprs = fields.iter().map(|sp| default_call(*sp)).collect();
-                    cx.expr_call_ident(trait_span, substr.type_ident, exprs)
-                }
-            }
-            Named(ref fields) => {
-                let default_fields = fields
-                    .iter()
-                    .map(|&(ident, span)| cx.field_imm(span, ident, default_call(span)))
-                    .collect();
-                cx.expr_struct_ident(trait_span, substr.type_ident, default_fields)
+    match summary {
+        Unnamed(ref fields, is_tuple) => {
+            if !is_tuple {
+                cx.expr_ident(trait_span, substr.type_ident)
+            } else {
+                let exprs = fields.iter().map(|sp| default_call(*sp)).collect();
+                cx.expr_call_ident(trait_span, substr.type_ident, exprs)
             }
-        },
-        StaticEnum(..) => {
-            struct_span_err!(
-                &cx.sess.parse_sess.span_diagnostic,
-                trait_span,
-                E0665,
-                "`Default` cannot be derived for enums, only structs"
-            )
+        }
+        Named(ref fields) => {
+            let default_fields = fields
+                .iter()
+                .map(|&(ident, span)| cx.field_imm(span, ident, default_call(span)))
+                .collect();
+            cx.expr_struct_ident(trait_span, substr.type_ident, default_fields)
+        }
+    }
+}
+
+fn default_enum_substructure(
+    cx: &mut ExtCtxt<'_>,
+    trait_span: Span,
+    enum_def: &EnumDef,
+) -> P<Expr> {
+    let default_variant = match extract_default_variant(cx, enum_def, trait_span) {
+        Ok(value) => value,
+        Err(()) => return DummyResult::raw_expr(trait_span, true),
+    };
+
+    // At this point, we know that there is exactly one variant with a `#[default]` attribute. The
+    // attribute hasn't yet been validated.
+
+    if let Err(()) = validate_default_attribute(cx, default_variant) {
+        return DummyResult::raw_expr(trait_span, true);
+    }
+
+    // We now know there is exactly one unit variant with exactly one `#[default]` attribute.
+
+    cx.expr_path(cx.path(
+        default_variant.span,
+        vec![Ident::new(kw::SelfUpper, default_variant.span), default_variant.ident],
+    ))
+}
+
+fn extract_default_variant<'a>(
+    cx: &mut ExtCtxt<'_>,
+    enum_def: &'a EnumDef,
+    trait_span: Span,
+) -> Result<&'a rustc_ast::Variant, ()> {
+    let default_variants: SmallVec<[_; 1]> = enum_def
+        .variants
+        .iter()
+        .filter(|variant| cx.sess.contains_name(&variant.attrs, kw::Default))
+        .collect();
+
+    let variant = match default_variants.as_slice() {
+        [variant] => variant,
+        [] => {
+            cx.struct_span_err(trait_span, "no default declared")
+                .help("make a unit variant default by placing `#[default]` above it")
+                .emit();
+
+            return Err(());
+        }
+        [first, rest @ ..] => {
+            cx.struct_span_err(trait_span, "multiple declared defaults")
+                .span_label(first.span, "first default")
+                .span_labels(rest.iter().map(|variant| variant.span), "additional default")
+                .note("only one variant can be default")
+                .emit();
+
+            return Err(());
+        }
+    };
+
+    if !matches!(variant.data, VariantData::Unit(..)) {
+        cx.struct_span_err(variant.ident.span, "`#[default]` may only be used on unit variants")
+            .help("consider a manual implementation of `Default`")
+            .emit();
+
+        return Err(());
+    }
+
+    if let Some(non_exhaustive_attr) = cx.sess.find_by_name(&variant.attrs, sym::non_exhaustive) {
+        cx.struct_span_err(variant.ident.span, "default variant must be exhaustive")
+            .span_label(non_exhaustive_attr.span, "declared `#[non_exhaustive]` here")
+            .help("consider a manual implementation of `Default`")
             .emit();
-            // let compilation continue
-            DummyResult::raw_expr(trait_span, true)
+
+        return Err(());
+    }
+
+    Ok(variant)
+}
+
+fn validate_default_attribute(
+    cx: &mut ExtCtxt<'_>,
+    default_variant: &rustc_ast::Variant,
+) -> Result<(), ()> {
+    let attrs: SmallVec<[_; 1]> =
+        cx.sess.filter_by_name(&default_variant.attrs, kw::Default).collect();
+
+    let attr = match attrs.as_slice() {
+        [attr] => attr,
+        [] => cx.bug(
+            "this method must only be called with a variant that has a `#[default]` attribute",
+        ),
+        [first, rest @ ..] => {
+            // FIXME(jhpratt) Do we want to perform this check? It doesn't exist
+            // for `#[inline]`, `#[non_exhaustive]`, and presumably others.
+
+            let suggestion_text =
+                if rest.len() == 1 { "try removing this" } else { "try removing these" };
+
+            cx.struct_span_err(default_variant.ident.span, "multiple `#[default]` attributes")
+                .note("only one `#[default]` attribute is needed")
+                .span_label(first.span, "`#[default]` used here")
+                .span_label(rest[0].span, "`#[default]` used again here")
+                .span_help(rest.iter().map(|attr| attr.span).collect::<Vec<_>>(), suggestion_text)
+                // This would otherwise display the empty replacement, hence the otherwise
+                // repetitive `.span_help` call above.
+                .tool_only_multipart_suggestion(
+                    suggestion_text,
+                    rest.iter().map(|attr| (attr.span, String::new())).collect(),
+                    Applicability::MachineApplicable,
+                )
+                .emit();
+
+            return Err(());
         }
-        _ => cx.span_bug(trait_span, "method in `derive(Default)`"),
+    };
+    if !attr.is_word() {
+        cx.struct_span_err(attr.span, "`#[default]` attribute does not accept a value")
+            .span_suggestion_hidden(
+                attr.span,
+                "try using `#[default]`",
+                "#[default]".into(),
+                Applicability::MaybeIncorrect,
+            )
+            .emit();
+
+        return Err(());
     }
+    Ok(())
 }
diff --git a/compiler/rustc_feature/src/active.rs b/compiler/rustc_feature/src/active.rs
index 803e4a2e59d..41351c1f938 100644
--- a/compiler/rustc_feature/src/active.rs
+++ b/compiler/rustc_feature/src/active.rs
@@ -683,6 +683,9 @@ declare_features! (
     /// Infer generic args for both consts and types.
     (active, generic_arg_infer, "1.55.0", Some(85077), None),
 
+    /// Allows `#[derive(Default)]` and `#[default]` on enums.
+    (active, derive_default_enum, "1.56.0", Some(86985), None),
+
     // -------------------------------------------------------------------------
     // feature-group-end: actual feature gates
     // -------------------------------------------------------------------------
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index 536ebdef426..5fc773e431c 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -489,6 +489,7 @@ symbols! {
         deref_mut,
         deref_target,
         derive,
+        derive_default_enum,
         destructuring_assignment,
         diagnostic,
         direct,