about summary refs log tree commit diff
path: root/src/librustc_codegen_ssa
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2019-07-13 20:45:40 +0000
committerbors <bors@rust-lang.org>2019-07-13 20:45:40 +0000
commitd32a7250dbf797c9a89f56de0842d7ad43bfe85f (patch)
tree5a5cffe20a744ad2cb71ea31b9397f5e744c7bf9 /src/librustc_codegen_ssa
parent69656fa4cbafc378fd63f9186d93b0df3cdd9320 (diff)
parentc063057beb96cd4901ab300eed2267c9b73ed589 (diff)
downloadrust-d32a7250dbf797c9a89f56de0842d7ad43bfe85f.tar.gz
rust-d32a7250dbf797c9a89f56de0842d7ad43bfe85f.zip
Auto merge of #62584 - eddyb:circular-math-is-hard, r=pnkfelix
 rustc_codegen_ssa: fix range check in codegen_get_discr.

Fixes #61696, see https://github.com/rust-lang/rust/issues/61696#issuecomment-505473018 for more details.

In short, I had wanted to use `x - a <= b - a` to check whether `x` is in `a..=b` (as it's 1 comparison instead of 2 *and* `b - a` is guaranteed to fit in the same data type, while `b` itself might not), but I ended up with `x - a + c <= b - a + c` instead, because `x - a + c` was the final value needed.

That latter comparison is equivalent to checking that `x` is in `(a - c)..=b`, i.e. it also includes `(a - c)..a`, not just `a..=b`, so if `c` is not `0`, it will cause false positives.

This presented itself as the non-niche ("dataful") variant sometimes being treated like a niche variant, in the presence of uninhabited variants (which made `c`, aka the index of the first niche variant, arbitrarily large).

r? @nagisa, @rkruppe or @oli-obk
Diffstat (limited to 'src/librustc_codegen_ssa')
-rw-r--r--src/librustc_codegen_ssa/mir/place.rs86
1 files changed, 62 insertions, 24 deletions
diff --git a/src/librustc_codegen_ssa/mir/place.rs b/src/librustc_codegen_ssa/mir/place.rs
index 010be3e8c74..588badfa11a 100644
--- a/src/librustc_codegen_ssa/mir/place.rs
+++ b/src/librustc_codegen_ssa/mir/place.rs
@@ -228,8 +228,11 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
             }
         };
 
-        let discr = self.project_field(bx, discr_index);
-        let lldiscr = bx.load_operand(discr).immediate();
+        // Read the tag/niche-encoded discriminant from memory.
+        let encoded_discr = self.project_field(bx, discr_index);
+        let encoded_discr = bx.load_operand(encoded_discr);
+
+        // Decode the discriminant (specifically if it's niche-encoded).
         match *discr_kind {
             layout::DiscriminantKind::Tag => {
                 let signed = match discr_scalar.value {
@@ -240,38 +243,73 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
                     layout::Int(_, signed) => !discr_scalar.is_bool() && signed,
                     _ => false
                 };
-                bx.intcast(lldiscr, cast_to, signed)
+                bx.intcast(encoded_discr.immediate(), cast_to, signed)
             }
             layout::DiscriminantKind::Niche {
                 dataful_variant,
                 ref niche_variants,
                 niche_start,
             } => {
-                let niche_llty = bx.cx().immediate_backend_type(discr.layout);
-                if niche_variants.start() == niche_variants.end() {
-                    // FIXME(eddyb): check the actual primitive type here.
-                    let niche_llval = if niche_start == 0 {
-                        // HACK(eddyb): using `c_null` as it works on all types.
+                // Rebase from niche values to discriminants, and check
+                // whether the result is in range for the niche variants.
+                let niche_llty = bx.cx().immediate_backend_type(encoded_discr.layout);
+                let encoded_discr = encoded_discr.immediate();
+
+                // We first compute the "relative discriminant" (wrt `niche_variants`),
+                // that is, if `n = niche_variants.end() - niche_variants.start()`,
+                // we remap `niche_start..=niche_start + n` (which may wrap around)
+                // to (non-wrap-around) `0..=n`, to be able to check whether the
+                // discriminant corresponds to a niche variant with one comparison.
+                // We also can't go directly to the (variant index) discriminant
+                // and check that it is in the range `niche_variants`, because
+                // that might not fit in the same type, on top of needing an extra
+                // comparison (see also the comment on `let niche_discr`).
+                let relative_discr = if niche_start == 0 {
+                    // Avoid subtracting `0`, which wouldn't work for pointers.
+                    // FIXME(eddyb) check the actual primitive type here.
+                    encoded_discr
+                } else {
+                    bx.sub(encoded_discr, bx.cx().const_uint_big(niche_llty, niche_start))
+                };
+                let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
+                let is_niche = {
+                    let relative_max = if relative_max == 0 {
+                        // Avoid calling `const_uint`, which wouldn't work for pointers.
+                        // FIXME(eddyb) check the actual primitive type here.
                         bx.cx().const_null(niche_llty)
                     } else {
-                        bx.cx().const_uint_big(niche_llty, niche_start)
+                        bx.cx().const_uint(niche_llty, relative_max as u64)
+                    };
+                    bx.icmp(IntPredicate::IntULE, relative_discr, relative_max)
+                };
+
+                // NOTE(eddyb) this addition needs to be performed on the final
+                // type, in case the niche itself can't represent all variant
+                // indices (e.g. `u8` niche with more than `256` variants,
+                // but enough uninhabited variants so that the remaining variants
+                // fit in the niche).
+                // In other words, `niche_variants.end - niche_variants.start`
+                // is representable in the niche, but `niche_variants.end`
+                // might not be, in extreme cases.
+                let niche_discr = {
+                    let relative_discr = if relative_max == 0 {
+                        // HACK(eddyb) since we have only one niche, we know which
+                        // one it is, and we can avoid having a dynamic value here.
+                        bx.cx().const_uint(cast_to, 0)
+                    } else {
+                        bx.intcast(relative_discr, cast_to, false)
                     };
-                    let select_arg = bx.icmp(IntPredicate::IntEQ, lldiscr, niche_llval);
-                    bx.select(select_arg,
+                    bx.add(
+                        relative_discr,
                         bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64),
-                        bx.cx().const_uint(cast_to, dataful_variant.as_u32() as u64))
-                } else {
-                    // Rebase from niche values to discriminant values.
-                    let delta = niche_start.wrapping_sub(niche_variants.start().as_u32() as u128);
-                    let lldiscr = bx.sub(lldiscr, bx.cx().const_uint_big(niche_llty, delta));
-                    let lldiscr_max =
-                        bx.cx().const_uint(niche_llty, niche_variants.end().as_u32() as u64);
-                    let select_arg = bx.icmp(IntPredicate::IntULE, lldiscr, lldiscr_max);
-                    let cast = bx.intcast(lldiscr, cast_to, false);
-                    bx.select(select_arg,
-                        cast,
-                        bx.cx().const_uint(cast_to, dataful_variant.as_u32() as u64))
-                }
+                    )
+                };
+
+                bx.select(
+                    is_niche,
+                    niche_discr,
+                    bx.cx().const_uint(cast_to, dataful_variant.as_u32() as u64),
+                )
             }
         }
     }