about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-10-24 21:47:53 +0000
committerbors <bors@rust-lang.org>2023-10-24 21:47:53 +0000
commitdf871fbf053de3a855398964cd05fadbe91cf4fd (patch)
treef9bde4acab06d6aad30130ff26aaaca350578b4f /compiler/rustc_mir_transform/src
parent151256bd4b577f92922c0fbdf94b12d69cfb08d3 (diff)
parent9c85dfa1d755004b5499f45e2fd921c298f90a2e (diff)
downloadrust-df871fbf053de3a855398964cd05fadbe91cf4fd.tar.gz
rust-df871fbf053de3a855398964cd05fadbe91cf4fd.zip
Auto merge of #115796 - cjgillot:const-prop-rvalue, r=oli-obk
Generate aggregate constants in DataflowConstProp.
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/dataflow_const_prop.rs190
1 files changed, 172 insertions, 18 deletions
diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
index 85a0be8a44c..2c29978173f 100644
--- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
+++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
@@ -2,13 +2,13 @@
 //!
 //! Currently, this pass only propagates scalar values.
 
-use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
+use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable};
 use rustc_data_structures::fx::FxHashMap;
 use rustc_hir::def::DefKind;
 use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar};
 use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
 use rustc_middle::mir::*;
-use rustc_middle::ty::layout::TyAndLayout;
+use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
 use rustc_middle::ty::{self, Ty, TyCtxt};
 use rustc_mir_dataflow::value_analysis::{
     Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
@@ -16,8 +16,9 @@ use rustc_mir_dataflow::value_analysis::{
 use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
 use rustc_span::def_id::DefId;
 use rustc_span::DUMMY_SP;
-use rustc_target::abi::{FieldIdx, VariantIdx};
+use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT};
 
+use crate::const_prop::throw_machine_stop_str;
 use crate::MirPass;
 
 // These constants are somewhat random guesses and have not been optimized.
@@ -553,16 +554,151 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> {
 
     fn try_make_constant(
         &self,
+        ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>,
         place: Place<'tcx>,
         state: &State<FlatSet<Scalar>>,
         map: &Map,
     ) -> Option<Const<'tcx>> {
-        let FlatSet::Elem(Scalar::Int(value)) = state.get(place.as_ref(), &map) else {
-            return None;
-        };
         let ty = place.ty(self.local_decls, self.patch.tcx).ty;
-        Some(Const::Val(ConstValue::Scalar(value.into()), ty))
+        let layout = ecx.layout_of(ty).ok()?;
+
+        if layout.is_zst() {
+            return Some(Const::zero_sized(ty));
+        }
+
+        if layout.is_unsized() {
+            return None;
+        }
+
+        let place = map.find(place.as_ref())?;
+        if layout.abi.is_scalar()
+            && let Some(value) = propagatable_scalar(place, state, map)
+        {
+            return Some(Const::Val(ConstValue::Scalar(value), ty));
+        }
+
+        if matches!(layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
+            let alloc_id = ecx
+                .intern_with_temp_alloc(layout, |ecx, dest| {
+                    try_write_constant(ecx, dest, place, ty, state, map)
+                })
+                .ok()?;
+            return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty));
+        }
+
+        None
+    }
+}
+
+fn propagatable_scalar(
+    place: PlaceIndex,
+    state: &State<FlatSet<Scalar>>,
+    map: &Map,
+) -> Option<Scalar> {
+    if let FlatSet::Elem(value) = state.get_idx(place, map) && value.try_to_int().is_ok() {
+        // Do not attempt to propagate pointers, as we may fail to preserve their identity.
+        Some(value)
+    } else {
+        None
+    }
+}
+
+#[instrument(level = "trace", skip(ecx, state, map))]
+fn try_write_constant<'tcx>(
+    ecx: &mut InterpCx<'_, 'tcx, DummyMachine>,
+    dest: &PlaceTy<'tcx>,
+    place: PlaceIndex,
+    ty: Ty<'tcx>,
+    state: &State<FlatSet<Scalar>>,
+    map: &Map,
+) -> InterpResult<'tcx> {
+    let layout = ecx.layout_of(ty)?;
+
+    // Fast path for ZSTs.
+    if layout.is_zst() {
+        return Ok(());
+    }
+
+    // Fast path for scalars.
+    if layout.abi.is_scalar()
+        && let Some(value) = propagatable_scalar(place, state, map)
+    {
+        return ecx.write_immediate(Immediate::Scalar(value), dest);
+    }
+
+    match ty.kind() {
+        // ZSTs. Nothing to do.
+        ty::FnDef(..) => {}
+
+        // Those are scalars, must be handled above.
+        ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"),
+
+        ty::Tuple(elem_tys) => {
+            for (i, elem) in elem_tys.iter().enumerate() {
+                let Some(field) = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))) else {
+                    throw_machine_stop_str!("missing field in tuple")
+                };
+                let field_dest = ecx.project_field(dest, i)?;
+                try_write_constant(ecx, &field_dest, field, elem, state, map)?;
+            }
+        }
+
+        ty::Adt(def, args) => {
+            if def.is_union() {
+                throw_machine_stop_str!("cannot propagate unions")
+            }
+
+            let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() {
+                let Some(discr) = map.apply(place, TrackElem::Discriminant) else {
+                    throw_machine_stop_str!("missing discriminant for enum")
+                };
+                let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
+                    throw_machine_stop_str!("discriminant with provenance")
+                };
+                let discr_bits = discr.assert_bits(discr.size());
+                let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else {
+                    throw_machine_stop_str!("illegal discriminant for enum")
+                };
+                let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else {
+                    throw_machine_stop_str!("missing variant for enum")
+                };
+                let variant_dest = ecx.project_downcast(dest, variant)?;
+                (variant, def.variant(variant), variant_place, variant_dest)
+            } else {
+                (FIRST_VARIANT, def.non_enum_variant(), place, dest.clone())
+            };
+
+            for (i, field) in variant_def.fields.iter_enumerated() {
+                let ty = field.ty(*ecx.tcx, args);
+                let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else {
+                    throw_machine_stop_str!("missing field in ADT")
+                };
+                let field_dest = ecx.project_field(&variant_dest, i.as_usize())?;
+                try_write_constant(ecx, &field_dest, field, ty, state, map)?;
+            }
+            ecx.write_discriminant(variant_idx, dest)?;
+        }
+
+        // Unsupported for now.
+        ty::Array(_, _)
+
+        // Do not attempt to support indirection in constants.
+        | ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_)
+
+        | ty::Never
+        | ty::Foreign(..)
+        | ty::Alias(..)
+        | ty::Param(_)
+        | ty::Bound(..)
+        | ty::Placeholder(..)
+        | ty::Closure(..)
+        | ty::Coroutine(..)
+        | ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"),
+
+        ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(),
     }
+
+    Ok(())
 }
 
 impl<'mir, 'tcx>
@@ -580,8 +716,13 @@ impl<'mir, 'tcx>
     ) {
         match &statement.kind {
             StatementKind::Assign(box (_, rvalue)) => {
-                OperandCollector { state, visitor: self, map: &results.analysis.0.map }
-                    .visit_rvalue(rvalue, location);
+                OperandCollector {
+                    state,
+                    visitor: self,
+                    ecx: &mut results.analysis.0.ecx,
+                    map: &results.analysis.0.map,
+                }
+                .visit_rvalue(rvalue, location);
             }
             _ => (),
         }
@@ -599,7 +740,12 @@ impl<'mir, 'tcx>
                 // Don't overwrite the assignment if it already uses a constant (to keep the span).
             }
             StatementKind::Assign(box (place, _)) => {
-                if let Some(value) = self.try_make_constant(place, state, &results.analysis.0.map) {
+                if let Some(value) = self.try_make_constant(
+                    &mut results.analysis.0.ecx,
+                    place,
+                    state,
+                    &results.analysis.0.map,
+                ) {
                     self.patch.assignments.insert(location, value);
                 }
             }
@@ -614,8 +760,13 @@ impl<'mir, 'tcx>
         terminator: &'mir Terminator<'tcx>,
         location: Location,
     ) {
-        OperandCollector { state, visitor: self, map: &results.analysis.0.map }
-            .visit_terminator(terminator, location);
+        OperandCollector {
+            state,
+            visitor: self,
+            ecx: &mut results.analysis.0.ecx,
+            map: &results.analysis.0.map,
+        }
+        .visit_terminator(terminator, location);
     }
 }
 
@@ -670,6 +821,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> {
 struct OperandCollector<'tcx, 'map, 'locals, 'a> {
     state: &'a State<FlatSet<Scalar>>,
     visitor: &'a mut Collector<'tcx, 'locals>,
+    ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>,
     map: &'map Map,
 }
 
@@ -682,7 +834,7 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
         location: Location,
     ) {
         if let PlaceElem::Index(local) = elem
-            && let Some(value) = self.visitor.try_make_constant(local.into(), self.state, self.map)
+            && let Some(value) = self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map)
         {
             self.visitor.patch.before_effect.insert((location, local.into()), value);
         }
@@ -690,7 +842,9 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
 
     fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
         if let Some(place) = operand.place() {
-            if let Some(value) = self.visitor.try_make_constant(place, self.state, self.map) {
+            if let Some(value) =
+                self.visitor.try_make_constant(self.ecx, place, self.state, self.map)
+            {
                 self.visitor.patch.before_effect.insert((location, place), value);
             } else if !place.projection.is_empty() {
                 // Try to propagate into `Index` projections.
@@ -713,7 +867,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
     }
 
     fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
-        unimplemented!()
+        false
     }
 
     fn before_access_global(
@@ -725,13 +879,13 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
         is_write: bool,
     ) -> InterpResult<'tcx> {
         if is_write {
-            crate::const_prop::throw_machine_stop_str!("can't write to global");
+            throw_machine_stop_str!("can't write to global");
         }
 
         // If the static allocation is mutable, then we can't const prop it as its content
         // might be different at runtime.
         if alloc.inner().mutability.is_mut() {
-            crate::const_prop::throw_machine_stop_str!("can't access mutable globals in ConstProp");
+            throw_machine_stop_str!("can't access mutable globals in ConstProp");
         }
 
         Ok(())
@@ -781,7 +935,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
         _left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
         _right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
     ) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> {
-        crate::const_prop::throw_machine_stop_str!("can't do pointer arithmetic");
+        throw_machine_stop_str!("can't do pointer arithmetic");
     }
 
     fn expose_ptr(