about summary refs log tree commit diff
path: root/compiler/rustc_codegen_ssa/src/mir/operand.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_ssa/src/mir/operand.rs')
-rw-r--r--compiler/rustc_codegen_ssa/src/mir/operand.rs139
1 files changed, 137 insertions, 2 deletions
diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs
index acae09b2c25..cfebc8840fa 100644
--- a/compiler/rustc_codegen_ssa/src/mir/operand.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs
@@ -3,16 +3,17 @@ use std::fmt;
 use arrayvec::ArrayVec;
 use either::Either;
 use rustc_abi as abi;
-use rustc_abi::{Align, BackendRepr, Size};
+use rustc_abi::{Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, Variants};
 use rustc_middle::mir::interpret::{Pointer, Scalar, alloc_range};
 use rustc_middle::mir::{self, ConstValue};
 use rustc_middle::ty::Ty;
 use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
 use rustc_middle::{bug, span_bug};
-use tracing::debug;
+use tracing::{debug, instrument};
 
 use super::place::{PlaceRef, PlaceValue};
 use super::{FunctionCx, LocalRef};
+use crate::common::IntPredicate;
 use crate::traits::*;
 use crate::{MemFlags, size_of_val};
 
@@ -415,6 +416,140 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
 
         OperandRef { val, layout: field }
     }
+
+    /// Obtain the actual discriminant of a value.
+    #[instrument(level = "trace", skip(fx, bx))]
+    pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
+        self,
+        fx: &mut FunctionCx<'a, 'tcx, Bx>,
+        bx: &mut Bx,
+        cast_to: Ty<'tcx>,
+    ) -> V {
+        let dl = &bx.tcx().data_layout;
+        let cast_to_layout = bx.cx().layout_of(cast_to);
+        let cast_to = bx.cx().immediate_backend_type(cast_to_layout);
+        if self.layout.is_uninhabited() {
+            return bx.cx().const_poison(cast_to);
+        }
+        let (tag_scalar, tag_encoding, tag_field) = match self.layout.variants {
+            Variants::Empty => unreachable!("we already handled uninhabited types"),
+            Variants::Single { index } => {
+                let discr_val =
+                    if let Some(discr) = self.layout.ty.discriminant_for_variant(bx.tcx(), index) {
+                        discr.val
+                    } else {
+                        assert_eq!(index, FIRST_VARIANT);
+                        0
+                    };
+                return bx.cx().const_uint_big(cast_to, discr_val);
+            }
+            Variants::Multiple { tag, ref tag_encoding, tag_field, .. } => {
+                (tag, tag_encoding, tag_field)
+            }
+        };
+
+        // Read the tag/niche-encoded discriminant from memory.
+        let tag_op = match self.val {
+            OperandValue::ZeroSized => bug!(),
+            OperandValue::Immediate(_) | OperandValue::Pair(_, _) => {
+                self.extract_field(fx, bx, tag_field)
+            }
+            OperandValue::Ref(place) => {
+                let tag = place.with_type(self.layout).project_field(bx, tag_field);
+                bx.load_operand(tag)
+            }
+        };
+        let tag_imm = tag_op.immediate();
+
+        // Decode the discriminant (specifically if it's niche-encoded).
+        match *tag_encoding {
+            TagEncoding::Direct => {
+                let signed = match tag_scalar.primitive() {
+                    // We use `i1` for bytes that are always `0` or `1`,
+                    // e.g., `#[repr(i8)] enum E { A, B }`, but we can't
+                    // let LLVM interpret the `i1` as signed, because
+                    // then `i1 1` (i.e., `E::B`) is effectively `i8 -1`.
+                    Primitive::Int(_, signed) => !tag_scalar.is_bool() && signed,
+                    _ => false,
+                };
+                bx.intcast(tag_imm, cast_to, signed)
+            }
+            TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
+                // Cast to an integer so we don't have to treat a pointer as a
+                // special case.
+                let (tag, tag_llty) = match tag_scalar.primitive() {
+                    // FIXME(erikdesjardins): handle non-default addrspace ptr sizes
+                    Primitive::Pointer(_) => {
+                        let t = bx.type_from_integer(dl.ptr_sized_integer());
+                        let tag = bx.ptrtoint(tag_imm, t);
+                        (tag, t)
+                    }
+                    _ => (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)),
+                };
+
+                let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
+
+                // We have a subrange `niche_start..=niche_end` inside `range`.
+                // If the value of the tag is inside this subrange, it's a
+                // "niche value", an increment of the discriminant. Otherwise it
+                // indicates the untagged variant.
+                // A general algorithm to extract the discriminant from the tag
+                // is:
+                // relative_tag = tag - niche_start
+                // is_niche = relative_tag <= (ule) relative_max
+                // discr = if is_niche {
+                //     cast(relative_tag) + niche_variants.start()
+                // } else {
+                //     untagged_variant
+                // }
+                // However, we will likely be able to emit simpler code.
+                let (is_niche, tagged_discr, delta) = if relative_max == 0 {
+                    // Best case scenario: only one tagged variant. This will
+                    // likely become just a comparison and a jump.
+                    // The algorithm is:
+                    // is_niche = tag == niche_start
+                    // discr = if is_niche {
+                    //     niche_start
+                    // } else {
+                    //     untagged_variant
+                    // }
+                    let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
+                    let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
+                    let tagged_discr =
+                        bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
+                    (is_niche, tagged_discr, 0)
+                } else {
+                    // The special cases don't apply, so we'll have to go with
+                    // the general algorithm.
+                    let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
+                    let cast_tag = bx.intcast(relative_discr, cast_to, false);
+                    let is_niche = bx.icmp(
+                        IntPredicate::IntULE,
+                        relative_discr,
+                        bx.cx().const_uint(tag_llty, relative_max as u64),
+                    );
+                    (is_niche, cast_tag, niche_variants.start().as_u32() as u128)
+                };
+
+                let tagged_discr = if delta == 0 {
+                    tagged_discr
+                } else {
+                    bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
+                };
+
+                let discr = bx.select(
+                    is_niche,
+                    tagged_discr,
+                    bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
+                );
+
+                // In principle we could insert assumes on the possible range of `discr`, but
+                // currently in LLVM this seems to be a pessimization.
+
+                discr
+            }
+        }
+    }
 }
 
 impl<'a, 'tcx, V: CodegenObject> OperandValue<V> {