about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
authorCamille GILLOT <gillot.camille@gmail.com>2023-03-20 21:37:36 +0000
committerCamille GILLOT <gillot.camille@gmail.com>2024-01-16 22:20:53 +0000
commit666030c51b87509a065867f99bf7317a5486470a (patch)
treedae7345129d87e308dde99c027cbce19e3e74715 /compiler/rustc_mir_transform/src
parent92f2e0aa62113a5f31076a9414daca55722556cf (diff)
downloadrust-666030c51b87509a065867f99bf7317a5486470a.tar.gz
rust-666030c51b87509a065867f99bf7317a5486470a.zip
Simplify binary ops.
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/gvn.rs111
1 files changed, 109 insertions, 2 deletions
diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs
index dc3af038d80..3052369c3ca 100644
--- a/compiler/rustc_mir_transform/src/gvn.rs
+++ b/compiler/rustc_mir_transform/src/gvn.rs
@@ -345,11 +345,20 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
         Some(self.insert(Value::Constant { value, disambiguator }))
     }
 
+    fn insert_bool(&mut self, flag: bool) -> VnIndex {
+        // Booleans are deterministic.
+        self.insert(Value::Constant { value: Const::from_bool(self.tcx, flag), disambiguator: 0 })
+    }
+
     fn insert_scalar(&mut self, scalar: Scalar, ty: Ty<'tcx>) -> VnIndex {
         self.insert_constant(Const::from_scalar(self.tcx, scalar, ty))
             .expect("scalars are deterministic")
     }
 
+    fn insert_tuple(&mut self, values: Vec<VnIndex>) -> VnIndex {
+        self.insert(Value::Aggregate(AggregateTy::Tuple, VariantIdx::from_u32(0), values))
+    }
+
     #[instrument(level = "trace", skip(self), ret)]
     fn eval_to_const(&mut self, value: VnIndex) -> Option<OpTy<'tcx>> {
         use Value::*;
@@ -785,14 +794,26 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
                 Value::Cast { kind, value, from, to }
             }
             Rvalue::BinaryOp(op, box (ref mut lhs, ref mut rhs)) => {
+                let ty = lhs.ty(self.local_decls, self.tcx);
                 let lhs = self.simplify_operand(lhs, location);
                 let rhs = self.simplify_operand(rhs, location);
-                Value::BinaryOp(op, lhs?, rhs?)
+                let lhs = lhs?;
+                let rhs = rhs?;
+                if let Some(value) = self.simplify_binary(op, false, ty, lhs, rhs) {
+                    return Some(value);
+                }
+                Value::BinaryOp(op, lhs, rhs)
             }
             Rvalue::CheckedBinaryOp(op, box (ref mut lhs, ref mut rhs)) => {
+                let ty = lhs.ty(self.local_decls, self.tcx);
                 let lhs = self.simplify_operand(lhs, location);
                 let rhs = self.simplify_operand(rhs, location);
-                Value::CheckedBinaryOp(op, lhs?, rhs?)
+                let lhs = lhs?;
+                let rhs = rhs?;
+                if let Some(value) = self.simplify_binary(op, true, ty, lhs, rhs) {
+                    return Some(value);
+                }
+                Value::CheckedBinaryOp(op, lhs, rhs)
             }
             Rvalue::UnaryOp(op, ref mut arg) => {
                 let arg = self.simplify_operand(arg, location)?;
@@ -894,6 +915,92 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
 
         Some(self.insert(Value::Aggregate(ty, variant_index, fields)))
     }
+
+    #[instrument(level = "trace", skip(self), ret)]
+    fn simplify_binary(
+        &mut self,
+        op: BinOp,
+        checked: bool,
+        lhs_ty: Ty<'tcx>,
+        lhs: VnIndex,
+        rhs: VnIndex,
+    ) -> Option<VnIndex> {
+        // Floats are weird enough that none of the logic below applies.
+        let reasonable_ty =
+            lhs_ty.is_integral() || lhs_ty.is_bool() || lhs_ty.is_char() || lhs_ty.is_any_ptr();
+        if !reasonable_ty {
+            return None;
+        }
+
+        let layout = self.ecx.layout_of(lhs_ty).ok()?;
+
+        let as_bits = |value| {
+            let constant = self.evaluated[value].as_ref()?;
+            let scalar = self.ecx.read_scalar(constant).ok()?;
+            scalar.to_bits(constant.layout.size).ok()
+        };
+
+        // Represent the values as `Ok(bits)` or `Err(VnIndex)`.
+        let a = as_bits(lhs).ok_or(lhs);
+        let b = as_bits(rhs).ok_or(rhs);
+        let result = match (op, a, b) {
+            // Neutral elements.
+            (BinOp::Add | BinOp::BitOr | BinOp::BitXor, Ok(0), Err(p))
+            | (
+                BinOp::Add
+                | BinOp::BitOr
+                | BinOp::BitXor
+                | BinOp::Sub
+                | BinOp::Offset
+                | BinOp::Shl
+                | BinOp::Shr,
+                Err(p),
+                Ok(0),
+            )
+            | (BinOp::Mul, Ok(1), Err(p))
+            | (BinOp::Mul | BinOp::Div, Err(p), Ok(1)) => p,
+            // Attempt to simplify `x & ALL_ONES` to `x`, with `ALL_ONES` depending on type size.
+            (BinOp::BitAnd, Err(p), Ok(ones)) | (BinOp::BitAnd, Ok(ones), Err(p))
+                if ones == layout.size.truncate(u128::MAX)
+                    || (layout.ty.is_bool() && ones == 1) =>
+            {
+                p
+            }
+            // Absorbing elements.
+            (BinOp::Mul | BinOp::BitAnd, _, Ok(0))
+            | (BinOp::Rem, _, Ok(1))
+            | (
+                BinOp::Mul | BinOp::Div | BinOp::Rem | BinOp::BitAnd | BinOp::Shl | BinOp::Shr,
+                Ok(0),
+                _,
+            ) => self.insert_scalar(Scalar::from_uint(0u128, layout.size), lhs_ty),
+            // Attempt to simplify `x | ALL_ONES` to `ALL_ONES`.
+            (BinOp::BitOr, _, Ok(ones)) | (BinOp::BitOr, Ok(ones), _)
+                if ones == layout.size.truncate(u128::MAX)
+                    || (layout.ty.is_bool() && ones == 1) =>
+            {
+                self.insert_scalar(Scalar::from_uint(ones, layout.size), lhs_ty)
+            }
+            // Sub/Xor with itself.
+            (BinOp::Sub | BinOp::BitXor, a, b) if a == b => {
+                self.insert_scalar(Scalar::from_uint(0u128, layout.size), lhs_ty)
+            }
+            // Comparison:
+            // - if both operands can be computed as bits, just compare the bits;
+            // - if we proved that both operands have the same value, we can insert true/false;
+            // - otherwise, do nothing, as we do not try to prove inequality.
+            (BinOp::Eq, a, b) if (a.is_ok() && b.is_ok()) || a == b => self.insert_bool(a == b),
+            (BinOp::Ne, a, b) if (a.is_ok() && b.is_ok()) || a == b => self.insert_bool(a != b),
+            _ => return None,
+        };
+
+        if checked {
+            let false_val = self.insert_bool(false);
+            Some(self.insert_tuple(vec![result, false_val]))
+        } else {
+            Some(result)
+        }
+    }
 }
 
 fn op_to_prop_const<'tcx>(