about summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlexander <alex.m.vlasov@gmail.com>2021-10-06 17:31:35 +0200
committerAlexander <alex.m.vlasov@gmail.com>2021-10-06 17:31:35 +0200
commit94e1413f60a1725b994f986a607a6f370bb230b4 (patch)
treedc25710cba692958d0d609ce1fec09d59f1de26b
parentd7539a6af09e5889ed9bcb8b49571b7a59c32e65 (diff)
downloadrust-94e1413f60a1725b994f986a607a6f370bb230b4.tar.gz
rust-94e1413f60a1725b994f986a607a6f370bb230b4.zip
reset and cleanup
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs2
-rw-r--r--compiler/rustc_mir_transform/src/normalize_array_len.rs287
-rw-r--r--src/test/mir-opt/lower_array_len.rs50
3 files changed, 339 insertions, 0 deletions
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 0ca640cd7b1..9b11c8f0b24 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -58,6 +58,7 @@ mod lower_intrinsics;
 mod lower_slice_len;
 mod match_branches;
 mod multiple_return_terminators;
+mod normalize_array_len;
 mod nrvo;
 mod remove_noop_landing_pads;
 mod remove_storage_markers;
@@ -488,6 +489,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
     // machine than on MIR with async primitives.
     let optimizations_with_generators: &[&dyn MirPass<'tcx>] = &[
         &lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first
+        &normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering
         &unreachable_prop::UnreachablePropagation,
         &uninhabited_enum_branching::UninhabitedEnumBranching,
         &simplify::SimplifyCfg::new("after-uninhabited-enum-branching"),
diff --git a/compiler/rustc_mir_transform/src/normalize_array_len.rs b/compiler/rustc_mir_transform/src/normalize_array_len.rs
new file mode 100644
index 00000000000..e6ec7171a47
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/normalize_array_len.rs
@@ -0,0 +1,287 @@
+//! This pass eliminates casting of arrays into slices when their length
+//! is taken using `.len()` method. Handy to preserve information in MIR for const prop
+
+use crate::transform::MirPass;
+use rustc_data_structures::fx::FxIndexMap;
+use rustc_index::bit_set::BitSet;
+use rustc_index::vec::IndexVec;
+use rustc_middle::mir::*;
+use rustc_middle::ty::{self, TyCtxt};
+
+const MAX_NUM_BLOCKS: usize = 800;
+const MAX_NUM_LOCALS: usize = 3000;
+
+pub struct NormalizeArrayLen;
+
+impl<'tcx> MirPass<'tcx> for NormalizeArrayLen {
+    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+        // if tcx.sess.mir_opt_level() < 3 {
+        //     return;
+        // }
+
+        // early returns for edge cases of highly unrolled functions
+        if body.basic_blocks().len() > MAX_NUM_BLOCKS {
+            return;
+        }
+        if body.local_decls().len() > MAX_NUM_LOCALS {
+            return;
+        }
+        normalize_array_len_calls(tcx, body)
+    }
+}
+
+pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+    let (basic_blocks, local_decls) = body.basic_blocks_and_local_decls_mut();
+
+    // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]`
+    let mut interesting_locals = BitSet::new_empty(local_decls.len());
+    for (local, decl) in local_decls.iter_enumerated() {
+        match decl.ty.kind() {
+            ty::Array(..) => {
+                interesting_locals.insert(local);
+            }
+            ty::Ref(.., ty, Mutability::Not) => match ty.kind() {
+                ty::Array(..) => {
+                    interesting_locals.insert(local);
+                }
+                _ => {}
+            },
+            _ => {}
+        }
+    }
+    if interesting_locals.is_empty() {
+        // we have found nothing to analyze
+        return;
+    }
+    let num_intesting_locals = interesting_locals.count();
+    let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
+    let mut patches_scratchpad =
+        FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
+    let mut replacements_scratchpad =
+        FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
+    for block in basic_blocks {
+        // make length calls for arrays [T; N] not to decay into length calls for &[T]
+        // that forbids constant propagation
+        normalize_array_len_call(
+            tcx,
+            block,
+            local_decls,
+            &interesting_locals,
+            &mut state,
+            &mut patches_scratchpad,
+            &mut replacements_scratchpad,
+        );
+        state.clear();
+        patches_scratchpad.clear();
+        replacements_scratchpad.clear();
+    }
+}
+
+struct Patcher<'a, 'tcx> {
+    tcx: TyCtxt<'tcx>,
+    patches_scratchpad: &'a FxIndexMap<usize, usize>,
+    replacements_scratchpad: &'a mut FxIndexMap<usize, Local>,
+    local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>,
+    statement_idx: usize,
+}
+
+impl<'a, 'tcx> Patcher<'a, 'tcx> {
+    fn patch_expand_statement(
+        &mut self,
+        statement: &mut Statement<'tcx>,
+    ) -> Option<std::vec::IntoIter<Statement<'tcx>>> {
+        let idx = self.statement_idx;
+        if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() {
+            let mut statements = Vec::with_capacity(2);
+
+            // we are at statement that performs a cast. The only sound way is
+            // to create another local that performs a similar copy without a cast and then
+            // use this copy in the Len operation
+
+            match &statement.kind {
+                StatementKind::Assign(box (
+                    ..,
+                    Rvalue::Cast(
+                        CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
+                        operand,
+                        _,
+                    ),
+                )) => {
+                    match operand {
+                        Operand::Copy(place) | Operand::Move(place) => {
+                            // create new local
+                            let ty = operand.ty(self.local_decls, self.tcx);
+                            let local_decl =
+                                LocalDecl::with_source_info(ty, statement.source_info.clone());
+                            let local = self.local_decls.push(local_decl);
+                            // make it live
+                            let mut make_live_statement = statement.clone();
+                            make_live_statement.kind = StatementKind::StorageLive(local);
+                            statements.push(make_live_statement);
+                            // copy into it
+
+                            let operand = Operand::Copy(*place);
+                            let mut make_copy_statement = statement.clone();
+                            let assign_to = Place::from(local);
+                            let rvalue = Rvalue::Use(operand);
+                            make_copy_statement.kind =
+                                StatementKind::Assign(box (assign_to, rvalue));
+                            statements.push(make_copy_statement);
+
+                            // to reorder we have to copy and make NOP
+                            statements.push(statement.clone());
+                            statement.make_nop();
+
+                            self.replacements_scratchpad.insert(len_statemnt_idx, local);
+                        }
+                        _ => {
+                            unreachable!("it's a bug in the implementation")
+                        }
+                    }
+                }
+                _ => {
+                    unreachable!("it's a bug in the implementation")
+                }
+            }
+
+            self.statement_idx += 1;
+
+            Some(statements.into_iter())
+        } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() {
+            let mut statements = Vec::with_capacity(2);
+
+            match &statement.kind {
+                StatementKind::Assign(box (into, Rvalue::Len(place))) => {
+                    let add_deref = if let Some(..) = place.as_local() {
+                        false
+                    } else if let Some(..) = place.local_or_deref_local() {
+                        true
+                    } else {
+                        unreachable!("it's a bug in the implementation")
+                    };
+                    // replace len statement
+                    let mut len_statement = statement.clone();
+                    let mut place = Place::from(local);
+                    if add_deref {
+                        place = self.tcx.mk_place_deref(place);
+                    }
+                    len_statement.kind = StatementKind::Assign(box (*into, Rvalue::Len(place)));
+                    statements.push(len_statement);
+
+                    // make temporary dead
+                    let mut make_dead_statement = statement.clone();
+                    make_dead_statement.kind = StatementKind::StorageDead(local);
+                    statements.push(make_dead_statement);
+
+                    // make original statement NOP
+                    statement.make_nop();
+                }
+                _ => {
+                    unreachable!("it's a bug in the implementation")
+                }
+            }
+
+            self.statement_idx += 1;
+
+            Some(statements.into_iter())
+        } else {
+            self.statement_idx += 1;
+            None
+        }
+    }
+}
+
+fn normalize_array_len_call<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    block: &mut BasicBlockData<'tcx>,
+    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
+    interesting_locals: &BitSet<Local>,
+    state: &mut FxIndexMap<Local, usize>,
+    patches_scratchpad: &mut FxIndexMap<usize, usize>,
+    replacements_scratchpad: &mut FxIndexMap<usize, Local>,
+) {
+    for (statement_idx, statement) in block.statements.iter_mut().enumerate() {
+        match &mut statement.kind {
+            StatementKind::Assign(box (place, rvalue)) => {
+                match rvalue {
+                    Rvalue::Cast(
+                        CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
+                        operand,
+                        cast_ty,
+                    ) => {
+                        let local = if let Some(local) = place.as_local() { local } else { return };
+                        match operand {
+                            Operand::Copy(place) | Operand::Move(place) => {
+                                let operand_local =
+                                    if let Some(local) = place.local_or_deref_local() {
+                                        local
+                                    } else {
+                                        return;
+                                    };
+                                if !interesting_locals.contains(operand_local) {
+                                    return;
+                                }
+                                let operand_ty = local_decls[operand_local].ty;
+                                match (operand_ty.kind(), cast_ty.kind()) {
+                                    (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
+                                        if of_ty_src == of_ty_dst {
+                                            // this is a cast from [T; N] into [T], so we are good
+                                            state.insert(local, statement_idx);
+                                        }
+                                    }
+                                    // current way of patching doesn't allow to work with `mut`
+                                    (
+                                        ty::Ref(
+                                            ty::RegionKind::ReErased,
+                                            operand_ty,
+                                            Mutability::Not,
+                                        ),
+                                        ty::Ref(ty::RegionKind::ReErased, cast_ty, Mutability::Not),
+                                    ) => {
+                                        match (operand_ty.kind(), cast_ty.kind()) {
+                                            // current way of patching doesn't allow to work with `mut`
+                                            (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
+                                                if of_ty_src == of_ty_dst {
+                                                    // this is a cast from [T; N] into [T], so we are good
+                                                    state.insert(local, statement_idx);
+                                                }
+                                            }
+                                            _ => {}
+                                        }
+                                    }
+                                    _ => {}
+                                }
+                            }
+                            _ => {}
+                        }
+                    }
+                    Rvalue::Len(place) => {
+                        let local = if let Some(local) = place.local_or_deref_local() {
+                            local
+                        } else {
+                            return;
+                        };
+                        if let Some(cast_statement_idx) = state.get(&local).copied() {
+                            patches_scratchpad.insert(cast_statement_idx, statement_idx);
+                        }
+                    }
+                    _ => {
+                        // invalidate
+                        state.remove(&place.local);
+                    }
+                }
+            }
+            _ => {}
+        }
+    }
+
+    let mut patcher = Patcher {
+        tcx,
+        patches_scratchpad: &*patches_scratchpad,
+        replacements_scratchpad,
+        local_decls,
+        statement_idx: 0,
+    };
+
+    block.expand_statements(|st| patcher.patch_expand_statement(st));
+}
diff --git a/src/test/mir-opt/lower_array_len.rs b/src/test/mir-opt/lower_array_len.rs
new file mode 100644
index 00000000000..2b4845bc151
--- /dev/null
+++ b/src/test/mir-opt/lower_array_len.rs
@@ -0,0 +1,50 @@
+
+Viewed
+@@ -0,0 +1,47 @@
+// compile-flags: -Z mir-opt-level=3
+
+// EMIT_MIR lower_array_len.array_bound.NormalizeArrayLen.diff
+// EMIT_MIR lower_array_len.array_bound.SimplifyLocals.diff
+// EMIT_MIR lower_array_len.array_bound.InstCombine.diff
+pub fn array_bound<const N: usize>(index: usize, slice: &[u8; N]) -> u8 {
+    if index < slice.len() {
+        slice[index]
+    } else {
+        42
+    }
+}
+
+// EMIT_MIR lower_array_len.array_bound_mut.NormalizeArrayLen.diff
+// EMIT_MIR lower_array_len.array_bound_mut.SimplifyLocals.diff
+// EMIT_MIR lower_array_len.array_bound_mut.InstCombine.diff
+pub fn array_bound_mut<const N: usize>(index: usize, slice: &mut [u8; N]) -> u8 {
+    if index < slice.len() {
+        slice[index]
+    } else {
+        slice[0] = 42;
+
+        42
+    }
+}
+
+// EMIT_MIR lower_array_len.array_len.NormalizeArrayLen.diff
+// EMIT_MIR lower_array_len.array_len.SimplifyLocals.diff
+// EMIT_MIR lower_array_len.array_len.InstCombine.diff
+pub fn array_len<const N: usize>(arr: &[u8; N]) -> usize {
+    arr.len()
+}
+
+// EMIT_MIR lower_array_len.array_len_by_value.NormalizeArrayLen.diff
+// EMIT_MIR lower_array_len.array_len_by_value.SimplifyLocals.diff
+// EMIT_MIR lower_array_len.array_len_by_value.InstCombine.diff
+pub fn array_len_by_value<const N: usize>(arr: [u8; N]) -> usize {
+    arr.len()
+}
+
+fn main() {
+    let _ = array_bound(3, &[0, 1, 2, 3]);
+    let mut tmp = [0, 1, 2, 3, 4];
+    let _ = array_bound_mut(3, &mut [0, 1, 2, 3]);
+    let _ = array_len(&[0]);
+    let _ = array_len_by_value([0, 2]);
+}
\ No newline at end of file