about summary refs log tree commit diff
diff options
context:
space:
mode:
authorUrgau <urgau@numericable.fr>2024-10-26 18:44:05 +0200
committerUrgau <urgau@numericable.fr>2024-10-26 18:46:41 +0200
commit00444bab2602ca487be08d5e2eaa6179833333b8 (patch)
tree6a00b3f33fae026dab91af8b5afc2ae9283628be
parent80d0d927d5069b67cc08c0c65b48e7b6e0cdeeb5 (diff)
downloadrust-00444bab2602ca487be08d5e2eaa6179833333b8.tar.gz
rust-00444bab2602ca487be08d5e2eaa6179833333b8.zip
Round negative signed integer towards zero in `iN::midpoint`
Instead of towards negative infinity as is currently the case.

This done so that the obvious expectations of
`midpoint(a, b) == midpoint(b, a)` and
`midpoint(-a, -b) == -midpoint(a, b)` are true, which makes the even
more obvious implementation `(a + b) / 2` true.

https://github.com/rust-lang/rust/issues/110840#issuecomment-2336753931
-rw-r--r--library/core/src/num/int_macros.rs38
-rw-r--r--library/core/src/num/mod.rs65
-rw-r--r--library/core/tests/num/int_macros.rs4
3 files changed, 67 insertions, 40 deletions
diff --git a/library/core/src/num/int_macros.rs b/library/core/src/num/int_macros.rs
index 1d640ea74c4..01ecaf2710f 100644
--- a/library/core/src/num/int_macros.rs
+++ b/library/core/src/num/int_macros.rs
@@ -3181,44 +3181,6 @@ macro_rules! int_impl {
             }
         }
 
-        /// Calculates the middle point of `self` and `rhs`.
-        ///
-        /// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
-        /// sufficiently-large signed integral type. This implies that the result is
-        /// always rounded towards negative infinity and that no overflow will ever occur.
-        ///
-        /// # Examples
-        ///
-        /// ```
-        /// #![feature(num_midpoint)]
-        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
-        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(-1), -1);")]
-        #[doc = concat!("assert_eq!((-1", stringify!($SelfT), ").midpoint(0), -1);")]
-        /// ```
-        #[unstable(feature = "num_midpoint", issue = "110840")]
-        #[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
-        #[rustc_allow_const_fn_unstable(const_num_midpoint)]
-        #[must_use = "this returns the result of the operation, \
-                      without modifying the original"]
-        #[inline]
-        pub const fn midpoint(self, rhs: Self) -> Self {
-            const U: $UnsignedT = <$SelfT>::MIN.unsigned_abs();
-
-            // Map an $SelfT to an $UnsignedT
-            // ex: i8 [-128; 127] to [0; 255]
-            const fn map(a: $SelfT) -> $UnsignedT {
-                (a as $UnsignedT) ^ U
-            }
-
-            // Map an $UnsignedT to an $SelfT
-            // ex: u8 [0; 255] to [-128; 127]
-            const fn demap(a: $UnsignedT) -> $SelfT {
-                (a ^ U) as $SelfT
-            }
-
-            demap(<$UnsignedT>::midpoint(map(self), map(rhs)))
-        }
-
         /// Returns the logarithm of the number with respect to an arbitrary base,
         /// rounded down.
         ///
diff --git a/library/core/src/num/mod.rs b/library/core/src/num/mod.rs
index f95cfd33ae5..9a5e211dd60 100644
--- a/library/core/src/num/mod.rs
+++ b/library/core/src/num/mod.rs
@@ -124,6 +124,37 @@ macro_rules! midpoint_impl {
             ((self ^ rhs) >> 1) + (self & rhs)
         }
     };
+    ($SelfT:ty, signed) => {
+        /// Calculates the middle point of `self` and `rhs`.
+        ///
+        /// `midpoint(a, b)` is `(a + b) / 2` as if it were performed in a
+        /// sufficiently-large signed integral type. This implies that the result is
+        /// always rounded towards zero and that no overflow will ever occur.
+        ///
+        /// # Examples
+        ///
+        /// ```
+        /// #![feature(num_midpoint)]
+        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
+        #[doc = concat!("assert_eq!((-1", stringify!($SelfT), ").midpoint(2), 0);")]
+        #[doc = concat!("assert_eq!((-7", stringify!($SelfT), ").midpoint(0), -3);")]
+        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(-7), -3);")]
+        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(7), 3);")]
+        /// ```
+        #[unstable(feature = "num_midpoint", issue = "110840")]
+        #[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
+        #[must_use = "this returns the result of the operation, \
+                      without modifying the original"]
+        #[inline]
+        pub const fn midpoint(self, rhs: Self) -> Self {
+            // Use the well known branchless algorithm from Hacker's Delight to compute
+            // `(a + b) / 2` without overflowing: `((a ^ b) >> 1) + (a & b)`.
+            let t = ((self ^ rhs) >> 1) + (self & rhs);
+            // Except that it fails for integers whose sum is an odd negative number as
+            // their floor is one less than their average. So we adjust the result.
+            t + (if t < 0 { 1 } else { 0 } & (self ^ rhs))
+        }
+    };
     ($SelfT:ty, $WideT:ty, unsigned) => {
         /// Calculates the middle point of `self` and `rhs`.
         ///
@@ -147,6 +178,32 @@ macro_rules! midpoint_impl {
             ((self as $WideT + rhs as $WideT) / 2) as $SelfT
         }
     };
+    ($SelfT:ty, $WideT:ty, signed) => {
+        /// Calculates the middle point of `self` and `rhs`.
+        ///
+        /// `midpoint(a, b)` is `(a + b) / 2` as if it were performed in a
+        /// sufficiently-large signed integral type. This implies that the result is
+        /// always rounded towards zero and that no overflow will ever occur.
+        ///
+        /// # Examples
+        ///
+        /// ```
+        /// #![feature(num_midpoint)]
+        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
+        #[doc = concat!("assert_eq!((-1", stringify!($SelfT), ").midpoint(2), 0);")]
+        #[doc = concat!("assert_eq!((-7", stringify!($SelfT), ").midpoint(0), -3);")]
+        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(-7), -3);")]
+        #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(7), 3);")]
+        /// ```
+        #[unstable(feature = "num_midpoint", issue = "110840")]
+        #[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
+        #[must_use = "this returns the result of the operation, \
+                      without modifying the original"]
+        #[inline]
+        pub const fn midpoint(self, rhs: $SelfT) -> $SelfT {
+            ((self as $WideT + rhs as $WideT) / 2) as $SelfT
+        }
+    };
 }
 
 macro_rules! widening_impl {
@@ -300,6 +357,7 @@ impl i8 {
         from_xe_bytes_doc = "",
         bound_condition = "",
     }
+    midpoint_impl! { i8, i16, signed }
 }
 
 impl i16 {
@@ -323,6 +381,7 @@ impl i16 {
         from_xe_bytes_doc = "",
         bound_condition = "",
     }
+    midpoint_impl! { i16, i32, signed }
 }
 
 impl i32 {
@@ -346,6 +405,7 @@ impl i32 {
         from_xe_bytes_doc = "",
         bound_condition = "",
     }
+    midpoint_impl! { i32, i64, signed }
 }
 
 impl i64 {
@@ -369,6 +429,7 @@ impl i64 {
         from_xe_bytes_doc = "",
         bound_condition = "",
     }
+    midpoint_impl! { i64, i128, signed }
 }
 
 impl i128 {
@@ -394,6 +455,7 @@ impl i128 {
         from_xe_bytes_doc = "",
         bound_condition = "",
     }
+    midpoint_impl! { i128, signed }
 }
 
 #[cfg(target_pointer_width = "16")]
@@ -418,6 +480,7 @@ impl isize {
         from_xe_bytes_doc = usize_isize_from_xe_bytes_doc!(),
         bound_condition = " on 16-bit targets",
     }
+    midpoint_impl! { isize, i32, signed }
 }
 
 #[cfg(target_pointer_width = "32")]
@@ -442,6 +505,7 @@ impl isize {
         from_xe_bytes_doc = usize_isize_from_xe_bytes_doc!(),
         bound_condition = " on 32-bit targets",
     }
+    midpoint_impl! { isize, i64, signed }
 }
 
 #[cfg(target_pointer_width = "64")]
@@ -466,6 +530,7 @@ impl isize {
         from_xe_bytes_doc = usize_isize_from_xe_bytes_doc!(),
         bound_condition = " on 64-bit targets",
     }
+    midpoint_impl! { isize, i128, signed }
 }
 
 /// If the 6th bit is set ascii is lower case.
diff --git a/library/core/tests/num/int_macros.rs b/library/core/tests/num/int_macros.rs
index 1608080d6b6..474d57049ab 100644
--- a/library/core/tests/num/int_macros.rs
+++ b/library/core/tests/num/int_macros.rs
@@ -369,8 +369,8 @@ macro_rules! int_module {
                 assert_eq_const_safe!(<$T>::midpoint(3, 4), 3);
                 assert_eq_const_safe!(<$T>::midpoint(4, 3), 3);
 
-                assert_eq_const_safe!(<$T>::midpoint(<$T>::MIN, <$T>::MAX), -1);
-                assert_eq_const_safe!(<$T>::midpoint(<$T>::MAX, <$T>::MIN), -1);
+                assert_eq_const_safe!(<$T>::midpoint(<$T>::MIN, <$T>::MAX), 0);
+                assert_eq_const_safe!(<$T>::midpoint(<$T>::MAX, <$T>::MIN), 0);
                 assert_eq_const_safe!(<$T>::midpoint(<$T>::MIN, <$T>::MIN), <$T>::MIN);
                 assert_eq_const_safe!(<$T>::midpoint(<$T>::MAX, <$T>::MAX), <$T>::MAX);