about summary refs log tree commit diff
path: root/compiler/rustc_codegen_ssa/src
diff options
context:
space:
mode:
authorScott McMurray <scottmcm@users.noreply.github.com>2025-07-11 05:07:37 -0700
committerScott McMurray <scottmcm@users.noreply.github.com>2025-07-12 04:53:24 -0700
commitd5bcfb334b616e01c3589ba7b1697c7ebc47a024 (patch)
treec154ef170b3ef1d3fb0795b10823f99b3c2bf9cb /compiler/rustc_codegen_ssa/src
parent13b1e4030aecb24b5a78cfa44594641e7d49b4b6 (diff)
downloadrust-d5bcfb334b616e01c3589ba7b1697c7ebc47a024.tar.gz
rust-d5bcfb334b616e01c3589ba7b1697c7ebc47a024.zip
Simplify codegen for niche-encoded variant tests
Diffstat (limited to 'compiler/rustc_codegen_ssa/src')
-rw-r--r--compiler/rustc_codegen_ssa/src/mir/operand.rs77
1 files changed, 50 insertions, 27 deletions
diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs
index b0d191528a8..9910cb5469f 100644
--- a/compiler/rustc_codegen_ssa/src/mir/operand.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs
@@ -486,6 +486,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
                 // value and the variant index match, since that's all `Niche` can encode.
 
                 let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
+                let niche_start_const = bx.cx().const_uint_big(tag_llty, niche_start);
 
                 // We have a subrange `niche_start..=niche_end` inside `range`.
                 // If the value of the tag is inside this subrange, it's a
@@ -511,35 +512,44 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
                     // } else {
                     //     untagged_variant
                     // }
-                    let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
-                    let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
+                    let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start_const);
                     let tagged_discr =
                         bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
                     (is_niche, tagged_discr, 0)
                 } else {
                     // The special cases don't apply, so we'll have to go with
                     // the general algorithm.
-                    let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
+
+                    let tag_range = tag_scalar.valid_range(&dl);
+                    let tag_size = tag_scalar.size(&dl);
+                    let niche_end = u128::from(relative_max).wrapping_add(niche_start);
+                    let niche_end = tag_size.truncate(niche_end);
+
+                    let relative_discr = bx.sub(tag, niche_start_const);
                     let cast_tag = bx.intcast(relative_discr, cast_to, false);
-                    let is_niche = bx.icmp(
-                        IntPredicate::IntULE,
-                        relative_discr,
-                        bx.cx().const_uint(tag_llty, relative_max as u64),
-                    );
-
-                    // Thanks to parameter attributes and load metadata, LLVM already knows
-                    // the general valid range of the tag. It's possible, though, for there
-                    // to be an impossible value *in the middle*, which those ranges don't
-                    // communicate, so it's worth an `assume` to let the optimizer know.
-                    if niche_variants.contains(&untagged_variant)
-                        && bx.cx().sess().opts.optimize != OptLevel::No
-                    {
-                        let impossible =
-                            u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32());
-                        let impossible = bx.cx().const_uint(tag_llty, impossible);
-                        let ne = bx.icmp(IntPredicate::IntNE, relative_discr, impossible);
-                        bx.assume(ne);
-                    }
+                    let is_niche = if tag_range.no_unsigned_wraparound(tag_size) == Ok(true) {
+                        if niche_start == tag_range.start {
+                            let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end);
+                            bx.icmp(IntPredicate::IntULE, tag, niche_end_const)
+                        } else {
+                            assert_eq!(niche_end, tag_range.end);
+                            bx.icmp(IntPredicate::IntUGE, tag, niche_start_const)
+                        }
+                    } else if tag_range.no_signed_wraparound(tag_size) == Ok(true) {
+                        if niche_start == tag_range.start {
+                            let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end);
+                            bx.icmp(IntPredicate::IntSLE, tag, niche_end_const)
+                        } else {
+                            assert_eq!(niche_end, tag_range.end);
+                            bx.icmp(IntPredicate::IntSGE, tag, niche_start_const)
+                        }
+                    } else {
+                        bx.icmp(
+                            IntPredicate::IntULE,
+                            relative_discr,
+                            bx.cx().const_uint(tag_llty, relative_max as u64),
+                        )
+                    };
 
                     (is_niche, cast_tag, niche_variants.start().as_u32() as u128)
                 };
@@ -550,11 +560,24 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
                     bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
                 };
 
-                let discr = bx.select(
-                    is_niche,
-                    tagged_discr,
-                    bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
-                );
+                let untagged_variant_const =
+                    bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32()));
+
+                // Thanks to parameter attributes and load metadata, LLVM already knows
+                // the general valid range of the tag. It's possible, though, for there
+                // to be an impossible value *in the middle*, which those ranges don't
+                // communicate, so it's worth an `assume` to let the optimizer know.
+                // Most importantly, this means when optimizing a variant test like
+                // `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that
+                // to `!is_niche` because the `complex` part can't possibly match.
+                if niche_variants.contains(&untagged_variant)
+                    && bx.cx().sess().opts.optimize != OptLevel::No
+                {
+                    let ne = bx.icmp(IntPredicate::IntNE, tagged_discr, untagged_variant_const);
+                    bx.assume(ne);
+                }
+
+                let discr = bx.select(is_niche, tagged_discr, untagged_variant_const);
 
                 // In principle we could insert assumes on the possible range of `discr`, but
                 // currently in LLVM this isn't worth it because the original `tag` will