diff options
| author | Camille GILLOT <gillot.camille@gmail.com> | 2023-09-17 09:35:06 +0000 |
|---|---|---|
| committer | Camille GILLOT <gillot.camille@gmail.com> | 2023-10-21 16:26:05 +0000 |
| commit | 8c1b039d482149f643a88c6d7af526b18f56dd8f (patch) | |
| tree | 92c781f404a4632691b9b226be17f833c0b6cb44 /compiler/rustc_mir_transform/src | |
| parent | 31d101093c134e69a31ba58893e56647a5831299 (diff) | |
| download | rust-8c1b039d482149f643a88c6d7af526b18f56dd8f.tar.gz rust-8c1b039d482149f643a88c6d7af526b18f56dd8f.zip | |
Use a ConstValue instead.
Diffstat (limited to 'compiler/rustc_mir_transform/src')
| -rw-r--r-- | compiler/rustc_mir_transform/src/dataflow_const_prop.rs | 255 |
1 files changed, 159 insertions, 96 deletions
diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs index 762f57ad29e..2c29978173f 100644 --- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -2,13 +2,13 @@ //! //! Currently, this pass only propagates scalar values. -use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable}; +use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable}; use rustc_data_structures::fx::FxHashMap; use rustc_hir::def::DefKind; use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar}; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::*; -use rustc_middle::ty::layout::TyAndLayout; +use rustc_middle::ty::layout::{LayoutOf, TyAndLayout}; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_mir_dataflow::value_analysis::{ Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace, @@ -16,8 +16,9 @@ use rustc_mir_dataflow::value_analysis::{ use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor}; use rustc_span::def_id::DefId; use rustc_span::DUMMY_SP; -use rustc_target::abi::{FieldIdx, VariantIdx}; +use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT}; +use crate::const_prop::throw_machine_stop_str; use crate::MirPass; // These constants are somewhat random guesses and have not been optimized. @@ -553,107 +554,151 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> { fn try_make_constant( &self, + ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>, place: Place<'tcx>, state: &State<FlatSet<Scalar>>, map: &Map, ) -> Option<Const<'tcx>> { let ty = place.ty(self.local_decls, self.patch.tcx).ty; + let layout = ecx.layout_of(ty).ok()?; + + if layout.is_zst() { + return Some(Const::zero_sized(ty)); + } + + if layout.is_unsized() { + return None; + } + let place = map.find(place.as_ref())?; - if let FlatSet::Elem(Scalar::Int(value)) = state.get_idx(place, map) { - Some(Const::Val(ConstValue::Scalar(value.into()), ty)) - } else { - let valtree = self.try_make_valtree(place, ty, state, map)?; - let constant = ty::Const::new_value(self.patch.tcx, valtree, ty); - Some(Const::Ty(constant)) + if layout.abi.is_scalar() + && let Some(value) = propagatable_scalar(place, state, map) + { + return Some(Const::Val(ConstValue::Scalar(value), ty)); + } + + if matches!(layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) { + let alloc_id = ecx + .intern_with_temp_alloc(layout, |ecx, dest| { + try_write_constant(ecx, dest, place, ty, state, map) + }) + .ok()?; + return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty)); } + + None } +} - fn try_make_valtree( - &self, - place: PlaceIndex, - ty: Ty<'tcx>, - state: &State<FlatSet<Scalar>>, - map: &Map, - ) -> Option<ty::ValTree<'tcx>> { - let tcx = self.patch.tcx; - match ty.kind() { - // ZSTs. - ty::FnDef(..) => Some(ty::ValTree::zst()), - - // Scalars. - ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => { - if let FlatSet::Elem(Scalar::Int(value)) = state.get_idx(place, map) { - Some(ty::ValTree::Leaf(value)) - } else { - None - } - } +fn propagatable_scalar( + place: PlaceIndex, + state: &State<FlatSet<Scalar>>, + map: &Map, +) -> Option<Scalar> { + if let FlatSet::Elem(value) = state.get_idx(place, map) && value.try_to_int().is_ok() { + // Do not attempt to propagate pointers, as we may fail to preserve their identity. + Some(value) + } else { + None + } +} - // Unsupported for now. - ty::Array(_, _) => None, - - ty::Tuple(elem_tys) => { - let branches = elem_tys - .iter() - .enumerate() - .map(|(i, ty)| { - let field = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i)))?; - self.try_make_valtree(field, ty, state, map) - }) - .collect::<Option<Vec<_>>>()?; - Some(ty::ValTree::Branch(tcx.arena.alloc_from_iter(branches.into_iter()))) - } +#[instrument(level = "trace", skip(ecx, state, map))] +fn try_write_constant<'tcx>( + ecx: &mut InterpCx<'_, 'tcx, DummyMachine>, + dest: &PlaceTy<'tcx>, + place: PlaceIndex, + ty: Ty<'tcx>, + state: &State<FlatSet<Scalar>>, + map: &Map, +) -> InterpResult<'tcx> { + let layout = ecx.layout_of(ty)?; + + // Fast path for ZSTs. + if layout.is_zst() { + return Ok(()); + } + + // Fast path for scalars. + if layout.abi.is_scalar() + && let Some(value) = propagatable_scalar(place, state, map) + { + return ecx.write_immediate(Immediate::Scalar(value), dest); + } - ty::Adt(def, args) => { - if def.is_union() { - return None; - } + match ty.kind() { + // ZSTs. Nothing to do. + ty::FnDef(..) => {} - let (variant_idx, variant_def, variant_place) = if def.is_enum() { - let discr = map.apply(place, TrackElem::Discriminant)?; - let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else { - return None; - }; - let discr_bits = discr.assert_bits(discr.size()); - let (variant, _) = - def.discriminants(tcx).find(|(_, var)| discr_bits == var.val)?; - let variant_place = map.apply(place, TrackElem::Variant(variant))?; - let variant_int = ty::ValTree::Leaf(variant.as_u32().into()); - (Some(variant_int), def.variant(variant), variant_place) - } else { - (None, def.non_enum_variant(), place) + // Those are scalars, must be handled above. + ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"), + + ty::Tuple(elem_tys) => { + for (i, elem) in elem_tys.iter().enumerate() { + let Some(field) = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))) else { + throw_machine_stop_str!("missing field in tuple") }; + let field_dest = ecx.project_field(dest, i)?; + try_write_constant(ecx, &field_dest, field, elem, state, map)?; + } + } - let branches = variant_def - .fields - .iter_enumerated() - .map(|(i, field)| { - let ty = field.ty(tcx, args); - let field = map.apply(variant_place, TrackElem::Field(i))?; - self.try_make_valtree(field, ty, state, map) - }) - .collect::<Option<Vec<_>>>()?; - Some(ty::ValTree::Branch( - tcx.arena.alloc_from_iter(variant_idx.into_iter().chain(branches)), - )) + ty::Adt(def, args) => { + if def.is_union() { + throw_machine_stop_str!("cannot propagate unions") } - // Do not attempt to support indirection in constants. - ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_) => None, + let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() { + let Some(discr) = map.apply(place, TrackElem::Discriminant) else { + throw_machine_stop_str!("missing discriminant for enum") + }; + let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else { + throw_machine_stop_str!("discriminant with provenance") + }; + let discr_bits = discr.assert_bits(discr.size()); + let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else { + throw_machine_stop_str!("illegal discriminant for enum") + }; + let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else { + throw_machine_stop_str!("missing variant for enum") + }; + let variant_dest = ecx.project_downcast(dest, variant)?; + (variant, def.variant(variant), variant_place, variant_dest) + } else { + (FIRST_VARIANT, def.non_enum_variant(), place, dest.clone()) + }; + + for (i, field) in variant_def.fields.iter_enumerated() { + let ty = field.ty(*ecx.tcx, args); + let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else { + throw_machine_stop_str!("missing field in ADT") + }; + let field_dest = ecx.project_field(&variant_dest, i.as_usize())?; + try_write_constant(ecx, &field_dest, field, ty, state, map)?; + } + ecx.write_discriminant(variant_idx, dest)?; + } - ty::Never - | ty::Foreign(..) - | ty::Alias(..) - | ty::Param(_) - | ty::Bound(..) - | ty::Placeholder(..) - | ty::Closure(..) - | ty::Coroutine(..) - | ty::Dynamic(..) => None, + // Unsupported for now. + ty::Array(_, _) - ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(), - } + // Do not attempt to support indirection in constants. + | ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_) + + | ty::Never + | ty::Foreign(..) + | ty::Alias(..) + | ty::Param(_) + | ty::Bound(..) + | ty::Placeholder(..) + | ty::Closure(..) + | ty::Coroutine(..) + | ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"), + + ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(), } + + Ok(()) } impl<'mir, 'tcx> @@ -671,8 +716,13 @@ impl<'mir, 'tcx> ) { match &statement.kind { StatementKind::Assign(box (_, rvalue)) => { - OperandCollector { state, visitor: self, map: &results.analysis.0.map } - .visit_rvalue(rvalue, location); + OperandCollector { + state, + visitor: self, + ecx: &mut results.analysis.0.ecx, + map: &results.analysis.0.map, + } + .visit_rvalue(rvalue, location); } _ => (), } @@ -690,7 +740,12 @@ impl<'mir, 'tcx> // Don't overwrite the assignment if it already uses a constant (to keep the span). } StatementKind::Assign(box (place, _)) => { - if let Some(value) = self.try_make_constant(place, state, &results.analysis.0.map) { + if let Some(value) = self.try_make_constant( + &mut results.analysis.0.ecx, + place, + state, + &results.analysis.0.map, + ) { self.patch.assignments.insert(location, value); } } @@ -705,8 +760,13 @@ impl<'mir, 'tcx> terminator: &'mir Terminator<'tcx>, location: Location, ) { - OperandCollector { state, visitor: self, map: &results.analysis.0.map } - .visit_terminator(terminator, location); + OperandCollector { + state, + visitor: self, + ecx: &mut results.analysis.0.ecx, + map: &results.analysis.0.map, + } + .visit_terminator(terminator, location); } } @@ -761,6 +821,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> { struct OperandCollector<'tcx, 'map, 'locals, 'a> { state: &'a State<FlatSet<Scalar>>, visitor: &'a mut Collector<'tcx, 'locals>, + ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>, map: &'map Map, } @@ -773,7 +834,7 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> { location: Location, ) { if let PlaceElem::Index(local) = elem - && let Some(value) = self.visitor.try_make_constant(local.into(), self.state, self.map) + && let Some(value) = self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map) { self.visitor.patch.before_effect.insert((location, local.into()), value); } @@ -781,7 +842,9 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> { fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { if let Some(place) = operand.place() { - if let Some(value) = self.visitor.try_make_constant(place, self.state, self.map) { + if let Some(value) = + self.visitor.try_make_constant(self.ecx, place, self.state, self.map) + { self.visitor.patch.before_effect.insert((location, place), value); } else if !place.projection.is_empty() { // Try to propagate into `Index` projections. @@ -804,7 +867,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm } fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool { - unimplemented!() + false } fn before_access_global( @@ -816,13 +879,13 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm is_write: bool, ) -> InterpResult<'tcx> { if is_write { - crate::const_prop::throw_machine_stop_str!("can't write to global"); + throw_machine_stop_str!("can't write to global"); } // If the static allocation is mutable, then we can't const prop it as its content // might be different at runtime. if alloc.inner().mutability.is_mut() { - crate::const_prop::throw_machine_stop_str!("can't access mutable globals in ConstProp"); + throw_machine_stop_str!("can't access mutable globals in ConstProp"); } Ok(()) @@ -872,7 +935,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm _left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, _right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, ) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> { - crate::const_prop::throw_machine_stop_str!("can't do pointer arithmetic"); + throw_machine_stop_str!("can't do pointer arithmetic"); } fn expose_ptr( |
