about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLoïc BRANSTETT <loic.branstett@epitech.eu>2021-12-17 20:01:19 +0100
committerUrgau <urgau@numericable.fr>2023-04-26 10:18:53 +0200
commit1a72d7c7c4e7050d3c07d8b6fe586ddef2c36305 (patch)
treeb0ea3d7affe03b9f3fef4fd8fd3c3d597ec12826
parent23a76a8ab5f7b29a1eb7aca5f3e4c4a9b866d5b7 (diff)
downloadrust-1a72d7c7c4e7050d3c07d8b6fe586ddef2c36305.tar.gz
rust-1a72d7c7c4e7050d3c07d8b6fe586ddef2c36305.zip
Implement midpoint for all signed and unsigned integers
-rw-r--r--library/core/src/lib.rs1
-rw-r--r--library/core/src/num/int_macros.rs38
-rw-r--r--library/core/src/num/mod.rs59
-rw-r--r--library/core/tests/lib.rs1
-rw-r--r--library/core/tests/num/int_macros.rs26
-rw-r--r--library/core/tests/num/uint_macros.rs26
6 files changed, 151 insertions, 0 deletions
diff --git a/library/core/src/lib.rs b/library/core/src/lib.rs
index 24a9d81d037..23bf199f482 100644
--- a/library/core/src/lib.rs
+++ b/library/core/src/lib.rs
@@ -132,6 +132,7 @@
 #![feature(const_maybe_uninit_assume_init)]
 #![feature(const_maybe_uninit_uninit_array)]
 #![feature(const_nonnull_new)]
+#![feature(const_num_midpoint)]
 #![feature(const_option)]
 #![feature(const_option_ext)]
 #![feature(const_pin)]
diff --git a/library/core/src/num/int_macros.rs b/library/core/src/num/int_macros.rs
index 17715c9291f..1199d09b563 100644
--- a/library/core/src/num/int_macros.rs
+++ b/library/core/src/num/int_macros.rs
@@ -2332,6 +2332,44 @@ macro_rules! int_impl {
             }
         }
 
+        /// Calculates the middle point of `self` and `rhs`.
+        ///
+        /// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
+        /// sufficiently-large signed integral type. This implies that the result is
+        /// always rounded towards negative infinity and that no overflow will ever occur.
+        ///
+        /// # Examples
+        ///
+        /// ```
+        /// #![feature(num_midpoint)]
+        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
+        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(-1), -1);")]
+        #[doc = concat!("assert_eq!((-1", stringify!($SelfT), ").midpoint(0), -1);")]
+        /// ```
+        #[unstable(feature = "num_midpoint", issue = "110840")]
+        #[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
+        #[rustc_allow_const_fn_unstable(const_num_midpoint)]
+        #[must_use = "this returns the result of the operation, \
+                      without modifying the original"]
+        #[inline]
+        pub const fn midpoint(self, rhs: Self) -> Self {
+            const U: $UnsignedT = <$SelfT>::MIN.unsigned_abs();
+
+            // Map an $SelfT to an $UnsignedT
+            // ex: i8 [-128; 127] to [0; 255]
+            const fn map(a: $SelfT) -> $UnsignedT {
+                (a as $UnsignedT) ^ U
+            }
+
+            // Map an $UnsignedT to an $SelfT
+            // ex: u8 [0; 255] to [-128; 127]
+            const fn demap(a: $UnsignedT) -> $SelfT {
+                (a ^ U) as $SelfT
+            }
+
+            demap(<$UnsignedT>::midpoint(map(self), map(rhs)))
+        }
+
         /// Returns the logarithm of the number with respect to an arbitrary base,
         /// rounded down.
         ///
diff --git a/library/core/src/num/mod.rs b/library/core/src/num/mod.rs
index fdd7be625ed..5f04c06a8d3 100644
--- a/library/core/src/num/mod.rs
+++ b/library/core/src/num/mod.rs
@@ -95,6 +95,57 @@ depending on the target pointer size.
     };
 }
 
+macro_rules! midpoint_impl {
+    ($SelfT:ty, unsigned) => {
+        /// Calculates the middle point of `self` and `rhs`.
+        ///
+        /// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
+        /// sufficiently-large signed integral type. This implies that the result is
+        /// always rounded towards negative infinity and that no overflow will ever occur.
+        ///
+        /// # Examples
+        ///
+        /// ```
+        /// #![feature(num_midpoint)]
+        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
+        #[doc = concat!("assert_eq!(1", stringify!($SelfT), ".midpoint(4), 2);")]
+        /// ```
+        #[unstable(feature = "num_midpoint", issue = "110840")]
+        #[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
+        #[must_use = "this returns the result of the operation, \
+                      without modifying the original"]
+        #[inline]
+        pub const fn midpoint(self, rhs: $SelfT) -> $SelfT {
+            // Use the well known branchless algorthim from Hacker's Delight to compute
+            // `(a + b) / 2` without overflowing: `((a ^ b) >> 1) + (a & b)`.
+            ((self ^ rhs) >> 1) + (self & rhs)
+        }
+    };
+    ($SelfT:ty, $WideT:ty, unsigned) => {
+        /// Calculates the middle point of `self` and `rhs`.
+        ///
+        /// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
+        /// sufficiently-large signed integral type. This implies that the result is
+        /// always rounded towards negative infinity and that no overflow will ever occur.
+        ///
+        /// # Examples
+        ///
+        /// ```
+        /// #![feature(num_midpoint)]
+        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
+        #[doc = concat!("assert_eq!(1", stringify!($SelfT), ".midpoint(4), 2);")]
+        /// ```
+        #[unstable(feature = "num_midpoint", issue = "110840")]
+        #[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
+        #[must_use = "this returns the result of the operation, \
+                      without modifying the original"]
+        #[inline]
+        pub const fn midpoint(self, rhs: $SelfT) -> $SelfT {
+            ((self as $WideT + rhs as $WideT) / 2) as $SelfT
+        }
+    };
+}
+
 macro_rules! widening_impl {
     ($SelfT:ty, $WideT:ty, $BITS:literal, unsigned) => {
         /// Calculates the complete product `self * rhs` without the possibility to overflow.
@@ -455,6 +506,7 @@ impl u8 {
         bound_condition = "",
     }
     widening_impl! { u8, u16, 8, unsigned }
+    midpoint_impl! { u8, u16, unsigned }
 
     /// Checks if the value is within the ASCII range.
     ///
@@ -1057,6 +1109,7 @@ impl u16 {
         bound_condition = "",
     }
     widening_impl! { u16, u32, 16, unsigned }
+    midpoint_impl! { u16, u32, unsigned }
 
     /// Checks if the value is a Unicode surrogate code point, which are disallowed values for [`char`].
     ///
@@ -1105,6 +1158,7 @@ impl u32 {
         bound_condition = "",
     }
     widening_impl! { u32, u64, 32, unsigned }
+    midpoint_impl! { u32, u64, unsigned }
 }
 
 impl u64 {
@@ -1128,6 +1182,7 @@ impl u64 {
         bound_condition = "",
     }
     widening_impl! { u64, u128, 64, unsigned }
+    midpoint_impl! { u64, u128, unsigned }
 }
 
 impl u128 {
@@ -1152,6 +1207,7 @@ impl u128 {
         from_xe_bytes_doc = "",
         bound_condition = "",
     }
+    midpoint_impl! { u128, unsigned }
 }
 
 #[cfg(target_pointer_width = "16")]
@@ -1176,6 +1232,7 @@ impl usize {
         bound_condition = " on 16-bit targets",
     }
     widening_impl! { usize, u32, 16, unsigned }
+    midpoint_impl! { usize, u32, unsigned }
 }
 
 #[cfg(target_pointer_width = "32")]
@@ -1200,6 +1257,7 @@ impl usize {
         bound_condition = " on 32-bit targets",
     }
     widening_impl! { usize, u64, 32, unsigned }
+    midpoint_impl! { usize, u64, unsigned }
 }
 
 #[cfg(target_pointer_width = "64")]
@@ -1224,6 +1282,7 @@ impl usize {
         bound_condition = " on 64-bit targets",
     }
     widening_impl! { usize, u128, 64, unsigned }
+    midpoint_impl! { usize, u128, unsigned }
 }
 
 impl usize {
diff --git a/library/core/tests/lib.rs b/library/core/tests/lib.rs
index f460da35dd3..79b6771a7b5 100644
--- a/library/core/tests/lib.rs
+++ b/library/core/tests/lib.rs
@@ -54,6 +54,7 @@
 #![feature(maybe_uninit_uninit_array_transpose)]
 #![feature(min_specialization)]
 #![feature(numfmt)]
+#![feature(num_midpoint)]
 #![feature(step_trait)]
 #![feature(str_internals)]
 #![feature(std_internals)]
diff --git a/library/core/tests/num/int_macros.rs b/library/core/tests/num/int_macros.rs
index 18c55e43aac..439bbe66997 100644
--- a/library/core/tests/num/int_macros.rs
+++ b/library/core/tests/num/int_macros.rs
@@ -364,6 +364,32 @@ macro_rules! int_module {
                 assert_eq!((0 as $T).borrowing_sub($T::MIN, false), ($T::MIN, true));
                 assert_eq!((0 as $T).borrowing_sub($T::MIN, true), ($T::MAX, false));
             }
+
+            #[test]
+            fn test_midpoint() {
+                assert_eq!(<$T>::midpoint(1, 3), 2);
+                assert_eq!(<$T>::midpoint(3, 1), 2);
+
+                assert_eq!(<$T>::midpoint(0, 0), 0);
+                assert_eq!(<$T>::midpoint(0, 2), 1);
+                assert_eq!(<$T>::midpoint(2, 0), 1);
+                assert_eq!(<$T>::midpoint(2, 2), 2);
+
+                assert_eq!(<$T>::midpoint(1, 4), 2);
+                assert_eq!(<$T>::midpoint(4, 1), 2);
+                assert_eq!(<$T>::midpoint(3, 4), 3);
+                assert_eq!(<$T>::midpoint(4, 3), 3);
+
+                assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MAX), -1);
+                assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MIN), -1);
+                assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MIN), <$T>::MIN);
+                assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MAX), <$T>::MAX);
+
+                assert_eq!(<$T>::midpoint(<$T>::MIN, 6), <$T>::MIN / 2 + 3);
+                assert_eq!(<$T>::midpoint(6, <$T>::MIN), <$T>::MIN / 2 + 3);
+                assert_eq!(<$T>::midpoint(<$T>::MAX, 6), <$T>::MAX / 2 + 3);
+                assert_eq!(<$T>::midpoint(6, <$T>::MAX), <$T>::MAX / 2 + 3);
+            }
         }
     };
 }
diff --git a/library/core/tests/num/uint_macros.rs b/library/core/tests/num/uint_macros.rs
index 15ae9f2324f..7d6203db0b9 100644
--- a/library/core/tests/num/uint_macros.rs
+++ b/library/core/tests/num/uint_macros.rs
@@ -252,6 +252,32 @@ macro_rules! uint_module {
                 assert_eq!($T::MAX.borrowing_sub(0, true), ($T::MAX - 1, false));
                 assert_eq!($T::MAX.borrowing_sub($T::MAX, true), ($T::MAX, true));
             }
+
+            #[test]
+            fn test_midpoint() {
+                assert_eq!(<$T>::midpoint(1, 3), 2);
+                assert_eq!(<$T>::midpoint(3, 1), 2);
+
+                assert_eq!(<$T>::midpoint(0, 0), 0);
+                assert_eq!(<$T>::midpoint(0, 2), 1);
+                assert_eq!(<$T>::midpoint(2, 0), 1);
+                assert_eq!(<$T>::midpoint(2, 2), 2);
+
+                assert_eq!(<$T>::midpoint(1, 4), 2);
+                assert_eq!(<$T>::midpoint(4, 1), 2);
+                assert_eq!(<$T>::midpoint(3, 4), 3);
+                assert_eq!(<$T>::midpoint(4, 3), 3);
+
+                assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MAX), (<$T>::MAX - <$T>::MIN) / 2);
+                assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MIN), (<$T>::MAX - <$T>::MIN) / 2);
+                assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MIN), <$T>::MIN);
+                assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MAX), <$T>::MAX);
+
+                assert_eq!(<$T>::midpoint(<$T>::MIN, 6), <$T>::MIN / 2 + 3);
+                assert_eq!(<$T>::midpoint(6, <$T>::MIN), <$T>::MIN / 2 + 3);
+                assert_eq!(<$T>::midpoint(<$T>::MAX, 6), (<$T>::MAX - <$T>::MIN) / 2 + 3);
+                assert_eq!(<$T>::midpoint(6, <$T>::MAX), (<$T>::MAX - <$T>::MIN) / 2 + 3);
+            }
         }
     };
 }