about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCamille GILLOT <gillot.camille@gmail.com>2023-12-31 01:53:51 +0000
committerCamille GILLOT <gillot.camille@gmail.com>2024-01-18 22:53:07 +0000
commitbe9668d3989b721bfaa2ec517b307965476431fa (patch)
tree0937caa224afeb543c2cfb529e7f1d6110896f59
parent25f8d01fd8bda339612d0c0a8844173a09205f7c (diff)
downloadrust-be9668d3989b721bfaa2ec517b307965476431fa.tar.gz
rust-be9668d3989b721bfaa2ec517b307965476431fa.zip
Use an interpreter in jump threading.
-rw-r--r--compiler/rustc_mir_transform/src/jump_threading.rs102
-rw-r--r--tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-abort.diff52
-rw-r--r--tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-unwind.diff52
-rw-r--r--tests/mir-opt/jump_threading.rs18
4 files changed, 197 insertions, 27 deletions
diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs
index dcab124505e..fc77076e4e7 100644
--- a/compiler/rustc_mir_transform/src/jump_threading.rs
+++ b/compiler/rustc_mir_transform/src/jump_threading.rs
@@ -36,16 +36,21 @@
 //! cost by `MAX_COST`.
 
 use rustc_arena::DroplessArena;
+use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
 use rustc_data_structures::fx::FxHashSet;
 use rustc_index::bit_set::BitSet;
 use rustc_index::IndexVec;
+use rustc_middle::mir::interpret::Scalar;
 use rustc_middle::mir::visit::Visitor;
 use rustc_middle::mir::*;
-use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
+use rustc_middle::ty::layout::LayoutOf;
+use rustc_middle::ty::{self, ScalarInt, TyCtxt};
 use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
+use rustc_span::DUMMY_SP;
 use rustc_target::abi::{TagEncoding, Variants};
 
 use crate::cost_checker::CostChecker;
+use crate::dataflow_const_prop::DummyMachine;
 
 pub struct JumpThreading;
 
@@ -71,6 +76,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
         let mut finder = TOFinder {
             tcx,
             param_env,
+            ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
             body,
             arena: &arena,
             map: &map,
@@ -88,7 +94,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
             debug!(?discr, ?bb);
 
             let discr_ty = discr.ty(body, tcx).ty;
-            let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue };
+            let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue };
 
             let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
             debug!(?discr);
@@ -142,6 +148,7 @@ struct ThreadingOpportunity {
 struct TOFinder<'tcx, 'a> {
     tcx: TyCtxt<'tcx>,
     param_env: ty::ParamEnv<'tcx>,
+    ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
     body: &'a Body<'tcx>,
     map: &'a Map,
     loop_headers: &'a BitSet<BasicBlock>,
@@ -329,11 +336,11 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
     }
 
     #[instrument(level = "trace", skip(self))]
-    fn process_operand(
+    fn process_immediate(
         &mut self,
         bb: BasicBlock,
         lhs: PlaceIndex,
-        rhs: &Operand<'tcx>,
+        rhs: ImmTy<'tcx>,
         state: &mut State<ConditionSet<'a>>,
     ) -> Option<!> {
         let register_opportunity = |c: Condition| {
@@ -341,13 +348,60 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
             self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
         };
 
+        let conditions = state.try_get_idx(lhs, self.map)?;
+        if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
+            conditions.iter_matches(int).for_each(register_opportunity);
+        }
+
+        None
+    }
+
+    #[instrument(level = "trace", skip(self))]
+    fn process_operand(
+        &mut self,
+        bb: BasicBlock,
+        lhs: PlaceIndex,
+        rhs: &Operand<'tcx>,
+        state: &mut State<ConditionSet<'a>>,
+    ) -> Option<!> {
         match rhs {
             // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
             Operand::Constant(constant) => {
-                let conditions = state.try_get_idx(lhs, self.map)?;
-                let constant =
-                    constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
-                conditions.iter_matches(constant).for_each(register_opportunity);
+                let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
+                self.map.for_each_projection_value(
+                    lhs,
+                    constant,
+                    &mut |elem, op| match elem {
+                        TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
+                        TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
+                        TrackElem::Discriminant => {
+                            let variant = self.ecx.read_discriminant(op).ok()?;
+                            let discr_value =
+                                self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
+                            Some(discr_value.into())
+                        }
+                        TrackElem::DerefLen => {
+                            let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
+                            let len_usize = op.len(&self.ecx).ok()?;
+                            let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
+                            Some(ImmTy::from_uint(len_usize, layout).into())
+                        }
+                    },
+                    &mut |place, op| {
+                        if let Some(conditions) = state.try_get_idx(place, self.map)
+                            && let Ok(imm) = self.ecx.read_immediate_raw(op)
+                            && let Some(imm) = imm.right()
+                            && let Immediate::Scalar(Scalar::Int(int)) = *imm
+                        {
+                            conditions.iter_matches(int).for_each(|c: Condition| {
+                                self.opportunities.push(ThreadingOpportunity {
+                                    chain: vec![bb],
+                                    target: c.target,
+                                })
+                            })
+                        }
+                    },
+                );
             }
             // Transfer the conditions on the copied rhs.
             Operand::Move(rhs) | Operand::Copy(rhs) => {
@@ -374,18 +428,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
         // Below, `lhs` is the return value of `mutated_statement`,
         // the place to which `conditions` apply.
 
-        let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
-            let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
-            let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
-            let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
-            Some(Operand::const_from_scalar(
-                self.tcx,
-                discr.ty,
-                scalar.into(),
-                rustc_span::DUMMY_SP,
-            ))
-        };
-
         match &stmt.kind {
             // If we expect `discriminant(place) ?= A`,
             // we have an opportunity if `variant_index ?= A`.
@@ -395,7 +437,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
                 // `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
                 // of a niche encoding. If we cannot ensure that we write to the discriminant, do
                 // nothing.
-                let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
+                let enum_layout = self.ecx.layout_of(enum_ty).ok()?;
                 let writes_discriminant = match enum_layout.variants {
                     Variants::Single { index } => {
                         assert_eq!(index, *variant_index);
@@ -408,8 +450,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
                     } => *variant_index != untagged_variant,
                 };
                 if writes_discriminant {
-                    let discr = discriminant_for_variant(enum_ty, *variant_index)?;
-                    self.process_operand(bb, discr_target, &discr, state)?;
+                    let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
+                    self.process_immediate(bb, discr_target, discr, state)?;
                 }
             }
             // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
@@ -440,10 +482,16 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
                                 AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
                                     if let Some(discr_target) =
                                         self.map.apply(lhs, TrackElem::Discriminant)
-                                        && let Some(discr_value) =
-                                            discriminant_for_variant(agg_ty, *variant_index)
+                                        && let Ok(discr_value) = self
+                                            .ecx
+                                            .discriminant_for_variant(agg_ty, *variant_index)
                                     {
-                                        self.process_operand(bb, discr_target, &discr_value, state);
+                                        self.process_immediate(
+                                            bb,
+                                            discr_target,
+                                            discr_value,
+                                            state,
+                                        );
                                     }
                                     self.map.apply(lhs, TrackElem::Variant(*variant_index))?
                                 }
@@ -577,7 +625,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
 
         let discr = discr.place()?;
         let discr_ty = discr.ty(self.body, self.tcx).ty;
-        let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
+        let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
         let conditions = state.try_get(discr.as_ref(), self.map)?;
 
         if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
diff --git a/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-abort.diff b/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-abort.diff
new file mode 100644
index 00000000000..66aa892c6f3
--- /dev/null
+++ b/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-abort.diff
@@ -0,0 +1,52 @@
+- // MIR for `aggregate` before JumpThreading
++ // MIR for `aggregate` after JumpThreading
+  
+  fn aggregate(_1: u8) -> u8 {
+      debug x => _1;
+      let mut _0: u8;
+      let _2: u8;
+      let _3: u8;
+      let mut _4: (u8, u8);
+      let mut _5: bool;
+      let mut _6: u8;
+      scope 1 {
+          debug a => _2;
+          debug b => _3;
+      }
+  
+      bb0: {
+          StorageLive(_4);
+          _4 = const _;
+          StorageLive(_2);
+          _2 = (_4.0: u8);
+          StorageLive(_3);
+          _3 = (_4.1: u8);
+          StorageDead(_4);
+          StorageLive(_5);
+          StorageLive(_6);
+          _6 = _2;
+          _5 = Eq(move _6, const 7_u8);
+-         switchInt(move _5) -> [0: bb2, otherwise: bb1];
++         goto -> bb2;
+      }
+  
+      bb1: {
+          StorageDead(_6);
+          _0 = _3;
+          goto -> bb3;
+      }
+  
+      bb2: {
+          StorageDead(_6);
+          _0 = _2;
+          goto -> bb3;
+      }
+  
+      bb3: {
+          StorageDead(_5);
+          StorageDead(_3);
+          StorageDead(_2);
+          return;
+      }
+  }
+  
diff --git a/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-unwind.diff b/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-unwind.diff
new file mode 100644
index 00000000000..66aa892c6f3
--- /dev/null
+++ b/tests/mir-opt/jump_threading.aggregate.JumpThreading.panic-unwind.diff
@@ -0,0 +1,52 @@
+- // MIR for `aggregate` before JumpThreading
++ // MIR for `aggregate` after JumpThreading
+  
+  fn aggregate(_1: u8) -> u8 {
+      debug x => _1;
+      let mut _0: u8;
+      let _2: u8;
+      let _3: u8;
+      let mut _4: (u8, u8);
+      let mut _5: bool;
+      let mut _6: u8;
+      scope 1 {
+          debug a => _2;
+          debug b => _3;
+      }
+  
+      bb0: {
+          StorageLive(_4);
+          _4 = const _;
+          StorageLive(_2);
+          _2 = (_4.0: u8);
+          StorageLive(_3);
+          _3 = (_4.1: u8);
+          StorageDead(_4);
+          StorageLive(_5);
+          StorageLive(_6);
+          _6 = _2;
+          _5 = Eq(move _6, const 7_u8);
+-         switchInt(move _5) -> [0: bb2, otherwise: bb1];
++         goto -> bb2;
+      }
+  
+      bb1: {
+          StorageDead(_6);
+          _0 = _3;
+          goto -> bb3;
+      }
+  
+      bb2: {
+          StorageDead(_6);
+          _0 = _2;
+          goto -> bb3;
+      }
+  
+      bb3: {
+          StorageDead(_5);
+          StorageDead(_3);
+          StorageDead(_2);
+          return;
+      }
+  }
+  
diff --git a/tests/mir-opt/jump_threading.rs b/tests/mir-opt/jump_threading.rs
index 0cbdaa085bc..7c2fa42828b 100644
--- a/tests/mir-opt/jump_threading.rs
+++ b/tests/mir-opt/jump_threading.rs
@@ -453,7 +453,23 @@ fn disappearing_bb(x: u8) -> u8 {
     )
 }
 
+/// Verify that we can thread jumps when we assign from an aggregate constant.
+fn aggregate(x: u8) -> u8 {
+    // CHECK-LABEL: fn aggregate(
+    // CHECK-NOT: switchInt(
+
+    const FOO: (u8, u8) = (5, 13);
+
+    let (a, b) = FOO;
+    if a == 7 {
+        b
+    } else {
+        a
+    }
+}
+
 fn main() {
+    // CHECK-LABEL: fn main(
     too_complex(Ok(0));
     identity(Ok(0));
     custom_discr(false);
@@ -464,6 +480,7 @@ fn main() {
     mutable_ref();
     renumbered_bb(true);
     disappearing_bb(7);
+    aggregate(7);
 }
 
 // EMIT_MIR jump_threading.too_complex.JumpThreading.diff
@@ -476,3 +493,4 @@ fn main() {
 // EMIT_MIR jump_threading.mutable_ref.JumpThreading.diff
 // EMIT_MIR jump_threading.renumbered_bb.JumpThreading.diff
 // EMIT_MIR jump_threading.disappearing_bb.JumpThreading.diff
+// EMIT_MIR jump_threading.aggregate.JumpThreading.diff