about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/librustc/mir/mod.rs9
-rw-r--r--src/librustc/mir/tcx.rs12
-rw-r--r--src/librustc/ty/layout.rs81
-rw-r--r--src/librustc/ty/sty.rs28
-rw-r--r--src/librustc_codegen_llvm/debuginfo/metadata.rs6
-rw-r--r--src/librustc_codegen_llvm/type_of.rs5
-rw-r--r--src/librustc_codegen_ssa/mir/mod.rs37
-rw-r--r--src/librustc_codegen_ssa/mir/place.rs12
-rw-r--r--src/librustc_mir/borrow_check/nll/type_check/mod.rs77
-rw-r--r--src/librustc_mir/build/expr/as_rvalue.rs19
-rw-r--r--src/librustc_mir/transform/deaggregator.rs20
-rw-r--r--src/librustc_mir/transform/generator.rs127
-rw-r--r--src/test/mir-opt/generator-drop-cleanup.rs3
13 files changed, 281 insertions, 155 deletions
diff --git a/src/librustc/mir/mod.rs b/src/librustc/mir/mod.rs
index 5372f868aa9..79e143b5c24 100644
--- a/src/librustc/mir/mod.rs
+++ b/src/librustc/mir/mod.rs
@@ -2028,6 +2028,10 @@ impl<'tcx> Place<'tcx> {
             variant_index))
     }
 
+    pub fn downcast_unnamed(self, variant_index: VariantIdx) -> Place<'tcx> {
+        self.elem(ProjectionElem::Downcast(None, variant_index))
+    }
+
     pub fn index(self, index: Local) -> Place<'tcx> {
         self.elem(ProjectionElem::Index(index))
     }
@@ -2553,11 +2557,6 @@ impl<'tcx> Debug for Rvalue<'tcx> {
                                     let var_name = tcx.hir().name_by_hir_id(freevar.var_id());
                                     struct_fmt.field(&var_name.as_str(), place);
                                 }
-                                struct_fmt.field("$state", &places[freevars.len()]);
-                                for i in (freevars.len() + 1)..places.len() {
-                                    struct_fmt
-                                        .field(&format!("${}", i - freevars.len() - 1), &places[i]);
-                                }
                             });
 
                             struct_fmt.finish()
diff --git a/src/librustc/mir/tcx.rs b/src/librustc/mir/tcx.rs
index 04b763f773d..199ee3e04b3 100644
--- a/src/librustc/mir/tcx.rs
+++ b/src/librustc/mir/tcx.rs
@@ -177,11 +177,13 @@ impl<'tcx> Rvalue<'tcx> {
             }
             Rvalue::Discriminant(ref place) => {
                 let ty = place.ty(local_decls, tcx).ty;
-                if let ty::Adt(adt_def, _) = ty.sty {
-                    adt_def.repr.discr_type().to_ty(tcx)
-                } else {
-                    // This can only be `0`, for now, so `u8` will suffice.
-                    tcx.types.u8
+                match ty.sty {
+                    ty::Adt(adt_def, _) => adt_def.repr.discr_type().to_ty(tcx),
+                    ty::Generator(_, substs, _) => substs.discr_ty(tcx),
+                    _ => {
+                        // This can only be `0`, for now, so `u8` will suffice.
+                        tcx.types.u8
+                    }
                 }
             }
             Rvalue::NullaryOp(NullOp::Box, t) => tcx.mk_box(t),
diff --git a/src/librustc/ty/layout.rs b/src/librustc/ty/layout.rs
index fd1d3a91ede..10afaebc91a 100644
--- a/src/librustc/ty/layout.rs
+++ b/src/librustc/ty/layout.rs
@@ -604,12 +604,57 @@ impl<'a, 'tcx> LayoutCx<'tcx, TyCtxt<'a, 'tcx, 'tcx>> {
                 tcx.intern_layout(unit)
             }
 
-            // Tuples, generators and closures.
             ty::Generator(def_id, ref substs, _) => {
-                let tys = substs.field_tys(def_id, tcx);
-                univariant(&tys.map(|ty| self.layout_of(ty)).collect::<Result<Vec<_>, _>>()?,
+                let discr_index = substs.prefix_tys(def_id, tcx).count();
+                let prefix_tys = substs.prefix_tys(def_id, tcx)
+                    .chain(iter::once(substs.discr_ty(tcx)));
+                let prefix = univariant_uninterned(
+                    &prefix_tys.map(|ty| self.layout_of(ty)).collect::<Result<Vec<_>, _>>()?,
                     &ReprOptions::default(),
-                    StructKind::AlwaysSized)?
+                    StructKind::AlwaysSized)?;
+
+                let mut size = prefix.size;
+                let mut align = prefix.align;
+                let variants_tys = substs.state_tys(def_id, tcx);
+                let variants = variants_tys.enumerate().map(|(i, variant_tys)| {
+                    let mut variant = univariant_uninterned(
+                        &variant_tys.map(|ty| self.layout_of(ty)).collect::<Result<Vec<_>, _>>()?,
+                        &ReprOptions::default(),
+                        StructKind::Prefixed(prefix.size, prefix.align.abi))?;
+
+                    variant.variants = Variants::Single { index: VariantIdx::new(i) };
+
+                    size = size.max(variant.size);
+                    align = align.max(variant.align);
+
+                    Ok(variant)
+                }).collect::<Result<IndexVec<VariantIdx, _>, _>>()?;
+
+                let abi = if prefix.abi.is_uninhabited() ||
+                             variants.iter().all(|v| v.abi.is_uninhabited()) {
+                    Abi::Uninhabited
+                } else {
+                    Abi::Aggregate { sized: true }
+                };
+                let discr = match &self.layout_of(substs.discr_ty(tcx))?.abi {
+                    Abi::Scalar(s) => s.clone(),
+                    _ => bug!(),
+                };
+
+                let layout = tcx.intern_layout(LayoutDetails {
+                    variants: Variants::Multiple {
+                        discr,
+                        discr_kind: DiscriminantKind::Tag,
+                        discr_index,
+                        variants,
+                    },
+                    fields: prefix.fields,
+                    abi,
+                    size,
+                    align,
+                });
+                debug!("generator layout: {:#?}", layout);
+                layout
             }
 
             ty::Closure(def_id, ref substs) => {
@@ -1646,6 +1691,14 @@ impl<'a, 'tcx, C> TyLayoutMethods<'tcx, C> for Ty<'tcx>
 
     fn field(this: TyLayout<'tcx>, cx: &C, i: usize) -> C::TyLayout {
         let tcx = cx.tcx();
+        let handle_discriminant = |discr: &Scalar| -> C::TyLayout {
+            let layout = LayoutDetails::scalar(cx, discr.clone());
+            MaybeResult::from_ok(TyLayout {
+                details: tcx.intern_layout(layout),
+                ty: discr.value.to_ty(tcx)
+            })
+        };
+
         cx.layout_of(match this.ty.sty {
             ty::Bool |
             ty::Char |
@@ -1720,7 +1773,19 @@ impl<'a, 'tcx, C> TyLayoutMethods<'tcx, C> for Ty<'tcx>
             }
 
             ty::Generator(def_id, ref substs, _) => {
-                substs.field_tys(def_id, tcx).nth(i).unwrap()
+                match this.variants {
+                    Variants::Single { index } => {
+                        substs.state_tys(def_id, tcx)
+                            .nth(index.as_usize()).unwrap()
+                            .nth(i).unwrap()
+                    }
+                    Variants::Multiple { ref discr, discr_index, .. } => {
+                        if i == discr_index {
+                            return handle_discriminant(discr);
+                        }
+                        substs.prefix_tys(def_id, tcx).nth(i).unwrap()
+                    }
+                }
             }
 
             ty::Tuple(tys) => tys[i],
@@ -1740,11 +1805,7 @@ impl<'a, 'tcx, C> TyLayoutMethods<'tcx, C> for Ty<'tcx>
                     // Discriminant field for enums (where applicable).
                     Variants::Multiple { ref discr, .. } => {
                         assert_eq!(i, 0);
-                        let layout = LayoutDetails::scalar(cx, discr.clone());
-                        return MaybeResult::from_ok(TyLayout {
-                            details: tcx.intern_layout(layout),
-                            ty: discr.value.to_ty(tcx)
-                        });
+                        return handle_discriminant(discr);
                     }
                 }
             }
diff --git a/src/librustc/ty/sty.rs b/src/librustc/ty/sty.rs
index edd6014618e..842e49dfc08 100644
--- a/src/librustc/ty/sty.rs
+++ b/src/librustc/ty/sty.rs
@@ -15,7 +15,6 @@ use crate::util::captures::Captures;
 use crate::mir::interpret::{Scalar, Pointer};
 
 use smallvec::SmallVec;
-use std::iter;
 use std::cmp::Ordering;
 use std::marker::PhantomData;
 use rustc_target::spec::abi;
@@ -475,30 +474,23 @@ impl<'a, 'gcx, 'tcx> GeneratorSubsts<'tcx> {
     /// This returns the types of the MIR locals which had to be stored across suspension points.
     /// It is calculated in rustc_mir::transform::generator::StateTransform.
     /// All the types here must be in the tuple in GeneratorInterior.
+    ///
+    /// The locals are grouped by their variant number. Note that some locals may
+    /// be repeated in multiple variants.
     pub fn state_tys(self, def_id: DefId, tcx: TyCtxt<'a, 'gcx, 'tcx>) ->
-        impl Iterator<Item=Ty<'tcx>> + Captures<'gcx> + 'a
+        impl Iterator<Item=impl Iterator<Item=Ty<'tcx>> + Captures<'gcx> + 'a>
     {
-        // TODO remove so we can handle variants properly
         tcx.generator_layout(def_id)
-            .variant_fields[0].iter()
-            .map(move |d| d.ty.subst(tcx, self.substs))
+            .variant_fields.iter()
+            .map(move |v| v.iter().map(move |d| d.ty.subst(tcx, self.substs)))
     }
 
-    /// This is the types of the fields of a generator which
-    /// is available before the generator transformation.
-    /// It includes the upvars and the state discriminant.
-    pub fn pre_transforms_tys(self, def_id: DefId, tcx: TyCtxt<'a, 'gcx, 'tcx>) ->
+    /// This is the types of the fields of a generator which are not stored in a
+    /// variant.
+    pub fn prefix_tys(self, def_id: DefId, tcx: TyCtxt<'a, 'gcx, 'tcx>) ->
         impl Iterator<Item=Ty<'tcx>> + 'a
     {
-        self.upvar_tys(def_id, tcx).chain(iter::once(self.discr_ty(tcx)))
-    }
-
-    /// This is the types of all the fields stored in a generator.
-    /// It includes the upvars, state types and the state discriminant.
-    pub fn field_tys(self, def_id: DefId, tcx: TyCtxt<'a, 'gcx, 'tcx>) ->
-        impl Iterator<Item=Ty<'tcx>> + Captures<'gcx> + 'a
-    {
-        self.pre_transforms_tys(def_id, tcx).chain(self.state_tys(def_id, tcx))
+        self.upvar_tys(def_id, tcx)
     }
 }
 
diff --git a/src/librustc_codegen_llvm/debuginfo/metadata.rs b/src/librustc_codegen_llvm/debuginfo/metadata.rs
index 31348b99c5a..bbcd3c220d6 100644
--- a/src/librustc_codegen_llvm/debuginfo/metadata.rs
+++ b/src/librustc_codegen_llvm/debuginfo/metadata.rs
@@ -691,9 +691,12 @@ pub fn type_metadata(
                                    usage_site_span).finalize(cx)
         }
         ty::Generator(def_id, substs,  _) => {
-            let upvar_tys : Vec<_> = substs.field_tys(def_id, cx.tcx).map(|t| {
+            // TODO handle variant fields
+            let upvar_tys : Vec<_> = substs.prefix_tys(def_id, cx.tcx).map(|t| {
                 cx.tcx.normalize_erasing_regions(ParamEnv::reveal_all(), t)
             }).collect();
+            // TODO use prepare_enum_metadata and update it to handle multiple
+            // fields in the outer layout.
             prepare_tuple_metadata(cx,
                                    t,
                                    &upvar_tys,
@@ -1818,6 +1821,7 @@ fn prepare_enum_metadata(
     };
 
     // The variant part must be wrapped in a struct according to DWARF.
+    // TODO create remaining fields here, if any.
     let type_array = create_DIArray(DIB(cx), &[Some(variant_part)]);
     let struct_wrapper = unsafe {
         llvm::LLVMRustDIBuilderCreateStructType(
diff --git a/src/librustc_codegen_llvm/type_of.rs b/src/librustc_codegen_llvm/type_of.rs
index d42fa829161..080f78ff112 100644
--- a/src/librustc_codegen_llvm/type_of.rs
+++ b/src/librustc_codegen_llvm/type_of.rs
@@ -63,6 +63,11 @@ fn uncached_llvm_type<'a, 'tcx>(cx: &CodegenCx<'a, 'tcx>,
                     write!(&mut name, "::{}", def.variants[index].ident).unwrap();
                 }
             }
+            if let (&ty::Generator(..), &layout::Variants::Single { index })
+                 = (&layout.ty.sty, &layout.variants)
+            {
+                write!(&mut name, "::variant#{:?}", index).unwrap();
+            }
             Some(name)
         }
         _ => None
diff --git a/src/librustc_codegen_ssa/mir/mod.rs b/src/librustc_codegen_ssa/mir/mod.rs
index 4387d77a925..52429294852 100644
--- a/src/librustc_codegen_ssa/mir/mod.rs
+++ b/src/librustc_codegen_ssa/mir/mod.rs
@@ -4,6 +4,7 @@ use rustc::mir::{self, Mir};
 use rustc::session::config::DebugInfo;
 use rustc_mir::monomorphize::Instance;
 use rustc_target::abi::call::{FnType, PassMode, IgnoreMode};
+use rustc_target::abi::{Variants, VariantIdx};
 use crate::base;
 use crate::debuginfo::{self, VariableAccess, VariableKind, FunctionDebugContext};
 use crate::traits::*;
@@ -648,7 +649,7 @@ fn arg_local_refs<'a, 'tcx: 'a, Bx: BuilderMethods<'a, 'tcx>>(
                     .iter()
                     .zip(upvar_tys)
                     .enumerate()
-                    .map(|(i, (upvar, ty))| (i, upvar.debug_name, upvar.by_ref, ty));
+                    .map(|(i, (upvar, ty))| (None, i, upvar.debug_name, upvar.by_ref, ty));
 
                 let generator_fields = mir.generator_layout.as_ref().map(|generator_layout| {
                     let (def_id, gen_substs) = match closure_layout.ty.sty {
@@ -658,23 +659,39 @@ fn arg_local_refs<'a, 'tcx: 'a, Bx: BuilderMethods<'a, 'tcx>>(
                     // TODO handle variant scopes here
                     let state_tys = gen_substs.state_tys(def_id, tcx);
 
-                    // TODO remove assumption of only one variant
-                    let upvar_count = mir.upvar_decls.len();
-                    generator_layout.variant_fields[0]
-                        .iter()
+                    generator_layout.variant_fields.iter()
                         .zip(state_tys)
                         .enumerate()
-                        .filter_map(move |(i, (decl, ty))| {
-                            let ty = fx.monomorphize(&ty);
-                            decl.name.map(|name| (i + upvar_count + 1, name, false, ty))
+                        .flat_map(move |(variant_idx, (decls, tys))| {
+                            let variant_idx = Some(VariantIdx::from(variant_idx));
+                            decls.iter()
+                                .zip(tys)
+                                .enumerate()
+                                .filter_map(move |(i, (decl, ty))| {
+                                    let ty = fx.monomorphize(&ty);
+                                    decl.name.map(|name| {
+                                        (variant_idx, i, name, false, ty)
+                                })
+                            })
                         })
                 }).into_iter().flatten();
 
                 upvars.chain(generator_fields)
             };
 
-            for (field, name, by_ref, ty) in extra_locals {
-                let byte_offset_of_var_in_env = closure_layout.fields.offset(field).bytes();
+            for (variant_idx, field, name, by_ref, ty) in extra_locals {
+                let fields = match variant_idx {
+                    Some(variant_idx) => {
+                        match &closure_layout.variants {
+                            Variants::Multiple { variants, .. } => {
+                                &variants[variant_idx].fields
+                            },
+                            _ => bug!("variant index on univariant layout"),
+                        }
+                    }
+                    None => &closure_layout.fields,
+                };
+                let byte_offset_of_var_in_env = fields.offset(field).bytes();
 
                 let ops = bx.debuginfo_upvar_ops_sequence(byte_offset_of_var_in_env);
 
diff --git a/src/librustc_codegen_ssa/mir/place.rs b/src/librustc_codegen_ssa/mir/place.rs
index 1134707f96c..2875468127e 100644
--- a/src/librustc_codegen_ssa/mir/place.rs
+++ b/src/librustc_codegen_ssa/mir/place.rs
@@ -296,9 +296,15 @@ impl<'a, 'tcx: 'a, V: CodegenObject> PlaceRef<'tcx, V> {
                 ..
             } => {
                 let ptr = self.project_field(bx, discr_index);
-                let to = self.layout.ty.ty_adt_def().unwrap()
-                    .discriminant_for_variant(bx.tcx(), variant_index)
-                    .val;
+                let to = match self.layout.ty.sty {
+                    ty::TyKind::Adt(adt_def, _) => adt_def
+                        .discriminant_for_variant(bx.tcx(), variant_index)
+                        .val,
+                    // Generators don't support explicit discriminant values, so
+                    // they are the same as the variant index.
+                    ty::TyKind::Generator(..) => variant_index.as_u32() as u128,
+                    _ => bug!(),
+                };
                 bx.store(
                     bx.cx().const_uint_big(bx.cx().backend_type(ptr.layout), to),
                     ptr.llval,
diff --git a/src/librustc_mir/borrow_check/nll/type_check/mod.rs b/src/librustc_mir/borrow_check/nll/type_check/mod.rs
index 0dee64db727..94900b98a52 100644
--- a/src/librustc_mir/borrow_check/nll/type_check/mod.rs
+++ b/src/librustc_mir/borrow_check/nll/type_check/mod.rs
@@ -684,6 +684,25 @@ impl<'a, 'b, 'gcx, 'tcx> TypeVerifier<'a, 'b, 'gcx, 'tcx> {
                         }
                     }
                 }
+                ty::Generator(def_id, substs, _) => {
+                    let variants = substs.state_tys(def_id, tcx).count();
+                    if index.as_usize() >= variants {
+                        PlaceTy::from_ty(
+                            span_mirbug_and_err!(
+                                self,
+                                place,
+                                "cast to variant #{:?} but generator only has {:?}",
+                                index,
+                                variants
+                            ),
+                        )
+                    } else {
+                        PlaceTy {
+                            ty: base_ty,
+                            variant_index: Some(index),
+                        }
+                    }
+                }
                 _ => {
                     let ty = if let Some(name) = maybe_name {
                         span_mirbug_and_err!(
@@ -745,11 +764,26 @@ impl<'a, 'b, 'gcx, 'tcx> TypeVerifier<'a, 'b, 'gcx, 'tcx> {
         let tcx = self.tcx();
 
         let (variant, substs) = match base_ty {
-            PlaceTy { ty, variant_index: Some(variant_index) } => {
-                match ty.sty {
-                    ty::Adt(adt_def, substs) => (&adt_def.variants[variant_index], substs),
-                    _ => bug!("can't have downcast of non-adt type"),
+            PlaceTy { ty, variant_index: Some(variant_index) } => match ty.sty {
+                ty::Adt(adt_def, substs) => (&adt_def.variants[variant_index], substs),
+                ty::Generator(def_id, substs, _) => {
+                    let mut variants = substs.state_tys(def_id, tcx);
+                    let mut variant = match variants.nth(variant_index.into()) {
+                        Some(v) => v,
+                        None => {
+                            bug!("variant_index of generator out of range: {:?}/{:?}",
+                                 variant_index,
+                                 substs.state_tys(def_id, tcx).count())
+                        }
+                    };
+                    return match variant.nth(field.index()) {
+                        Some(ty) => Ok(ty),
+                        None => Err(FieldAccessError::OutOfRange {
+                            field_count: variant.count(),
+                        }),
+                    }
                 }
+                _ => bug!("can't have downcast of non-adt non-generator type"),
             }
             PlaceTy { ty, variant_index: None } => match ty.sty {
                 ty::Adt(adt_def, substs) if !adt_def.is_enum() =>
@@ -763,19 +797,14 @@ impl<'a, 'b, 'gcx, 'tcx> TypeVerifier<'a, 'b, 'gcx, 'tcx> {
                     }
                 }
                 ty::Generator(def_id, substs, _) => {
-                    // Try pre-transform fields first (upvars and current state)
-                    if let Some(ty) = substs.pre_transforms_tys(def_id, tcx).nth(field.index()) {
-                        return Ok(ty);
-                    }
-
-                    // Then try `field_tys` which contains all the fields, but it
-                    // requires the final optimized MIR.
-                    return match substs.field_tys(def_id, tcx).nth(field.index()) {
+                    // Only prefix fields (upvars and current state) are
+                    // accessible without a variant index.
+                    return match substs.prefix_tys(def_id, tcx).nth(field.index()) {
                         Some(ty) => Ok(ty),
                         None => Err(FieldAccessError::OutOfRange {
-                            field_count: substs.field_tys(def_id, tcx).count(),
+                            field_count: substs.prefix_tys(def_id, tcx).count(),
                         }),
-                    };
+                    }
                 }
                 ty::Tuple(tys) => {
                     return match tys.get(field.index()) {
@@ -1908,18 +1937,14 @@ impl<'a, 'gcx, 'tcx> TypeChecker<'a, 'gcx, 'tcx> {
                 }
             }
             AggregateKind::Generator(def_id, substs, _) => {
-                // Try pre-transform fields first (upvars and current state)
-                if let Some(ty) = substs.pre_transforms_tys(def_id, tcx).nth(field_index) {
-                    Ok(ty)
-                } else {
-                    // Then try `field_tys` which contains all the fields, but it
-                    // requires the final optimized MIR.
-                    match substs.field_tys(def_id, tcx).nth(field_index) {
-                        Some(ty) => Ok(ty),
-                        None => Err(FieldAccessError::OutOfRange {
-                            field_count: substs.field_tys(def_id, tcx).count(),
-                        }),
-                    }
+                // It doesn't make sense to look at a field beyond the prefix;
+                // these require a variant index, and are not initialized in
+                // aggregate rvalues.
+                match substs.prefix_tys(def_id, tcx).nth(field_index) {
+                    Some(ty) => Ok(ty),
+                    None => Err(FieldAccessError::OutOfRange {
+                        field_count: substs.prefix_tys(def_id, tcx).count(),
+                    }),
                 }
             }
             AggregateKind::Array(ty) => Ok(ty),
diff --git a/src/librustc_mir/build/expr/as_rvalue.rs b/src/librustc_mir/build/expr/as_rvalue.rs
index 0871217c524..201bc4a43e4 100644
--- a/src/librustc_mir/build/expr/as_rvalue.rs
+++ b/src/librustc_mir/build/expr/as_rvalue.rs
@@ -211,7 +211,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
                 movability,
             } => {
                 // see (*) above
-                let mut operands: Vec<_> = upvars
+                let operands: Vec<_> = upvars
                     .into_iter()
                     .map(|upvar| {
                         let upvar = this.hir.mirror(upvar);
@@ -252,22 +252,9 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
                     }).collect();
                 let result = match substs {
                     UpvarSubsts::Generator(substs) => {
+                        // We implicitly set the discriminant to 0. See
+                        // librustc_mir/transform/deaggregator.rs for details.
                         let movability = movability.unwrap();
-                        // Add the state operand since it follows the upvars in the generator
-                        // struct. See librustc_mir/transform/generator.rs for more details.
-                        let discr_ty = substs.discr_ty(this.hir.tcx());
-                        operands.push(Operand::Constant(box Constant {
-                            span: expr_span,
-                            ty: discr_ty,
-                            user_ty: None,
-                            literal: this.hir.tcx().mk_const(
-                                ty::Const::from_bits(
-                                    this.hir.tcx(),
-                                    0,
-                                    ty::ParamEnv::empty().and(discr_ty),
-                                ),
-                            ),
-                        }));
                         box AggregateKind::Generator(closure_id, substs, movability)
                     }
                     UpvarSubsts::Closure(substs) => box AggregateKind::Closure(closure_id, substs),
diff --git a/src/librustc_mir/transform/deaggregator.rs b/src/librustc_mir/transform/deaggregator.rs
index 9061dfff76f..9f8d40bf4cd 100644
--- a/src/librustc_mir/transform/deaggregator.rs
+++ b/src/librustc_mir/transform/deaggregator.rs
@@ -1,5 +1,6 @@
-use rustc::ty::TyCtxt;
 use rustc::mir::*;
+use rustc::ty::TyCtxt;
+use rustc::ty::layout::VariantIdx;
 use rustc_data_structures::indexed_vec::Idx;
 use crate::transform::{MirPass, MirSource};
 
@@ -55,6 +56,23 @@ impl MirPass for Deaggregator {
                         }
                         active_field_index
                     }
+                    AggregateKind::Generator(..) => {
+                        // Right now we only support initializing generators to
+                        // variant#0.
+                        let variant_index = VariantIdx::new(0);
+                        set_discriminant = Some(Statement {
+                            kind: StatementKind::SetDiscriminant {
+                                place: lhs.clone(),
+                                variant_index,
+                            },
+                            source_info,
+                        });
+
+                        // Operands are upvars stored on the base place, so no
+                        // downcast is necessary.
+
+                        None
+                    }
                     _ => None
                 };
 
diff --git a/src/librustc_mir/transform/generator.rs b/src/librustc_mir/transform/generator.rs
index 253038fd030..b7c4bfd5126 100644
--- a/src/librustc_mir/transform/generator.rs
+++ b/src/librustc_mir/transform/generator.rs
@@ -60,7 +60,7 @@ use rustc_data_structures::fx::FxHashMap;
 use rustc_data_structures::indexed_vec::Idx;
 use rustc_data_structures::bit_set::BitSet;
 use std::borrow::Cow;
-use std::iter::once;
+use std::iter;
 use std::mem;
 use crate::transform::{MirPass, MirSource};
 use crate::transform::simplify;
@@ -145,14 +145,14 @@ fn self_arg() -> Local {
 }
 
 /// Generator have not been resumed yet
-const UNRESUMED: u32 = 0;
+const UNRESUMED: usize = 0;
 /// Generator has returned / is completed
-const RETURNED: u32 = 1;
+const RETURNED: usize = 1;
 /// Generator has been poisoned
-const POISONED: u32 = 2;
+const POISONED: usize = 2;
 
 struct SuspensionPoint {
-    state: u32,
+    state: usize,
     resume: BasicBlock,
     drop: Option<BasicBlock>,
     storage_liveness: liveness::LiveVarSet,
@@ -163,15 +163,12 @@ struct TransformVisitor<'a, 'tcx: 'a> {
     state_adt_ref: &'tcx AdtDef,
     state_substs: SubstsRef<'tcx>,
 
-    // The index of the generator state in the generator struct
-    state_field: usize,
-
-    // The type of the generator state in the generator struct
+    // The type of the discriminant in the generator struct
     discr_ty: Ty<'tcx>,
 
     // Mapping from Local to (type of local, generator struct index)
     // FIXME(eddyb) This should use `IndexVec<Local, Option<_>>`.
-    remap: FxHashMap<Local, (Ty<'tcx>, usize)>,
+    remap: FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>,
 
     // A map from a suspension point in a block to the locals which have live storage at that point
     // FIXME(eddyb) This should use `IndexVec<BasicBlock, Option<_>>`.
@@ -192,8 +189,9 @@ impl<'a, 'tcx> TransformVisitor<'a, 'tcx> {
     }
 
     // Create a Place referencing a generator struct field
-    fn make_field(&self, idx: usize, ty: Ty<'tcx>) -> Place<'tcx> {
-        let base = Place::Base(PlaceBase::Local(self_arg()));
+    fn make_field(&self, variant_index: VariantIdx, idx: usize, ty: Ty<'tcx>) -> Place<'tcx> {
+        let self_place = Place::Base(PlaceBase::Local(self_arg()));
+        let base = self_place.downcast_unnamed(variant_index);
         let field = Projection {
             base: base,
             elem: ProjectionElem::Field(Field::new(idx), ty),
@@ -201,24 +199,28 @@ impl<'a, 'tcx> TransformVisitor<'a, 'tcx> {
         Place::Projection(Box::new(field))
     }
 
-    // Create a statement which changes the generator state
-    fn set_state(&self, state_disc: u32, source_info: SourceInfo) -> Statement<'tcx> {
-        let state = self.make_field(self.state_field, self.discr_ty);
-        let val = Operand::Constant(box Constant {
-            span: source_info.span,
-            ty: self.discr_ty,
-            user_ty: None,
-            literal: self.tcx.mk_const(ty::Const::from_bits(
-                self.tcx,
-                state_disc.into(),
-                ty::ParamEnv::empty().and(self.discr_ty)
-            )),
-        });
+    // Create a statement which changes the discriminant
+    fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
+        let self_place = Place::Base(PlaceBase::Local(self_arg()));
         Statement {
             source_info,
-            kind: StatementKind::Assign(state, box Rvalue::Use(val)),
+            kind: StatementKind::SetDiscriminant { place: self_place, variant_index: state_disc },
         }
     }
+
+    // Create a statement which reads the discriminant into a temporary
+    fn get_discr(&self, mir: &mut Mir<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
+        let temp_decl = LocalDecl::new_internal(self.tcx.types.isize, mir.span);
+        let temp = Place::Base(PlaceBase::Local(Local::new(mir.local_decls.len())));
+        mir.local_decls.push(temp_decl);
+
+        let self_place = Place::Base(PlaceBase::Local(self_arg()));
+        let assign = Statement {
+            source_info: source_info(mir),
+            kind: StatementKind::Assign(temp.clone(), box Rvalue::Discriminant(self_place)),
+        };
+        (assign, temp)
+    }
 }
 
 impl<'a, 'tcx> MutVisitor<'tcx> for TransformVisitor<'a, 'tcx> {
@@ -235,8 +237,8 @@ impl<'a, 'tcx> MutVisitor<'tcx> for TransformVisitor<'a, 'tcx> {
                     location: Location) {
         if let Place::Base(PlaceBase::Local(l)) = *place {
             // Replace an Local in the remap with a generator struct access
-            if let Some(&(ty, idx)) = self.remap.get(&l) {
-                *place = self.make_field(idx, ty);
+            if let Some(&(ty, variant_index, idx)) = self.remap.get(&l) {
+                *place = self.make_field(variant_index, idx, ty);
             }
         } else {
             self.super_place(place, context, location);
@@ -277,7 +279,7 @@ impl<'a, 'tcx> MutVisitor<'tcx> for TransformVisitor<'a, 'tcx> {
                                             box self.make_state(state_idx, v)),
             });
             let state = if let Some(resume) = resume { // Yield
-                let state = 3 + self.suspension_points.len() as u32;
+                let state = 3 + self.suspension_points.len();
 
                 self.suspension_points.push(SuspensionPoint {
                     state,
@@ -286,11 +288,11 @@ impl<'a, 'tcx> MutVisitor<'tcx> for TransformVisitor<'a, 'tcx> {
                     storage_liveness: self.storage_liveness.get(&block).unwrap().clone(),
                 });
 
-                state
+                VariantIdx::new(state)
             } else { // Return
-                RETURNED // state for returned
+                VariantIdx::new(RETURNED) // state for returned
             };
-            data.statements.push(self.set_state(state, source_info));
+            data.statements.push(self.set_discr(state, source_info));
             data.terminator.as_mut().unwrap().kind = TerminatorKind::Return;
         }
 
@@ -391,6 +393,7 @@ fn locals_live_across_suspend_points(
 ) -> (
     liveness::LiveVarSet,
     FxHashMap<BasicBlock, liveness::LiveVarSet>,
+    BitSet<BasicBlock>,
 ) {
     let dead_unwinds = BitSet::new_empty(mir.basic_blocks().len());
     let def_id = source.def_id();
@@ -435,8 +438,12 @@ fn locals_live_across_suspend_points(
 
     let mut storage_liveness_map = FxHashMap::default();
 
+    let mut suspending_blocks = BitSet::new_empty(mir.basic_blocks().len());
+
     for (block, data) in mir.basic_blocks().iter_enumerated() {
         if let TerminatorKind::Yield { .. } = data.terminator().kind {
+            suspending_blocks.insert(block);
+
             let loc = Location {
                 block: block,
                 statement_index: data.statements.len(),
@@ -488,7 +495,7 @@ fn locals_live_across_suspend_points(
     // The generator argument is ignored
     set.remove(self_arg());
 
-    (set, storage_liveness_map)
+    (set, storage_liveness_map, suspending_blocks)
 }
 
 fn compute_layout<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
@@ -497,15 +504,14 @@ fn compute_layout<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
                             interior: Ty<'tcx>,
                             movable: bool,
                             mir: &mut Mir<'tcx>)
-    -> (FxHashMap<Local, (Ty<'tcx>, usize)>,
+    -> (FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>,
         GeneratorLayout<'tcx>,
         FxHashMap<BasicBlock, liveness::LiveVarSet>)
 {
     // Use a liveness analysis to compute locals which are live across a suspension point
-    let (live_locals, storage_liveness) = locals_live_across_suspend_points(tcx,
-                                                                            mir,
-                                                                            source,
-                                                                            movable);
+    let (live_locals, storage_liveness, suspending_blocks) =
+        locals_live_across_suspend_points(tcx, mir, source, movable);
+
     // Erase regions from the types passed in from typeck so we can compare them with
     // MIR types
     let allowed_upvars = tcx.erase_regions(upvars);
@@ -531,7 +537,6 @@ fn compute_layout<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
         }
     }
 
-    let upvar_len = upvars.len();
     let dummy_local = LocalDecl::new_internal(tcx.mk_unit(), mir.span);
 
     // Gather live locals and their indices replacing values in mir.local_decls with a dummy
@@ -541,38 +546,44 @@ fn compute_layout<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
         (local, var)
     });
 
+    // For now we will access everything via variant #3, leaving empty variants
+    // for the UNRESUMED, RETURNED, and POISONED states.
+    // If there were a yield-less generator without a variant #3, it would not
+    // have any vars to remap, so we would never use this.
+    let variant_index = VariantIdx::new(3);
+
     // Create a map from local indices to generator struct indices.
-    // These are offset by (upvar_len + 1) because of fields which comes before locals.
     // We also create a vector of the LocalDecls of these locals.
     let (remap, vars) = live_decls.enumerate().map(|(idx, (local, var))| {
-        ((local, (var.ty, upvar_len + 1 + idx)), var)
+        ((local, (var.ty, variant_index, idx)), var)
     }).unzip();
 
+    // Put every var in each variant, for now.
+    let empty_variants = iter::repeat(vec![]).take(3);
+    let state_variants = iter::repeat(vars).take(suspending_blocks.count());
     let layout = GeneratorLayout {
-        // Put everything in one variant, for now.
-        variant_fields: vec![vars]
+        variant_fields: empty_variants.chain(state_variants).collect()
     };
 
     (remap, layout, storage_liveness)
 }
 
-fn insert_switch<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
-                           mir: &mut Mir<'tcx>,
-                           cases: Vec<(u32, BasicBlock)>,
+fn insert_switch<'a, 'tcx>(mir: &mut Mir<'tcx>,
+                           cases: Vec<(usize, BasicBlock)>,
                            transform: &TransformVisitor<'a, 'tcx>,
                            default: TerminatorKind<'tcx>) {
     let default_block = insert_term_block(mir, default);
-
+    let (assign, discr) = transform.get_discr(mir);
     let switch = TerminatorKind::SwitchInt {
-        discr: Operand::Copy(transform.make_field(transform.state_field, tcx.types.u32)),
-        switch_ty: tcx.types.u32,
-        values: Cow::from(cases.iter().map(|&(i, _)| i.into()).collect::<Vec<_>>()),
-        targets: cases.iter().map(|&(_, d)| d).chain(once(default_block)).collect(),
+        discr: Operand::Move(discr),
+        switch_ty: transform.discr_ty,
+        values: Cow::from(cases.iter().map(|&(i, _)| i as u128).collect::<Vec<_>>()),
+        targets: cases.iter().map(|&(_, d)| d).chain(iter::once(default_block)).collect(),
     };
 
     let source_info = source_info(mir);
     mir.basic_blocks_mut().raw.insert(0, BasicBlockData {
-        statements: Vec::new(),
+        statements: vec![assign],
         terminator: Some(Terminator {
             source_info,
             kind: switch,
@@ -657,7 +668,7 @@ fn create_generator_drop_shim<'a, 'tcx>(
     // The returned state and the poisoned state fall through to the default
     // case which is just to return
 
-    insert_switch(tcx, &mut mir, cases, &transform, TerminatorKind::Return);
+    insert_switch(&mut mir, cases, &transform, TerminatorKind::Return);
 
     for block in mir.basic_blocks_mut() {
         let kind = &mut block.terminator_mut().kind;
@@ -771,7 +782,8 @@ fn create_generator_resume_function<'a, 'tcx>(
     for block in mir.basic_blocks_mut() {
         let source_info = block.terminator().source_info;
         if let &TerminatorKind::Resume = &block.terminator().kind {
-            block.statements.push(transform.set_state(POISONED, source_info));
+            block.statements.push(
+                transform.set_discr(VariantIdx::new(POISONED), source_info));
         }
     }
 
@@ -789,7 +801,7 @@ fn create_generator_resume_function<'a, 'tcx>(
     // Panic when resumed on the poisoned state
     cases.insert(2, (POISONED, insert_panic_block(tcx, mir, GeneratorResumedAfterPanic)));
 
-    insert_switch(tcx, mir, cases, &transform, TerminatorKind::Unreachable);
+    insert_switch(mir, cases, &transform, TerminatorKind::Unreachable);
 
     make_generator_state_argument_indirect(tcx, def_id, mir);
     make_generator_state_argument_pinned(tcx, mir);
@@ -835,7 +847,7 @@ fn insert_clean_drop<'a, 'tcx>(mir: &mut Mir<'tcx>) -> BasicBlock {
 
 fn create_cases<'a, 'tcx, F>(mir: &mut Mir<'tcx>,
                           transform: &TransformVisitor<'a, 'tcx>,
-                          target: F) -> Vec<(u32, BasicBlock)>
+                          target: F) -> Vec<(usize, BasicBlock)>
     where F: Fn(&SuspensionPoint) -> Option<BasicBlock> {
     let source_info = source_info(mir);
 
@@ -927,8 +939,6 @@ impl MirPass for StateTransform {
             movable,
             mir);
 
-        let state_field = upvars.len();
-
         // Run the transformation which converts Places from Local to generator struct
         // accesses for locals in `remap`.
         // It also rewrites `return x` and `yield y` as writing a new generator state and returning
@@ -941,7 +951,6 @@ impl MirPass for StateTransform {
             storage_liveness,
             suspension_points: Vec::new(),
             new_ret_local,
-            state_field,
             discr_ty,
         };
         transform.visit_mir(mir);
diff --git a/src/test/mir-opt/generator-drop-cleanup.rs b/src/test/mir-opt/generator-drop-cleanup.rs
index 48398691271..9cc4272fafa 100644
--- a/src/test/mir-opt/generator-drop-cleanup.rs
+++ b/src/test/mir-opt/generator-drop-cleanup.rs
@@ -13,7 +13,8 @@ fn main() {
 
 // START rustc.main-{{closure}}.generator_drop.0.mir
 // bb0: {
-//     switchInt(((*_1).0: u32)) -> [0u32: bb4, 3u32: bb7, otherwise: bb8];
+//     _5 = discriminant((*_1));
+//     switchInt(move _5) -> [0u32: bb4, 3u32: bb7, otherwise: bb8];
 // }
 // bb1: {
 //     goto -> bb5;