about summary refs log tree commit diff
diff options
context:
space:
mode:
authorkadmin <julianknodt@gmail.com>2021-05-28 04:17:00 +0000
committerkadmin <julianknodt@gmail.com>2023-02-07 09:37:55 +0000
commit96db5e9c7b2f8b97b75a5afeae21e0e0abf7bdfe (patch)
tree6dd2afa70a88539d700fe65a1601908cf3d9f86e
parent18144b66e1515fa1391b7c7034ba55c47511fb9e (diff)
downloadrust-96db5e9c7b2f8b97b75a5afeae21e0e0abf7bdfe.tar.gz
rust-96db5e9c7b2f8b97b75a5afeae21e0e0abf7bdfe.zip
Add comments
Still need to make it so that it maps discriminants to variant indexes.
Maybe instead I can map the variant indexes to discriminants?
-rw-r--r--compiler/rustc_mir/src/transform/large_enums.rs82
1 files changed, 49 insertions, 33 deletions
diff --git a/compiler/rustc_mir/src/transform/large_enums.rs b/compiler/rustc_mir/src/transform/large_enums.rs
index b742b7a45e6..a8377c95dcb 100644
--- a/compiler/rustc_mir/src/transform/large_enums.rs
+++ b/compiler/rustc_mir/src/transform/large_enums.rs
@@ -4,7 +4,7 @@ use rustc_data_structures::stable_map::FxHashMap;
 use rustc_middle::mir::*;
 use rustc_middle::ty::{self, Const, List, Ty, TyCtxt};
 use rustc_span::def_id::DefId;
-use rustc_target::abi::{Size, Variants};
+use rustc_target::abi::{Size, TagEncoding, Variants};
 
 /// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
 /// enough discrepanc between them
@@ -31,17 +31,25 @@ impl<const D: u64> EnumSizeOpt<D> {
                 match variants {
                     Variants::Single { .. } => None,
                     Variants::Multiple { variants, .. } if variants.len() <= 1 => None,
+                    Variants::Multiple { tag_encoding, .. }
+                        if matches!(tag_encoding, TagEncoding::Niche { .. }) =>
+                    {
+                        None
+                    }
                     Variants::Multiple { variants, .. } => {
                         let min = variants.iter().map(|v| v.size).min().unwrap();
                         let max = variants.iter().map(|v| v.size).max().unwrap();
                         if max.bytes() - min.bytes() < D {
                             return None;
                         }
-                        Some((
-                            layout.size,
-                            variants.len() as u64,
-                            variants.iter().map(|v| v.size).collect(),
-                        ))
+                        let mut discr_sizes = vec![Size::ZERO; adt_def.discriminants(tcx).count()];
+                        for (var_idx, layout) in variants.iter_enumerated() {
+                            let disc_idx =
+                                adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
+                            assert_eq!(discr_sizes[disc_idx], Size::ZERO);
+                            discr_sizes[disc_idx] = layout.size;
+                        }
+                        Some((layout.size, variants.len() as u64, discr_sizes))
                     }
                 }
             }
@@ -49,7 +57,7 @@ impl<const D: u64> EnumSizeOpt<D> {
         }
     }
     fn optim(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
-        let mut match_cache = FxHashMap::default();
+        let mut alloc_cache = FxHashMap::default();
         let body_did = body.source.def_id();
         let mut patch = MirPatch::new(body);
         let (bbs, local_decls) = body.basic_blocks_and_local_decls_mut();
@@ -61,39 +69,45 @@ impl<const D: u64> EnumSizeOpt<D> {
                         Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
                     )) => {
                         let ty = lhs.ty(local_decls, tcx).ty;
+                        let source_info = st.source_info;
+                        let span = source_info.span;
+
                         let (total_size, num_variants, sizes) =
-                            if let Some((ts, nv, s)) = match_cache.get(ty) {
-                                (*ts, *nv, s)
-                            } else if let Some((ts, nv, s)) = Self::candidate(tcx, ty, body_did) {
-                                // FIXME(jknodt) use entry API.
-                                match_cache.insert(ty, (ts, nv, s));
-                                let (ts, nv, s) = match_cache.get(ty).unwrap();
-                                (*ts, *nv, s)
+                            if let Some((ts, nv, s)) = Self::candidate(tcx, ty, body_did) {
+                                (ts, nv, s)
                             } else {
                                 return None;
                             };
 
-                        let source_info = st.source_info;
-                        let span = source_info.span;
+                        let alloc = if let Some(alloc) = alloc_cache.get(ty) {
+                            alloc
+                        } else {
+                            let mut data =
+                                vec![0; std::mem::size_of::<usize>() * num_variants as usize];
+                            data.copy_from_slice(unsafe { std::mem::transmute(&sizes[..]) });
+                            let alloc = interpret::Allocation::from_bytes(
+                                data,
+                                tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
+                                Mutability::Not,
+                            );
+                            let alloc = tcx.intern_const_alloc(alloc);
+                            alloc_cache.insert(ty, alloc);
+                            // FIXME(jknodt) use entry API
+                            alloc_cache.get(ty).unwrap()
+                        };
 
                         let tmp_ty = tcx.mk_ty(ty::Array(
                             tcx.types.usize,
                             Const::from_usize(tcx, num_variants),
                         ));
 
-                        let new_local = patch.new_temp(tmp_ty, span);
-                        let store_live =
-                            Statement { source_info, kind: StatementKind::StorageLive(new_local) };
-
-                        let place = Place { local: new_local, projection: List::empty() };
-                        let mut data =
-                            vec![0; std::mem::size_of::<usize>() * num_variants as usize];
-                        data.copy_from_slice(unsafe { std::mem::transmute(&sizes[..]) });
-                        let alloc = interpret::Allocation::from_bytes(
-                            data,
-                            tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
-                        );
-                        let alloc = tcx.intern_const_alloc(alloc);
+                        let size_array_local = patch.new_temp(tmp_ty, span);
+                        let store_live = Statement {
+                            source_info,
+                            kind: StatementKind::StorageLive(size_array_local),
+                        };
+
+                        let place = Place { local: size_array_local, projection: List::empty() };
                         let constant_vals = Constant {
                             span,
                             user_ty: None,
@@ -134,9 +148,9 @@ impl<const D: u64> EnumSizeOpt<D> {
                             kind: StatementKind::Assign(box (
                                 size_place,
                                 Rvalue::Use(Operand::Copy(Place {
-                                    local: discr_place.local,
+                                    local: size_array_local,
                                     projection: tcx
-                                        .intern_place_elems(&[PlaceElem::Index(size_place.local)]),
+                                        .intern_place_elems(&[PlaceElem::Index(discr_place.local)]),
                                 })),
                             )),
                         };
@@ -187,8 +201,10 @@ impl<const D: u64> EnumSizeOpt<D> {
                             }),
                         };
 
-                        let store_dead =
-                            Statement { source_info, kind: StatementKind::StorageDead(new_local) };
+                        let store_dead = Statement {
+                            source_info,
+                            kind: StatementKind::StorageDead(size_array_local),
+                        };
                         let iter = std::array::IntoIter::new([
                             store_live,
                             const_assign,