about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEduard Burtescu <edy.burt@gmail.com>2016-03-09 14:20:50 +0200
committerEduard Burtescu <edy.burt@gmail.com>2016-03-17 21:51:55 +0200
commit41499f45635d94003a9461c059d9b395b1a6e3ea (patch)
tree5a0934fba8a418e71a312c857010ae1a44176d8c
parentd9277b163c36448c5fbc39fa089a78256b45ffc1 (diff)
downloadrust-41499f45635d94003a9461c059d9b395b1a6e3ea.tar.gz
rust-41499f45635d94003a9461c059d9b395b1a6e3ea.zip
mir: Match against slices by calling PartialEq::eq.
-rw-r--r--src/librustc_mir/build/matches/test.rs93
-rw-r--r--src/librustc_mir/hair/cx/mod.rs30
2 files changed, 99 insertions, 24 deletions
diff --git a/src/librustc_mir/build/matches/test.rs b/src/librustc_mir/build/matches/test.rs
index bdec261ce65..4e3a69bf745 100644
--- a/src/librustc_mir/build/matches/test.rs
+++ b/src/librustc_mir/build/matches/test.rs
@@ -174,33 +174,78 @@ impl<'a,'tcx> Builder<'a,'tcx> {
                 targets
             }
 
-            TestKind::Eq { ref value, ty } => {
-                // If we're matching against &[u8] with b"...", we need to insert
-                // an unsizing coercion, as the byte string has type &[u8; N].
-                let expect = match *value {
-                    ConstVal::ByteStr(ref bytes) if ty.is_slice() => {
-                        let tcx = self.hir.tcx();
-                        let array_ty = tcx.mk_array(tcx.types.u8, bytes.len());
-                        let ref_ty = tcx.mk_imm_ref(tcx.mk_region(ty::ReStatic), array_ty);
-                        let array = self.literal_operand(test.span, ref_ty, Literal::Value {
-                            value: value.clone()
-                        });
-
-                        let sliced = self.temp(ty);
-                        self.cfg.push_assign(block, test.span, &sliced,
-                                             Rvalue::Cast(CastKind::Unsize, array, ty));
-                        Operand::Consume(sliced)
-                    }
-                    _ => {
-                        self.literal_operand(test.span, ty, Literal::Value {
-                            value: value.clone()
-                        })
+            TestKind::Eq { ref value, mut ty } => {
+                let mut val = Operand::Consume(lvalue.clone());
+
+                // If we're using b"..." as a pattern, we need to insert an
+                // unsizing coercion, as the byte string has the type &[u8; N].
+                let expect = if let ConstVal::ByteStr(ref bytes) = *value {
+                    let tcx = self.hir.tcx();
+
+                    // Unsize the lvalue to &[u8], too, if necessary.
+                    if let ty::TyRef(region, mt) = ty.sty {
+                        if let ty::TyArray(_, _) = mt.ty.sty {
+                            ty = tcx.mk_imm_ref(region, tcx.mk_slice(tcx.types.u8));
+                            let val_slice = self.temp(ty);
+                            self.cfg.push_assign(block, test.span, &val_slice,
+                                                 Rvalue::Cast(CastKind::Unsize, val, ty));
+                            val = Operand::Consume(val_slice);
+                        }
                     }
+
+                    assert!(ty.is_slice());
+
+                    let array_ty = tcx.mk_array(tcx.types.u8, bytes.len());
+                    let array_ref = tcx.mk_imm_ref(tcx.mk_region(ty::ReStatic), array_ty);
+                    let array = self.literal_operand(test.span, array_ref, Literal::Value {
+                        value: value.clone()
+                    });
+
+                    let slice = self.temp(ty);
+                    self.cfg.push_assign(block, test.span, &slice,
+                                         Rvalue::Cast(CastKind::Unsize, array, ty));
+                    Operand::Consume(slice)
+                } else {
+                    self.literal_operand(test.span, ty, Literal::Value {
+                        value: value.clone()
+                    })
                 };
-                let val = Operand::Consume(lvalue.clone());
+
+                // Use PartialEq::eq for &str and &[u8] slices, instead of BinOp::Eq.
                 let fail = self.cfg.start_new_block();
-                let block = self.compare(block, fail, test.span, BinOp::Eq, expect, val.clone());
-                vec![block, fail]
+                if let ty::TyRef(_, mt) = ty.sty {
+                    assert!(ty.is_slice());
+                    let eq_def_id = self.hir.tcx().lang_items.eq_trait().unwrap();
+                    let ty = mt.ty;
+                    let (mty, method) = self.hir.trait_method(eq_def_id, "eq", ty, vec![ty]);
+
+                    let bool_ty = self.hir.bool_ty();
+                    let eq_result = self.temp(bool_ty);
+                    let eq_block = self.cfg.start_new_block();
+                    let cleanup = self.diverge_cleanup();
+                    self.cfg.terminate(block, Terminator::Call {
+                        func: Operand::Constant(Constant {
+                            span: test.span,
+                            ty: mty,
+                            literal: method
+                        }),
+                        args: vec![val, expect],
+                        destination: Some((eq_result.clone(), eq_block)),
+                        cleanup: cleanup,
+                    });
+
+                    // check the result
+                    let block = self.cfg.start_new_block();
+                    self.cfg.terminate(eq_block, Terminator::If {
+                        cond: Operand::Consume(eq_result),
+                        targets: (block, fail),
+                    });
+
+                    vec![block, fail]
+                } else {
+                    let block = self.compare(block, fail, test.span, BinOp::Eq, expect, val);
+                    vec![block, fail]
+                }
             }
 
             TestKind::Range { ref lo, ref hi, ty } => {
diff --git a/src/librustc_mir/hair/cx/mod.rs b/src/librustc_mir/hair/cx/mod.rs
index d29d895e11d..b97bfaf5aef 100644
--- a/src/librustc_mir/hair/cx/mod.rs
+++ b/src/librustc_mir/hair/cx/mod.rs
@@ -19,7 +19,9 @@ use hair::*;
 use rustc::mir::repr::*;
 
 use rustc::middle::const_eval::{self, ConstVal};
+use rustc::middle::def_id::DefId;
 use rustc::middle::infer::InferCtxt;
+use rustc::middle::subst::{Subst, Substs};
 use rustc::middle::ty::{self, Ty, TyCtxt};
 use syntax::codemap::Span;
 use syntax::parse::token;
@@ -96,6 +98,34 @@ impl<'a,'tcx:'a> Cx<'a, 'tcx> {
         })
     }
 
+    pub fn trait_method(&mut self,
+                        trait_def_id: DefId,
+                        method_name: &str,
+                        self_ty: Ty<'tcx>,
+                        params: Vec<Ty<'tcx>>)
+                        -> (Ty<'tcx>, Literal<'tcx>) {
+        let method_name = token::intern(method_name);
+        let substs = Substs::new_trait(params, vec![], self_ty);
+        for trait_item in self.tcx.trait_items(trait_def_id).iter() {
+            match *trait_item {
+                ty::ImplOrTraitItem::MethodTraitItem(ref method) => {
+                    if method.name == method_name {
+                        let method_ty = self.tcx.lookup_item_type(method.def_id);
+                        let method_ty = method_ty.ty.subst(self.tcx, &substs);
+                        return (method_ty, Literal::Item {
+                            def_id: method.def_id,
+                            substs: self.tcx.mk_substs(substs),
+                        });
+                    }
+                }
+                ty::ImplOrTraitItem::ConstTraitItem(..) |
+                ty::ImplOrTraitItem::TypeTraitItem(..) => {}
+            }
+        }
+
+        self.tcx.sess.bug(&format!("found no method `{}` in `{:?}`", method_name, trait_def_id));
+    }
+
     pub fn num_variants(&mut self, adt_def: ty::AdtDef<'tcx>) -> usize {
         adt_def.variants.len()
     }