about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src/jump_threading.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/jump_threading.rs')
-rw-r--r--compiler/rustc_mir_transform/src/jump_threading.rs262
1 files changed, 155 insertions, 107 deletions
diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs
index dcab124505e..e87f68a0905 100644
--- a/compiler/rustc_mir_transform/src/jump_threading.rs
+++ b/compiler/rustc_mir_transform/src/jump_threading.rs
@@ -36,16 +36,21 @@
 //! cost by `MAX_COST`.
 
 use rustc_arena::DroplessArena;
+use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
 use rustc_data_structures::fx::FxHashSet;
 use rustc_index::bit_set::BitSet;
 use rustc_index::IndexVec;
+use rustc_middle::mir::interpret::Scalar;
 use rustc_middle::mir::visit::Visitor;
 use rustc_middle::mir::*;
-use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
+use rustc_middle::ty::layout::LayoutOf;
+use rustc_middle::ty::{self, ScalarInt, TyCtxt};
 use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
+use rustc_span::DUMMY_SP;
 use rustc_target::abi::{TagEncoding, Variants};
 
 use crate::cost_checker::CostChecker;
+use crate::dataflow_const_prop::DummyMachine;
 
 pub struct JumpThreading;
 
@@ -71,6 +76,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
         let mut finder = TOFinder {
             tcx,
             param_env,
+            ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
             body,
             arena: &arena,
             map: &map,
@@ -88,7 +94,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
             debug!(?discr, ?bb);
 
             let discr_ty = discr.ty(body, tcx).ty;
-            let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue };
+            let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue };
 
             let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
             debug!(?discr);
@@ -142,6 +148,7 @@ struct ThreadingOpportunity {
 struct TOFinder<'tcx, 'a> {
     tcx: TyCtxt<'tcx>,
     param_env: ty::ParamEnv<'tcx>,
+    ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
     body: &'a Body<'tcx>,
     map: &'a Map,
     loop_headers: &'a BitSet<BasicBlock>,
@@ -329,11 +336,11 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
     }
 
     #[instrument(level = "trace", skip(self))]
-    fn process_operand(
+    fn process_immediate(
         &mut self,
         bb: BasicBlock,
         lhs: PlaceIndex,
-        rhs: &Operand<'tcx>,
+        rhs: ImmTy<'tcx>,
         state: &mut State<ConditionSet<'a>>,
     ) -> Option<!> {
         let register_opportunity = |c: Condition| {
@@ -341,13 +348,70 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
             self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
         };
 
+        let conditions = state.try_get_idx(lhs, self.map)?;
+        if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
+            conditions.iter_matches(int).for_each(register_opportunity);
+        }
+
+        None
+    }
+
+    /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
+    #[instrument(level = "trace", skip(self))]
+    fn process_constant(
+        &mut self,
+        bb: BasicBlock,
+        lhs: PlaceIndex,
+        constant: OpTy<'tcx>,
+        state: &mut State<ConditionSet<'a>>,
+    ) {
+        self.map.for_each_projection_value(
+            lhs,
+            constant,
+            &mut |elem, op| match elem {
+                TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
+                TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
+                TrackElem::Discriminant => {
+                    let variant = self.ecx.read_discriminant(op).ok()?;
+                    let discr_value =
+                        self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
+                    Some(discr_value.into())
+                }
+                TrackElem::DerefLen => {
+                    let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
+                    let len_usize = op.len(&self.ecx).ok()?;
+                    let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
+                    Some(ImmTy::from_uint(len_usize, layout).into())
+                }
+            },
+            &mut |place, op| {
+                if let Some(conditions) = state.try_get_idx(place, self.map)
+                    && let Ok(imm) = self.ecx.read_immediate_raw(op)
+                    && let Some(imm) = imm.right()
+                    && let Immediate::Scalar(Scalar::Int(int)) = *imm
+                {
+                    conditions.iter_matches(int).for_each(|c: Condition| {
+                        self.opportunities
+                            .push(ThreadingOpportunity { chain: vec![bb], target: c.target })
+                    })
+                }
+            },
+        );
+    }
+
+    #[instrument(level = "trace", skip(self))]
+    fn process_operand(
+        &mut self,
+        bb: BasicBlock,
+        lhs: PlaceIndex,
+        rhs: &Operand<'tcx>,
+        state: &mut State<ConditionSet<'a>>,
+    ) -> Option<!> {
         match rhs {
             // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
             Operand::Constant(constant) => {
-                let conditions = state.try_get_idx(lhs, self.map)?;
-                let constant =
-                    constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
-                conditions.iter_matches(constant).for_each(register_opportunity);
+                let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
+                self.process_constant(bb, lhs, constant, state);
             }
             // Transfer the conditions on the copied rhs.
             Operand::Move(rhs) | Operand::Copy(rhs) => {
@@ -360,6 +424,84 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
     }
 
     #[instrument(level = "trace", skip(self))]
+    fn process_assign(
+        &mut self,
+        bb: BasicBlock,
+        lhs_place: &Place<'tcx>,
+        rhs: &Rvalue<'tcx>,
+        state: &mut State<ConditionSet<'a>>,
+    ) -> Option<!> {
+        let lhs = self.map.find(lhs_place.as_ref())?;
+        match rhs {
+            Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
+            // Transfer the conditions on the copy rhs.
+            Rvalue::CopyForDeref(rhs) => {
+                self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
+            }
+            Rvalue::Discriminant(rhs) => {
+                let rhs = self.map.find_discr(rhs.as_ref())?;
+                state.insert_place_idx(rhs, lhs, self.map);
+            }
+            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
+            Rvalue::Aggregate(box ref kind, ref operands) => {
+                let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
+                let lhs = match kind {
+                    // Do not support unions.
+                    AggregateKind::Adt(.., Some(_)) => return None,
+                    AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
+                        if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
+                            && let Ok(discr_value) =
+                                self.ecx.discriminant_for_variant(agg_ty, *variant_index)
+                        {
+                            self.process_immediate(bb, discr_target, discr_value, state);
+                        }
+                        self.map.apply(lhs, TrackElem::Variant(*variant_index))?
+                    }
+                    _ => lhs,
+                };
+                for (field_index, operand) in operands.iter_enumerated() {
+                    if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
+                        self.process_operand(bb, field, operand, state);
+                    }
+                }
+            }
+            // Transfer the conditions on the copy rhs, after inversing polarity.
+            Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
+                let conditions = state.try_get_idx(lhs, self.map)?;
+                let place = self.map.find(place.as_ref())?;
+                let conds = conditions.map(self.arena, Condition::inv);
+                state.insert_value_idx(place, conds, self.map);
+            }
+            // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
+            // Create a condition on `rhs ?= B`.
+            Rvalue::BinaryOp(
+                op,
+                box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
+                | box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
+            ) => {
+                let conditions = state.try_get_idx(lhs, self.map)?;
+                let place = self.map.find(place.as_ref())?;
+                let equals = match op {
+                    BinOp::Eq => ScalarInt::TRUE,
+                    BinOp::Ne => ScalarInt::FALSE,
+                    _ => return None,
+                };
+                let value = value.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
+                let conds = conditions.map(self.arena, |c| Condition {
+                    value,
+                    polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
+                    ..c
+                });
+                state.insert_value_idx(place, conds, self.map);
+            }
+
+            _ => {}
+        }
+
+        None
+    }
+
+    #[instrument(level = "trace", skip(self))]
     fn process_statement(
         &mut self,
         bb: BasicBlock,
@@ -374,18 +516,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
         // Below, `lhs` is the return value of `mutated_statement`,
         // the place to which `conditions` apply.
 
-        let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
-            let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
-            let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
-            let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
-            Some(Operand::const_from_scalar(
-                self.tcx,
-                discr.ty,
-                scalar.into(),
-                rustc_span::DUMMY_SP,
-            ))
-        };
-
         match &stmt.kind {
             // If we expect `discriminant(place) ?= A`,
             // we have an opportunity if `variant_index ?= A`.
@@ -395,7 +525,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
                 // `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
                 // of a niche encoding. If we cannot ensure that we write to the discriminant, do
                 // nothing.
-                let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
+                let enum_layout = self.ecx.layout_of(enum_ty).ok()?;
                 let writes_discriminant = match enum_layout.variants {
                     Variants::Single { index } => {
                         assert_eq!(index, *variant_index);
@@ -408,8 +538,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
                     } => *variant_index != untagged_variant,
                 };
                 if writes_discriminant {
-                    let discr = discriminant_for_variant(enum_ty, *variant_index)?;
-                    self.process_operand(bb, discr_target, &discr, state)?;
+                    let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
+                    self.process_immediate(bb, discr_target, discr, state)?;
                 }
             }
             // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
@@ -420,89 +550,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
                 conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
             }
             StatementKind::Assign(box (lhs_place, rhs)) => {
-                if let Some(lhs) = self.map.find(lhs_place.as_ref()) {
-                    match rhs {
-                        Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
-                        // Transfer the conditions on the copy rhs.
-                        Rvalue::CopyForDeref(rhs) => {
-                            self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
-                        }
-                        Rvalue::Discriminant(rhs) => {
-                            let rhs = self.map.find_discr(rhs.as_ref())?;
-                            state.insert_place_idx(rhs, lhs, self.map);
-                        }
-                        // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
-                        Rvalue::Aggregate(box ref kind, ref operands) => {
-                            let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
-                            let lhs = match kind {
-                                // Do not support unions.
-                                AggregateKind::Adt(.., Some(_)) => return None,
-                                AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
-                                    if let Some(discr_target) =
-                                        self.map.apply(lhs, TrackElem::Discriminant)
-                                        && let Some(discr_value) =
-                                            discriminant_for_variant(agg_ty, *variant_index)
-                                    {
-                                        self.process_operand(bb, discr_target, &discr_value, state);
-                                    }
-                                    self.map.apply(lhs, TrackElem::Variant(*variant_index))?
-                                }
-                                _ => lhs,
-                            };
-                            for (field_index, operand) in operands.iter_enumerated() {
-                                if let Some(field) =
-                                    self.map.apply(lhs, TrackElem::Field(field_index))
-                                {
-                                    self.process_operand(bb, field, operand, state);
-                                }
-                            }
-                        }
-                        // Transfer the conditions on the copy rhs, after inversing polarity.
-                        Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
-                            let conditions = state.try_get_idx(lhs, self.map)?;
-                            let place = self.map.find(place.as_ref())?;
-                            let conds = conditions.map(self.arena, Condition::inv);
-                            state.insert_value_idx(place, conds, self.map);
-                        }
-                        // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
-                        // Create a condition on `rhs ?= B`.
-                        Rvalue::BinaryOp(
-                            op,
-                            box (
-                                Operand::Move(place) | Operand::Copy(place),
-                                Operand::Constant(value),
-                            )
-                            | box (
-                                Operand::Constant(value),
-                                Operand::Move(place) | Operand::Copy(place),
-                            ),
-                        ) => {
-                            let conditions = state.try_get_idx(lhs, self.map)?;
-                            let place = self.map.find(place.as_ref())?;
-                            let equals = match op {
-                                BinOp::Eq => ScalarInt::TRUE,
-                                BinOp::Ne => ScalarInt::FALSE,
-                                _ => return None,
-                            };
-                            let value = value
-                                .const_
-                                .normalize(self.tcx, self.param_env)
-                                .try_to_scalar_int()?;
-                            let conds = conditions.map(self.arena, |c| Condition {
-                                value,
-                                polarity: if c.matches(equals) {
-                                    Polarity::Eq
-                                } else {
-                                    Polarity::Ne
-                                },
-                                ..c
-                            });
-                            state.insert_value_idx(place, conds, self.map);
-                        }
-
-                        _ => {}
-                    }
-                }
+                self.process_assign(bb, lhs_place, rhs, state)?;
             }
             _ => {}
         }
@@ -577,7 +625,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
 
         let discr = discr.place()?;
         let discr_ty = discr.ty(self.body, self.tcx).ty;
-        let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
+        let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
         let conditions = state.try_get(discr.as_ref(), self.map)?;
 
         if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {