summary refs log tree commit diff
path: root/compiler/rustc_mir/src/transform/match_branches.rs
blob: 70ae5474a4f909e42d6981021f5750c71dfec4d4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use crate::transform::{MirPass, MirSource};
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;

pub struct MatchBranchSimplification;

/// If a source block is found that switches between two blocks that are exactly
/// the same modulo const bool assignments (e.g., one assigns true another false
/// to the same place), merge a target block statements into the source block,
/// using Eq / Ne comparison with switch value where const bools value differ.
///
/// For example:
///
/// ```rust
/// bb0: {
///     switchInt(move _3) -> [42_isize: bb1, otherwise: bb2];
/// }
///
/// bb1: {
///     _2 = const true;
///     goto -> bb3;
/// }
///
/// bb2: {
///     _2 = const false;
///     goto -> bb3;
/// }
/// ```
///
/// into:
///
/// ```rust
/// bb0: {
///    _2 = Eq(move _3, const 42_isize);
///    goto -> bb3;
/// }
/// ```

impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
    fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, body: &mut Body<'tcx>) {
        // FIXME: This optimization can result in unsoundness, because it introduces
        // additional uses of a place holding the discriminant value without ensuring that
        // it is valid to do so.
        if !tcx.sess.opts.debugging_opts.unsound_mir_opts {
            return;
        }

        let param_env = tcx.param_env(src.def_id());
        let bbs = body.basic_blocks_mut();
        'outer: for bb_idx in bbs.indices() {
            let (discr, val, switch_ty, first, second) = match bbs[bb_idx].terminator().kind {
                TerminatorKind::SwitchInt {
                    discr: Operand::Copy(ref place) | Operand::Move(ref place),
                    switch_ty,
                    ref targets,
                    ref values,
                    ..
                } if targets.len() == 2 && values.len() == 1 && targets[0] != targets[1] => {
                    (place, values[0], switch_ty, targets[0], targets[1])
                }
                // Only optimize switch int statements
                _ => continue,
            };

            // Check that destinations are identical, and if not, then don't optimize this block
            if &bbs[first].terminator().kind != &bbs[second].terminator().kind {
                continue;
            }

            // Check that blocks are assignments of consts to the same place or same statement,
            // and match up 1-1, if not don't optimize this block.
            let first_stmts = &bbs[first].statements;
            let scnd_stmts = &bbs[second].statements;
            if first_stmts.len() != scnd_stmts.len() {
                continue;
            }
            for (f, s) in first_stmts.iter().zip(scnd_stmts.iter()) {
                match (&f.kind, &s.kind) {
                    // If two statements are exactly the same, we can optimize.
                    (f_s, s_s) if f_s == s_s => {}

                    // If two statements are const bool assignments to the same place, we can optimize.
                    (
                        StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
                        StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
                    ) if lhs_f == lhs_s
                        && f_c.literal.ty.is_bool()
                        && s_c.literal.ty.is_bool()
                        && f_c.literal.try_eval_bool(tcx, param_env).is_some()
                        && s_c.literal.try_eval_bool(tcx, param_env).is_some() => {}

                    // Otherwise we cannot optimize. Try another block.
                    _ => continue 'outer,
                }
            }
            // Take ownership of items now that we know we can optimize.
            let discr = discr.clone();

            // We already checked that first and second are different blocks,
            // and bb_idx has a different terminator from both of them.
            let (from, first, second) = bbs.pick3_mut(bb_idx, first, second);

            let new_stmts = first.statements.iter().zip(second.statements.iter()).map(|(f, s)| {
                match (&f.kind, &s.kind) {
                    (f_s, s_s) if f_s == s_s => (*f).clone(),

                    (
                        StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
                        StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))),
                    ) => {
                        // From earlier loop we know that we are dealing with bool constants only:
                        let f_b = f_c.literal.try_eval_bool(tcx, param_env).unwrap();
                        let s_b = s_c.literal.try_eval_bool(tcx, param_env).unwrap();
                        if f_b == s_b {
                            // Same value in both blocks. Use statement as is.
                            (*f).clone()
                        } else {
                            // Different value between blocks. Make value conditional on switch condition.
                            let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size;
                            let const_cmp = Operand::const_from_scalar(
                                tcx,
                                switch_ty,
                                crate::interpret::Scalar::from_uint(val, size),
                                rustc_span::DUMMY_SP,
                            );
                            let op = if f_b { BinOp::Eq } else { BinOp::Ne };
                            let rhs = Rvalue::BinaryOp(op, Operand::Copy(discr.clone()), const_cmp);
                            Statement {
                                source_info: f.source_info,
                                kind: StatementKind::Assign(box (*lhs, rhs)),
                            }
                        }
                    }

                    _ => unreachable!(),
                }
            });
            from.statements.extend(new_stmts);
            from.terminator_mut().kind = first.terminator().kind.clone();
        }
    }
}