about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2016-08-13 01:20:46 -0700
committerGitHub <noreply@github.com>2016-08-13 01:20:46 -0700
commite64f68817d850ccbe642d7f067083bc655115d84 (patch)
tree8676503b26f5871dd3d83a5276fc5f2ce35b15fe /src
parentd3c3de8abe63f738113874267dad3b92a1965ecd (diff)
parentd77a136437a38535522fb3636d165edd4ed49df0 (diff)
downloadrust-e64f68817d850ccbe642d7f067083bc655115d84.tar.gz
rust-e64f68817d850ccbe642d7f067083bc655115d84.zip
Auto merge of #35348 - scottcarr:discriminant2, r=nikomatsakis
[MIR] Add explicit SetDiscriminant StatementKind for deaggregating enums

cc #35186

To deaggregate enums, we need to be able to explicitly set the discriminant.  This PR implements a new StatementKind that does that.

I think some of the places that have `panics!` now could maybe do something smarter.
Diffstat (limited to 'src')
-rw-r--r--src/librustc/mir/repr.rs6
-rw-r--r--src/librustc/mir/visit.rs3
-rw-r--r--src/librustc_borrowck/borrowck/mir/dataflow/impls.rs3
-rw-r--r--src/librustc_borrowck/borrowck/mir/dataflow/sanity_check.rs3
-rw-r--r--src/librustc_borrowck/borrowck/mir/gather_moves.rs4
-rw-r--r--src/librustc_borrowck/borrowck/mir/mod.rs3
-rw-r--r--src/librustc_mir/transform/deaggregator.rs45
-rw-r--r--src/librustc_mir/transform/promote_consts.rs22
-rw-r--r--src/librustc_mir/transform/type_check.rs23
-rw-r--r--src/librustc_trans/mir/constant.rs3
-rw-r--r--src/librustc_trans/mir/statement.rs14
-rw-r--r--src/test/mir-opt/deaggregator_test_enum.rs45
12 files changed, 156 insertions, 18 deletions
diff --git a/src/librustc/mir/repr.rs b/src/librustc/mir/repr.rs
index 93507246241..08614ca253b 100644
--- a/src/librustc/mir/repr.rs
+++ b/src/librustc/mir/repr.rs
@@ -689,13 +689,17 @@ pub struct Statement<'tcx> {
 #[derive(Clone, Debug, RustcEncodable, RustcDecodable)]
 pub enum StatementKind<'tcx> {
     Assign(Lvalue<'tcx>, Rvalue<'tcx>),
+    SetDiscriminant{ lvalue: Lvalue<'tcx>, variant_index: usize },
 }
 
 impl<'tcx> Debug for Statement<'tcx> {
     fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
         use self::StatementKind::*;
         match self.kind {
-            Assign(ref lv, ref rv) => write!(fmt, "{:?} = {:?}", lv, rv)
+            Assign(ref lv, ref rv) => write!(fmt, "{:?} = {:?}", lv, rv),
+            SetDiscriminant{lvalue: ref lv, variant_index: index} => {
+                write!(fmt, "discriminant({:?}) = {:?}", lv, index)
+            }
         }
     }
 }
diff --git a/src/librustc/mir/visit.rs b/src/librustc/mir/visit.rs
index 3f714ff4d51..d44f00ed2cb 100644
--- a/src/librustc/mir/visit.rs
+++ b/src/librustc/mir/visit.rs
@@ -323,6 +323,9 @@ macro_rules! make_mir_visitor {
                                           ref $($mutability)* rvalue) => {
                         self.visit_assign(block, lvalue, rvalue);
                     }
+                    StatementKind::SetDiscriminant{ ref $($mutability)* lvalue, .. } => {
+                        self.visit_lvalue(lvalue, LvalueContext::Store);
+                    }
                 }
             }
 
diff --git a/src/librustc_borrowck/borrowck/mir/dataflow/impls.rs b/src/librustc_borrowck/borrowck/mir/dataflow/impls.rs
index 932b7485201..57b335bd5ee 100644
--- a/src/librustc_borrowck/borrowck/mir/dataflow/impls.rs
+++ b/src/librustc_borrowck/borrowck/mir/dataflow/impls.rs
@@ -442,6 +442,9 @@ impl<'a, 'tcx> BitDenotation for MovingOutStatements<'a, 'tcx> {
         }
         let bits_per_block = self.bits_per_block(ctxt);
         match stmt.kind {
+            repr::StatementKind::SetDiscriminant { .. } => {
+                span_bug!(stmt.source_info.span, "SetDiscriminant should not exist in borrowck");
+            }
             repr::StatementKind::Assign(ref lvalue, _) => {
                 // assigning into this `lvalue` kills all
                 // MoveOuts from it, and *also* all MoveOuts
diff --git a/src/librustc_borrowck/borrowck/mir/dataflow/sanity_check.rs b/src/librustc_borrowck/borrowck/mir/dataflow/sanity_check.rs
index d59bdf93f32..ccde429a171 100644
--- a/src/librustc_borrowck/borrowck/mir/dataflow/sanity_check.rs
+++ b/src/librustc_borrowck/borrowck/mir/dataflow/sanity_check.rs
@@ -104,6 +104,9 @@ fn each_block<'a, 'tcx, O>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
             repr::StatementKind::Assign(ref lvalue, ref rvalue) => {
                 (lvalue, rvalue)
             }
+            repr::StatementKind::SetDiscriminant{ .. } =>
+                span_bug!(stmt.source_info.span,
+                          "sanity_check should run before Deaggregator inserts SetDiscriminant"),
         };
 
         if lvalue == peek_arg_lval {
diff --git a/src/librustc_borrowck/borrowck/mir/gather_moves.rs b/src/librustc_borrowck/borrowck/mir/gather_moves.rs
index 05412216d48..e965dcc169c 100644
--- a/src/librustc_borrowck/borrowck/mir/gather_moves.rs
+++ b/src/librustc_borrowck/borrowck/mir/gather_moves.rs
@@ -616,6 +616,10 @@ fn gather_moves<'a, 'tcx>(mir: &Mir<'tcx>, tcx: TyCtxt<'a, 'tcx, 'tcx>) -> MoveD
                         Rvalue::InlineAsm { .. } => {}
                     }
                 }
+                StatementKind::SetDiscriminant{ .. } => {
+                    span_bug!(stmt.source_info.span,
+                              "SetDiscriminant should not exist during borrowck");
+                }
             }
         }
 
diff --git a/src/librustc_borrowck/borrowck/mir/mod.rs b/src/librustc_borrowck/borrowck/mir/mod.rs
index 7c912e8bac6..c563fdb8f44 100644
--- a/src/librustc_borrowck/borrowck/mir/mod.rs
+++ b/src/librustc_borrowck/borrowck/mir/mod.rs
@@ -369,6 +369,9 @@ fn drop_flag_effects_for_location<'a, 'tcx, F>(
     let block = &mir[loc.block];
     match block.statements.get(loc.index) {
         Some(stmt) => match stmt.kind {
+            repr::StatementKind::SetDiscriminant{ .. } => {
+                span_bug!(stmt.source_info.span, "SetDiscrimant should not exist during borrowck");
+            }
             repr::StatementKind::Assign(ref lvalue, _) => {
                 debug!("drop_flag_effects: assignment {:?}", stmt);
                  on_all_children_bits(tcx, mir, move_data,
diff --git a/src/librustc_mir/transform/deaggregator.rs b/src/librustc_mir/transform/deaggregator.rs
index fccd4a607fd..cd6f0ed9cba 100644
--- a/src/librustc_mir/transform/deaggregator.rs
+++ b/src/librustc_mir/transform/deaggregator.rs
@@ -39,7 +39,7 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
 
         let mut curr: usize = 0;
         for bb in mir.basic_blocks_mut() {
-            let idx = match get_aggregate_statement(curr, &bb.statements) {
+            let idx = match get_aggregate_statement_index(curr, &bb.statements) {
                 Some(idx) => idx,
                 None => continue,
             };
@@ -48,7 +48,11 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
             let src_info = bb.statements[idx].source_info;
             let suffix_stmts = bb.statements.split_off(idx+1);
             let orig_stmt = bb.statements.pop().unwrap();
-            let StatementKind::Assign(ref lhs, ref rhs) = orig_stmt.kind;
+            let (lhs, rhs) = match orig_stmt.kind {
+                StatementKind::Assign(ref lhs, ref rhs) => (lhs, rhs),
+                StatementKind::SetDiscriminant{ .. } =>
+                    span_bug!(src_info.span, "expected aggregate, not {:?}", orig_stmt.kind),
+            };
             let (agg_kind, operands) = match rhs {
                 &Rvalue::Aggregate(ref agg_kind, ref operands) => (agg_kind, operands),
                 _ => span_bug!(src_info.span, "expected aggregate, not {:?}", rhs),
@@ -64,10 +68,14 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
                 let ty = variant_def.fields[i].ty(tcx, substs);
                 let rhs = Rvalue::Use(op.clone());
 
-                // since we don't handle enums, we don't need a cast
-                let lhs_cast = lhs.clone();
-
-                // FIXME we cannot deaggregate enums issue: #35186
+                let lhs_cast = if adt_def.variants.len() > 1 {
+                    Lvalue::Projection(Box::new(LvalueProjection {
+                        base: lhs.clone(),
+                        elem: ProjectionElem::Downcast(adt_def, variant),
+                    }))
+                } else {
+                    lhs.clone()
+                };
 
                 let lhs_proj = Lvalue::Projection(Box::new(LvalueProjection {
                     base: lhs_cast,
@@ -80,18 +88,34 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
                 debug!("inserting: {:?} @ {:?}", new_statement, idx + i);
                 bb.statements.push(new_statement);
             }
+
+            // if the aggregate was an enum, we need to set the discriminant
+            if adt_def.variants.len() > 1 {
+                let set_discriminant = Statement {
+                    kind: StatementKind::SetDiscriminant {
+                        lvalue: lhs.clone(),
+                        variant_index: variant,
+                    },
+                    source_info: src_info,
+                };
+                bb.statements.push(set_discriminant);
+            };
+
             curr = bb.statements.len();
             bb.statements.extend(suffix_stmts);
         }
     }
 }
 
-fn get_aggregate_statement<'a, 'tcx, 'b>(curr: usize,
+fn get_aggregate_statement_index<'a, 'tcx, 'b>(start: usize,
                                          statements: &Vec<Statement<'tcx>>)
                                          -> Option<usize> {
-    for i in curr..statements.len() {
+    for i in start..statements.len() {
         let ref statement = statements[i];
-        let StatementKind::Assign(_, ref rhs) = statement.kind;
+        let rhs = match statement.kind {
+            StatementKind::Assign(_, ref rhs) => rhs,
+            StatementKind::SetDiscriminant{ .. } => continue,
+        };
         let (kind, operands) = match rhs {
             &Rvalue::Aggregate(ref kind, ref operands) => (kind, operands),
             _ => continue,
@@ -100,9 +124,8 @@ fn get_aggregate_statement<'a, 'tcx, 'b>(curr: usize,
             &AggregateKind::Adt(adt_def, variant, _) => (adt_def, variant),
             _ => continue,
         };
-        if operands.len() == 0 || adt_def.variants.len() > 1 {
+        if operands.len() == 0 {
             // don't deaggregate ()
-            // don't deaggregate enums ... for now
             continue;
         }
         debug!("getting variant {:?}", variant);
diff --git a/src/librustc_mir/transform/promote_consts.rs b/src/librustc_mir/transform/promote_consts.rs
index fa3490cbcf3..eb0d8697f15 100644
--- a/src/librustc_mir/transform/promote_consts.rs
+++ b/src/librustc_mir/transform/promote_consts.rs
@@ -219,7 +219,13 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
         let (mut rvalue, mut call) = (None, None);
         let source_info = if stmt_idx < no_stmts {
             let statement = &mut self.source[bb].statements[stmt_idx];
-            let StatementKind::Assign(_, ref mut rhs) = statement.kind;
+            let mut rhs = match statement.kind {
+                StatementKind::Assign(_, ref mut rhs) => rhs,
+                StatementKind::SetDiscriminant{ .. } =>
+                    span_bug!(statement.source_info.span,
+                              "cannot promote SetDiscriminant {:?}",
+                              statement),
+            };
             if self.keep_original {
                 rvalue = Some(rhs.clone());
             } else {
@@ -300,10 +306,16 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
         });
         let mut rvalue = match candidate {
             Candidate::Ref(Location { block: bb, statement_index: stmt_idx }) => {
-                match self.source[bb].statements[stmt_idx].kind {
+                let ref mut statement = self.source[bb].statements[stmt_idx];
+                match statement.kind {
                     StatementKind::Assign(_, ref mut rvalue) => {
                         mem::replace(rvalue, Rvalue::Use(new_operand))
                     }
+                    StatementKind::SetDiscriminant{ .. } => {
+                        span_bug!(statement.source_info.span,
+                                  "cannot promote SetDiscriminant {:?}",
+                                  statement);
+                    }
                 }
             }
             Candidate::ShuffleIndices(bb) => {
@@ -340,7 +352,11 @@ pub fn promote_candidates<'a, 'tcx>(mir: &mut Mir<'tcx>,
         let (span, ty) = match candidate {
             Candidate::Ref(Location { block: bb, statement_index: stmt_idx }) => {
                 let statement = &mir[bb].statements[stmt_idx];
-                let StatementKind::Assign(ref dest, _) = statement.kind;
+                let dest = match statement.kind {
+                    StatementKind::Assign(ref dest, _) => dest,
+                    StatementKind::SetDiscriminant{ .. } =>
+                        panic!("cannot promote SetDiscriminant"),
+                };
                 if let Lvalue::Temp(index) = *dest {
                     if temps[index] == TempState::PromotedOut {
                         // Already promoted.
diff --git a/src/librustc_mir/transform/type_check.rs b/src/librustc_mir/transform/type_check.rs
index 52f41741b08..934357c9e1d 100644
--- a/src/librustc_mir/transform/type_check.rs
+++ b/src/librustc_mir/transform/type_check.rs
@@ -14,7 +14,7 @@
 use rustc::infer::{self, InferCtxt, InferOk};
 use rustc::traits::{self, Reveal};
 use rustc::ty::fold::TypeFoldable;
-use rustc::ty::{self, Ty, TyCtxt};
+use rustc::ty::{self, Ty, TyCtxt, TypeVariants};
 use rustc::mir::repr::*;
 use rustc::mir::tcx::LvalueTy;
 use rustc::mir::transform::{MirPass, MirSource, Pass};
@@ -360,10 +360,27 @@ impl<'a, 'gcx, 'tcx> TypeChecker<'a, 'gcx, 'tcx> {
                         span_mirbug!(self, stmt, "bad assignment ({:?} = {:?}): {:?}",
                                      lv_ty, rv_ty, terr);
                     }
-                }
-
                 // FIXME: rvalue with undeterminable type - e.g. inline
                 // asm.
+                }
+            }
+            StatementKind::SetDiscriminant{ ref lvalue, variant_index } => {
+                let lvalue_type = lvalue.ty(mir, tcx).to_ty(tcx);
+                let adt = match lvalue_type.sty {
+                    TypeVariants::TyEnum(adt, _) => adt,
+                    _ => {
+                        span_bug!(stmt.source_info.span,
+                                  "bad set discriminant ({:?} = {:?}): lhs is not an enum",
+                                  lvalue,
+                                  variant_index);
+                    }
+                };
+                if variant_index >= adt.variants.len() {
+                     span_bug!(stmt.source_info.span,
+                               "bad set discriminant ({:?} = {:?}): value of of range",
+                               lvalue,
+                               variant_index);
+                };
             }
         }
     }
diff --git a/src/librustc_trans/mir/constant.rs b/src/librustc_trans/mir/constant.rs
index 35ded704296..7ca94b6356e 100644
--- a/src/librustc_trans/mir/constant.rs
+++ b/src/librustc_trans/mir/constant.rs
@@ -285,6 +285,9 @@ impl<'a, 'tcx> MirConstContext<'a, 'tcx> {
                             Err(err) => if failure.is_ok() { failure = Err(err); }
                         }
                     }
+                    mir::StatementKind::SetDiscriminant{ .. } => {
+                        span_bug!(span, "SetDiscriminant should not appear in constants?");
+                    }
                 }
             }
 
diff --git a/src/librustc_trans/mir/statement.rs b/src/librustc_trans/mir/statement.rs
index 44d264c7e98..7e3074f4ced 100644
--- a/src/librustc_trans/mir/statement.rs
+++ b/src/librustc_trans/mir/statement.rs
@@ -14,6 +14,8 @@ use common::{self, BlockAndBuilder};
 
 use super::MirContext;
 use super::LocalRef;
+use super::super::adt;
+use super::super::disr::Disr;
 
 impl<'bcx, 'tcx> MirContext<'bcx, 'tcx> {
     pub fn trans_statement(&mut self,
@@ -57,6 +59,18 @@ impl<'bcx, 'tcx> MirContext<'bcx, 'tcx> {
                     self.trans_rvalue(bcx, tr_dest, rvalue, debug_loc)
                 }
             }
+            mir::StatementKind::SetDiscriminant{ref lvalue, variant_index} => {
+                let ty = self.monomorphized_lvalue_ty(lvalue);
+                let repr = adt::represent_type(bcx.ccx(), ty);
+                let lvalue_transed = self.trans_lvalue(&bcx, lvalue);
+                bcx.with_block(|bcx|
+                    adt::trans_set_discr(bcx,
+                                         &repr,
+                                        lvalue_transed.llval,
+                                        Disr::from(variant_index))
+                );
+                bcx
+            }
         }
     }
 }
diff --git a/src/test/mir-opt/deaggregator_test_enum.rs b/src/test/mir-opt/deaggregator_test_enum.rs
new file mode 100644
index 00000000000..ccfa760a28c
--- /dev/null
+++ b/src/test/mir-opt/deaggregator_test_enum.rs
@@ -0,0 +1,45 @@
+// Copyright 2016 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+enum Baz {
+    Empty,
+    Foo { x: usize },
+}
+
+fn bar(a: usize) -> Baz {
+    Baz::Foo { x: a }
+}
+
+fn main() {
+    let x = bar(10);
+    match x {
+        Baz::Empty => println!("empty"),
+        Baz::Foo { x } => println!("{}", x),
+    };
+}
+
+// END RUST SOURCE
+// START rustc.node10.Deaggregator.before.mir
+// bb0: {
+//     var0 = arg0;                     // scope 0 at main.rs:7:8: 7:9
+//     tmp0 = var0;                     // scope 1 at main.rs:8:19: 8:20
+//     return = Baz::Foo { x: tmp0 };   // scope 1 at main.rs:8:5: 8:21
+//     goto -> bb1;                     // scope 1 at main.rs:7:1: 9:2
+// }
+// END rustc.node10.Deaggregator.before.mir
+// START rustc.node10.Deaggregator.after.mir
+// bb0: {
+//     var0 = arg0;                     // scope 0 at main.rs:7:8: 7:9
+//     tmp0 = var0;                     // scope 1 at main.rs:8:19: 8:20
+//     ((return as Foo).0: usize) = tmp0; // scope 1 at main.rs:8:5: 8:21
+//     discriminant(return) = 1;         // scope 1 at main.rs:8:5: 8:21
+//     goto -> bb1;                     // scope 1 at main.rs:7:1: 9:2
+// }
+// END rustc.node10.Deaggregator.after.mir
\ No newline at end of file