about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCamille GILLOT <gillot.camille@gmail.com>2023-05-03 18:36:53 +0000
committerCamille GILLOT <gillot.camille@gmail.com>2023-10-25 06:46:47 +0000
commit23d4857080a3968447adbb1d55b2720dba46d666 (patch)
tree3a72edf51913b424cf2f8ff2af955ebdfadebfd3
parentf110f22060b86dd2873d7ac6d7a4e444c9348b1d (diff)
downloadrust-23d4857080a3968447adbb1d55b2720dba46d666.tar.gz
rust-23d4857080a3968447adbb1d55b2720dba46d666.zip
Do not compute actual aggregate type.
-rw-r--r--compiler/rustc_mir_transform/src/gvn.rs58
1 files changed, 47 insertions, 11 deletions
diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs
index 66442849863..9880e239957 100644
--- a/compiler/rustc_mir_transform/src/gvn.rs
+++ b/compiler/rustc_mir_transform/src/gvn.rs
@@ -56,6 +56,7 @@
 use rustc_const_eval::interpret::{ImmTy, InterpCx, MemPlaceMeta, OpTy, Projectable, Scalar};
 use rustc_data_structures::fx::{FxHashMap, FxIndexSet};
 use rustc_data_structures::graph::dominators::Dominators;
+use rustc_hir::def::DefKind;
 use rustc_index::bit_set::BitSet;
 use rustc_index::IndexVec;
 use rustc_macros::newtype_index;
@@ -64,6 +65,7 @@ use rustc_middle::mir::visit::*;
 use rustc_middle::mir::*;
 use rustc_middle::ty::layout::LayoutOf;
 use rustc_middle::ty::{self, Ty, TyCtxt, TypeAndMut};
+use rustc_span::def_id::DefId;
 use rustc_span::DUMMY_SP;
 use rustc_target::abi::{self, Abi, Size, VariantIdx, FIRST_VARIANT};
 use std::borrow::Cow;
@@ -136,6 +138,16 @@ newtype_index! {
     struct VnIndex {}
 }
 
+/// Computing the aggregate's type can be quite slow, so we only keep the minimal amount of
+/// information to reconstruct it when needed.
+#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
+enum AggregateTy<'tcx> {
+    /// Invariant: this must not be used for an empty array.
+    Array,
+    Tuple,
+    Def(DefId, ty::GenericArgsRef<'tcx>),
+}
+
 #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
 enum AddressKind {
     Ref(BorrowKind),
@@ -152,7 +164,7 @@ enum Value<'tcx> {
     Constant(Const<'tcx>),
     /// An aggregate value, either tuple/closure/struct/enum.
     /// This does not contain unions, as we cannot reason with the value.
-    Aggregate(Ty<'tcx>, VariantIdx, Vec<VnIndex>),
+    Aggregate(AggregateTy<'tcx>, VariantIdx, Vec<VnIndex>),
     /// This corresponds to a `[value; count]` expression.
     Repeat(VnIndex, ty::Const<'tcx>),
     /// The address of a place.
@@ -289,11 +301,23 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
             Repeat(..) => return None,
 
             Constant(ref constant) => self.ecx.eval_mir_constant(constant, None, None).ok()?,
-            Aggregate(ty, variant, ref fields) => {
+            Aggregate(kind, variant, ref fields) => {
                 let fields = fields
                     .iter()
                     .map(|&f| self.evaluated[f].as_ref())
                     .collect::<Option<Vec<_>>>()?;
+                let ty = match kind {
+                    AggregateTy::Array => {
+                        assert!(fields.len() > 0);
+                        Ty::new_array(self.tcx, fields[0].layout.ty, fields.len() as u64)
+                    }
+                    AggregateTy::Tuple => {
+                        Ty::new_tup_from_iter(self.tcx, fields.iter().map(|f| f.layout.ty))
+                    }
+                    AggregateTy::Def(def_id, args) => {
+                        self.tcx.type_of(def_id).instantiate(self.tcx, args)
+                    }
+                };
                 let variant = if ty.is_enum() { Some(variant) } else { None };
                 let ty = self.ecx.layout_of(ty).ok()?;
                 if ty.is_zst() {
@@ -510,7 +534,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
                     Value::Repeat(inner, _) => {
                         return Some(*inner);
                     }
-                    Value::Aggregate(ty, _, operands) if ty.is_array() => {
+                    Value::Aggregate(AggregateTy::Array, _, operands) => {
                         let offset = if from_end {
                             operands.len() - offset as usize
                         } else {
@@ -659,12 +683,23 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
             }
             Rvalue::NullaryOp(op, ty) => Value::NullaryOp(op, ty),
             Rvalue::Aggregate(box ref kind, ref mut fields) => {
-                let variant_index = match *kind {
-                    AggregateKind::Array(..)
-                    | AggregateKind::Tuple
-                    | AggregateKind::Closure(..)
-                    | AggregateKind::Coroutine(..) => FIRST_VARIANT,
-                    AggregateKind::Adt(_, variant_index, _, _, None) => variant_index,
+                let (ty, variant_index) = match *kind {
+                    // For empty arrays, we have not mean to recover the type. They are ZSTs
+                    // anyway, so return them as such.
+                    AggregateKind::Array(..) | AggregateKind::Tuple if fields.is_empty() => {
+                        return Some(self.insert(Value::Constant(Const::zero_sized(
+                            rvalue.ty(self.local_decls, self.tcx),
+                        ))));
+                    }
+                    AggregateKind::Array(..) => (AggregateTy::Array, FIRST_VARIANT),
+                    AggregateKind::Tuple => (AggregateTy::Tuple, FIRST_VARIANT),
+                    AggregateKind::Closure(did, substs)
+                    | AggregateKind::Coroutine(did, substs, _) => {
+                        (AggregateTy::Def(did, substs), FIRST_VARIANT)
+                    }
+                    AggregateKind::Adt(did, variant_index, substs, _, None) => {
+                        (AggregateTy::Def(did, substs), variant_index)
+                    }
                     // Do not track unions.
                     AggregateKind::Adt(_, _, _, _, Some(_)) => return None,
                 };
@@ -672,7 +707,6 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
                     .iter_mut()
                     .map(|op| self.simplify_operand(op, location).or_else(|| self.new_opaque()))
                     .collect();
-                let ty = rvalue.ty(self.local_decls, self.tcx);
                 Value::Aggregate(ty, variant_index, fields?)
             }
             Rvalue::Ref(_, borrow_kind, ref mut place) => {
@@ -725,8 +759,10 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
 
     fn simplify_discriminant(&mut self, place: VnIndex) -> Option<VnIndex> {
         if let Value::Aggregate(enum_ty, variant, _) = *self.get(place)
-            && enum_ty.is_enum()
+            && let AggregateTy::Def(enum_did, enum_substs) = enum_ty
+            && let DefKind::Enum = self.tcx.def_kind(enum_did)
         {
+            let enum_ty = self.tcx.type_of(enum_did).instantiate(self.tcx, enum_substs);
             let discr = self.ecx.discriminant_for_variant(enum_ty, variant).ok()?;
             return Some(self.insert_scalar(discr.to_scalar(), discr.layout.ty));
         }