diff options
| author | Camille GILLOT <gillot.camille@gmail.com> | 2023-12-31 01:53:51 +0000 |
|---|---|---|
| committer | Camille GILLOT <gillot.camille@gmail.com> | 2024-01-18 22:53:07 +0000 |
| commit | be9668d3989b721bfaa2ec517b307965476431fa (patch) | |
| tree | 0937caa224afeb543c2cfb529e7f1d6110896f59 | |
| parent | 25f8d01fd8bda339612d0c0a8844173a09205f7c (diff) | |
| download | rust-be9668d3989b721bfaa2ec517b307965476431fa.tar.gz rust-be9668d3989b721bfaa2ec517b307965476431fa.zip | |
Use an interpreter in jump threading.
4 files changed, 197 insertions, 27 deletions
diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs index dcab124505e..fc77076e4e7 100644 --- a/compiler/rustc_mir_transform/src/jump_threading.rs +++ b/compiler/rustc_mir_transform/src/jump_threading.rs @@ -36,16 +36,21 @@ //! cost by `MAX_COST`. use rustc_arena::DroplessArena; +use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable}; use rustc_data_structures::fx::FxHashSet; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; +use rustc_middle::mir::interpret::Scalar; use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; -use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt}; +use rustc_middle::ty::layout::LayoutOf; +use rustc_middle::ty::{self, ScalarInt, TyCtxt}; use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem}; +use rustc_span::DUMMY_SP; use rustc_target::abi::{TagEncoding, Variants}; use crate::cost_checker::CostChecker; +use crate::dataflow_const_prop::DummyMachine; pub struct JumpThreading; @@ -71,6 +76,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading { let mut finder = TOFinder { tcx, param_env, + ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), body, arena: &arena, map: &map, @@ -88,7 +94,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading { debug!(?discr, ?bb); let discr_ty = discr.ty(body, tcx).ty; - let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue }; + let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue }; let Some(discr) = finder.map.find(discr.as_ref()) else { continue }; debug!(?discr); @@ -142,6 +148,7 @@ struct ThreadingOpportunity { struct TOFinder<'tcx, 'a> { tcx: TyCtxt<'tcx>, param_env: ty::ParamEnv<'tcx>, + ecx: InterpCx<'tcx, 'tcx, DummyMachine>, body: &'a Body<'tcx>, map: &'a Map, loop_headers: &'a BitSet<BasicBlock>, @@ -329,11 +336,11 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { } #[instrument(level = "trace", skip(self))] - fn process_operand( + fn process_immediate( &mut self, bb: BasicBlock, lhs: PlaceIndex, - rhs: &Operand<'tcx>, + rhs: ImmTy<'tcx>, state: &mut State<ConditionSet<'a>>, ) -> Option<!> { let register_opportunity = |c: Condition| { @@ -341,13 +348,60 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target }) }; + let conditions = state.try_get_idx(lhs, self.map)?; + if let Immediate::Scalar(Scalar::Int(int)) = *rhs { + conditions.iter_matches(int).for_each(register_opportunity); + } + + None + } + + #[instrument(level = "trace", skip(self))] + fn process_operand( + &mut self, + bb: BasicBlock, + lhs: PlaceIndex, + rhs: &Operand<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) -> Option<!> { match rhs { // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. Operand::Constant(constant) => { - let conditions = state.try_get_idx(lhs, self.map)?; - let constant = - constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?; - conditions.iter_matches(constant).for_each(register_opportunity); + let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?; + self.map.for_each_projection_value( + lhs, + constant, + &mut |elem, op| match elem { + TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(), + TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(), + TrackElem::Discriminant => { + let variant = self.ecx.read_discriminant(op).ok()?; + let discr_value = + self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?; + Some(discr_value.into()) + } + TrackElem::DerefLen => { + let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into(); + let len_usize = op.len(&self.ecx).ok()?; + let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); + Some(ImmTy::from_uint(len_usize, layout).into()) + } + }, + &mut |place, op| { + if let Some(conditions) = state.try_get_idx(place, self.map) + && let Ok(imm) = self.ecx.read_immediate_raw(op) + && let Some(imm) = imm.right() + && let Immediate::Scalar(Scalar::Int(int)) = *imm + { + conditions.iter_matches(int).for_each(|c: Condition| { + self.opportunities.push(ThreadingOpportunity { + chain: vec![bb], + target: c.target, + }) + }) + } + }, + ); } // Transfer the conditions on the copied rhs. Operand::Move(rhs) | Operand::Copy(rhs) => { @@ -374,18 +428,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { // Below, `lhs` is the return value of `mutated_statement`, // the place to which `conditions` apply. - let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| { - let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?; - let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?; - let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?; - Some(Operand::const_from_scalar( - self.tcx, - discr.ty, - scalar.into(), - rustc_span::DUMMY_SP, - )) - }; - match &stmt.kind { // If we expect `discriminant(place) ?= A`, // we have an opportunity if `variant_index ?= A`. @@ -395,7 +437,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { // `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant // of a niche encoding. If we cannot ensure that we write to the discriminant, do // nothing. - let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?; + let enum_layout = self.ecx.layout_of(enum_ty).ok()?; let writes_discriminant = match enum_layout.variants { Variants::Single { index } => { assert_eq!(index, *variant_index); @@ -408,8 +450,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { } => *variant_index != untagged_variant, }; if writes_discriminant { - let discr = discriminant_for_variant(enum_ty, *variant_index)?; - self.process_operand(bb, discr_target, &discr, state)?; + let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?; + self.process_immediate(bb, discr_target, discr, state)?; } } // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`. @@ -440,10 +482,16 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => { if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant) - && let Some(discr_value) = - discriminant_for_variant(agg_ty, *variant_index) + && let Ok(discr_value) = self + .ecx + .discriminant_for_variant(agg_ty, *variant_index) { - self.process_operand(bb, discr_target, &discr_value, state); + self.process_immediate( + bb, + discr_target, + discr_value, + state, + ); } self.map.apply(lhs, TrackElem::Variant(*variant_index))? } @@ -577,7 +625,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { let discr = discr.place()?; let discr_ty = discr.ty(self.body, self.tcx).ty; - let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?; + let discr_layout = self.ecx.layout_of(discr_ty).ok()?; let conditions = state.try_get(discr.as_ref(), self.map)?; if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) { diff --git a/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-abort.diff b/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-abort.diff new file mode 100644 index 00000000000..66aa892c6f3 --- /dev/null +++ b/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-abort.diff @@ -0,0 +1,52 @@ +- // MIR for `aggregate` before JumpThreading ++ // MIR for `aggregate` after JumpThreading + + fn aggregate(_1: u8) -> u8 { + debug x => _1; + let mut _0: u8; + let _2: u8; + let _3: u8; + let mut _4: (u8, u8); + let mut _5: bool; + let mut _6: u8; + scope 1 { + debug a => _2; + debug b => _3; + } + + bb0: { + StorageLive(_4); + _4 = const _; + StorageLive(_2); + _2 = (_4.0: u8); + StorageLive(_3); + _3 = (_4.1: u8); + StorageDead(_4); + StorageLive(_5); + StorageLive(_6); + _6 = _2; + _5 = Eq(move _6, const 7_u8); +- switchInt(move _5) -> [0: bb2, otherwise: bb1]; ++ goto -> bb2; + } + + bb1: { + StorageDead(_6); + _0 = _3; + goto -> bb3; + } + + bb2: { + StorageDead(_6); + _0 = _2; + goto -> bb3; + } + + bb3: { + StorageDead(_5); + StorageDead(_3); + StorageDead(_2); + return; + } + } + diff --git a/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-unwind.diff b/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-unwind.diff new file mode 100644 index 00000000000..66aa892c6f3 --- /dev/null +++ b/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-unwind.diff @@ -0,0 +1,52 @@ +- // MIR for `aggregate` before JumpThreading ++ // MIR for `aggregate` after JumpThreading + + fn aggregate(_1: u8) -> u8 { + debug x => _1; + let mut _0: u8; + let _2: u8; + let _3: u8; + let mut _4: (u8, u8); + let mut _5: bool; + let mut _6: u8; + scope 1 { + debug a => _2; + debug b => _3; + } + + bb0: { + StorageLive(_4); + _4 = const _; + StorageLive(_2); + _2 = (_4.0: u8); + StorageLive(_3); + _3 = (_4.1: u8); + StorageDead(_4); + StorageLive(_5); + StorageLive(_6); + _6 = _2; + _5 = Eq(move _6, const 7_u8); +- switchInt(move _5) -> [0: bb2, otherwise: bb1]; ++ goto -> bb2; + } + + bb1: { + StorageDead(_6); + _0 = _3; + goto -> bb3; + } + + bb2: { + StorageDead(_6); + _0 = _2; + goto -> bb3; + } + + bb3: { + StorageDead(_5); + StorageDead(_3); + StorageDead(_2); + return; + } + } + diff --git a/tests/mir-opt/jump_threading.rs b/tests/mir-opt/jump_threading.rs index 0cbdaa085bc..7c2fa42828b 100644 --- a/tests/mir-opt/jump_threading.rs +++ b/tests/mir-opt/jump_threading.rs @@ -453,7 +453,23 @@ fn disappearing_bb(x: u8) -> u8 { ) } +/// Verify that we can thread jumps when we assign from an aggregate constant. +fn aggregate(x: u8) -> u8 { + // CHECK-LABEL: fn aggregate( + // CHECK-NOT: switchInt( + + const FOO: (u8, u8) = (5, 13); + + let (a, b) = FOO; + if a == 7 { + b + } else { + a + } +} + fn main() { + // CHECK-LABEL: fn main( too_complex(Ok(0)); identity(Ok(0)); custom_discr(false); @@ -464,6 +480,7 @@ fn main() { mutable_ref(); renumbered_bb(true); disappearing_bb(7); + aggregate(7); } // EMIT_MIR jump_threading.too_complex.JumpThreading.diff @@ -476,3 +493,4 @@ fn main() { // EMIT_MIR jump_threading.mutable_ref.JumpThreading.diff // EMIT_MIR jump_threading.renumbered_bb.JumpThreading.diff // EMIT_MIR jump_threading.disappearing_bb.JumpThreading.diff +// EMIT_MIR jump_threading.aggregate.JumpThreading.diff |
