about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCaleb Zulawski <caleb.zulawski@gmail.com>2023-12-03 11:27:57 -0500
committerCaleb Zulawski <caleb.zulawski@gmail.com>2023-12-03 11:42:26 -0500
commit289c1d14f0dfd80d5e94141a3b9b59bed41c3539 (patch)
treeb8678434da6d652a52c5f35f8da97af495e47902
parente0e9a4517f9fc021283514da387e70a56061bd3e (diff)
downloadrust-289c1d14f0dfd80d5e94141a3b9b59bed41c3539.tar.gz
rust-289c1d14f0dfd80d5e94141a3b9b59bed41c3539.zip
Fix bitmask vector bit order
-rw-r--r--crates/core_simd/src/masks/full_masks.rs6
-rw-r--r--crates/core_simd/tests/masks.rs42
2 files changed, 48 insertions, 0 deletions
diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs
index 63964f455e0..b184b98a147 100644
--- a/crates/core_simd/src/masks/full_masks.rs
+++ b/crates/core_simd/src/masks/full_masks.rs
@@ -157,6 +157,9 @@ where
                 for x in bytes.as_mut() {
                     *x = x.reverse_bits()
                 }
+                if N % 8 > 0 {
+                    bytes.as_mut()[N / 8] >>= 8 - N % 8;
+                }
             }
 
             bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref());
@@ -180,6 +183,9 @@ where
                 for x in bytes.as_mut() {
                     *x = x.reverse_bits();
                 }
+                if N % 8 > 0 {
+                    bytes.as_mut()[N / 8] >>= 8 - N % 8;
+                }
             }
 
             // Compute the regular mask
diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs
index 00fc2a24e27..fc6a3476b7c 100644
--- a/crates/core_simd/tests/masks.rs
+++ b/crates/core_simd/tests/masks.rs
@@ -99,6 +99,19 @@ macro_rules! test_mask_api {
                 assert_eq!(Mask::<$type, 2>::from_bitmask(bitmask), mask);
             }
 
+            #[cfg(feature = "all_lane_counts")]
+            #[test]
+            fn roundtrip_bitmask_conversion_odd() {
+                let values = [
+                    true, false, true, false, true, true, false, false, false, true, true,
+                ];
+                let mask = Mask::<$type, 11>::from_array(values);
+                let bitmask = mask.to_bitmask();
+                assert_eq!(bitmask, 0b11000110101);
+                assert_eq!(Mask::<$type, 11>::from_bitmask(bitmask), mask);
+            }
+
+
             #[test]
             fn cast() {
                 fn cast_impl<T: core_simd::simd::MaskElement>()
@@ -134,6 +147,35 @@ macro_rules! test_mask_api {
                 assert_eq!(bitmask.resize::<2>(0).to_ne_bytes()[..2], [0b01001001, 0b10000011]);
                 assert_eq!(Mask::<$type, 16>::from_bitmask_vector(bitmask), mask);
             }
+
+            // rust-lang/portable-simd#379
+            #[test]
+            fn roundtrip_bitmask_vector_conversion_small() {
+                use core_simd::simd::ToBytes;
+                let values = [
+                    true, false, true, true
+                ];
+                let mask = Mask::<$type, 4>::from_array(values);
+                let bitmask = mask.to_bitmask_vector();
+                assert_eq!(bitmask.resize::<1>(0).to_ne_bytes()[0], 0b00001101);
+                assert_eq!(Mask::<$type, 4>::from_bitmask_vector(bitmask), mask);
+            }
+
+            /* FIXME doesn't work with non-powers-of-two, yet
+            // rust-lang/portable-simd#379
+            #[cfg(feature = "all_lane_counts")]
+            #[test]
+            fn roundtrip_bitmask_vector_conversion_odd() {
+                use core_simd::simd::ToBytes;
+                let values = [
+                    true, false, true, false, true, true, false, false, false, true, true,
+                ];
+                let mask = Mask::<$type, 11>::from_array(values);
+                let bitmask = mask.to_bitmask_vector();
+                assert_eq!(bitmask.resize::<2>(0).to_ne_bytes()[..2], [0b00110101, 0b00000110]);
+                assert_eq!(Mask::<$type, 11>::from_bitmask_vector(bitmask), mask);
+            }
+            */
         }
     }
 }