diff options
Diffstat (limited to 'compiler/rustc_mir_transform')
| -rw-r--r-- | compiler/rustc_mir_transform/src/dataflow_const_prop.rs | 333 | 
1 files changed, 274 insertions, 59 deletions
| diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs index c084b20772f..b8e35587b87 100644 --- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -2,6 +2,9 @@ //! //! Currently, this pass only propagates scalar values. +use std::assert_matches::assert_matches; +use std::fmt::Formatter; + use rustc_abi::{BackendRepr, FIRST_VARIANT, FieldIdx, Size, VariantIdx}; use rustc_const_eval::const_eval::{DummyMachine, throw_machine_stop_str}; use rustc_const_eval::interpret::{ @@ -15,9 +18,10 @@ use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::*; use rustc_middle::ty::layout::{HasParamEnv, LayoutOf}; use rustc_middle::ty::{self, Ty, TyCtxt}; -use rustc_mir_dataflow::lattice::FlatSet; +use rustc_mir_dataflow::fmt::DebugWithContext; +use rustc_mir_dataflow::lattice::{FlatSet, HasBottom}; use rustc_mir_dataflow::value_analysis::{ - Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace, + Map, PlaceIndex, State, TrackElem, ValueOrPlace, debug_with_context, }; use rustc_mir_dataflow::{Analysis, Results, ResultsVisitor}; use rustc_span::DUMMY_SP; @@ -58,8 +62,8 @@ impl<'tcx> crate::MirPass<'tcx> for DataflowConstProp { // Perform the actual dataflow analysis. let analysis = ConstAnalysis::new(tcx, body, map); - let mut results = debug_span!("analyze") - .in_scope(|| analysis.wrap().iterate_to_fixpoint(tcx, body, None)); + let mut results = + debug_span!("analyze").in_scope(|| analysis.iterate_to_fixpoint(tcx, body, None)); // Collect results and patch the body afterwards. let mut visitor = Collector::new(tcx, &body.local_decls); @@ -69,6 +73,10 @@ impl<'tcx> crate::MirPass<'tcx> for DataflowConstProp { } } +// Note: Currently, places that have their reference taken cannot be tracked. Although this would +// be possible, it has to rely on some aliasing model, which we are not ready to commit to yet. +// Because of that, we can assume that the only way to change the value behind a tracked place is +// by direct assignment. struct ConstAnalysis<'a, 'tcx> { map: Map<'tcx>, tcx: TyCtxt<'tcx>, @@ -77,20 +85,198 @@ struct ConstAnalysis<'a, 'tcx> { param_env: ty::ParamEnv<'tcx>, } -impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { - type Value = FlatSet<Scalar>; +impl<'tcx> Analysis<'tcx> for ConstAnalysis<'_, 'tcx> { + type Domain = State<FlatSet<Scalar>>; const NAME: &'static str = "ConstAnalysis"; - fn map(&self) -> &Map<'tcx> { - &self.map + // The bottom state denotes uninitialized memory. Because we are only doing a sound + // approximation of the actual execution, we can also use this state for places where access + // would be UB. + fn bottom_value(&self, _body: &Body<'tcx>) -> Self::Domain { + State::Unreachable + } + + fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) { + // The initial state maps all tracked places of argument projections to ⊤ and the rest to ⊥. + assert_matches!(state, State::Unreachable); + *state = State::new_reachable(); + for arg in body.args_iter() { + state.flood(PlaceRef { local: arg, projection: &[] }, &self.map); + } + } + + fn apply_statement_effect( + &mut self, + state: &mut Self::Domain, + statement: &Statement<'tcx>, + _location: Location, + ) { + if state.is_reachable() { + self.handle_statement(statement, state); + } + } + + fn apply_terminator_effect<'mir>( + &mut self, + state: &mut Self::Domain, + terminator: &'mir Terminator<'tcx>, + _location: Location, + ) -> TerminatorEdges<'mir, 'tcx> { + if state.is_reachable() { + self.handle_terminator(terminator, state) + } else { + TerminatorEdges::None + } + } + + fn apply_call_return_effect( + &mut self, + state: &mut Self::Domain, + _block: BasicBlock, + return_places: CallReturnPlaces<'_, 'tcx>, + ) { + if state.is_reachable() { + self.handle_call_return(return_places, state) + } + } +} + +impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { + fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map<'tcx>) -> Self { + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + Self { + map, + tcx, + local_decls: &body.local_decls, + ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), + param_env, + } + } + + fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State<FlatSet<Scalar>>) { + match &statement.kind { + StatementKind::Assign(box (place, rvalue)) => { + self.handle_assign(*place, rvalue, state); + } + StatementKind::SetDiscriminant { box place, variant_index } => { + self.handle_set_discriminant(*place, *variant_index, state); + } + StatementKind::Intrinsic(box intrinsic) => { + self.handle_intrinsic(intrinsic); + } + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + // StorageLive leaves the local in an uninitialized state. + // StorageDead makes it UB to access the local afterwards. + state.flood_with( + Place::from(*local).as_ref(), + &self.map, + FlatSet::<Scalar>::BOTTOM, + ); + } + StatementKind::Deinit(box place) => { + // Deinit makes the place uninitialized. + state.flood_with(place.as_ref(), &self.map, FlatSet::<Scalar>::BOTTOM); + } + StatementKind::Retag(..) => { + // We don't track references. + } + StatementKind::ConstEvalCounter + | StatementKind::Nop + | StatementKind::FakeRead(..) + | StatementKind::PlaceMention(..) + | StatementKind::Coverage(..) + | StatementKind::AscribeUserType(..) => (), + } + } + + fn handle_intrinsic(&self, intrinsic: &NonDivergingIntrinsic<'tcx>) { + match intrinsic { + NonDivergingIntrinsic::Assume(..) => { + // Could use this, but ignoring it is sound. + } + NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping { + dst: _, + src: _, + count: _, + }) => { + // This statement represents `*dst = *src`, `count` times. + } + } + } + + fn handle_operand( + &self, + operand: &Operand<'tcx>, + state: &mut State<FlatSet<Scalar>>, + ) -> ValueOrPlace<FlatSet<Scalar>> { + match operand { + Operand::Constant(box constant) => { + ValueOrPlace::Value(self.handle_constant(constant, state)) + } + Operand::Copy(place) | Operand::Move(place) => { + // On move, we would ideally flood the place with bottom. But with the current + // framework this is not possible (similar to `InterpCx::eval_operand`). + self.map.find(place.as_ref()).map(ValueOrPlace::Place).unwrap_or(ValueOrPlace::TOP) + } + } + } + + /// The effect of a successful function call return should not be + /// applied here, see [`Analysis::apply_terminator_effect`]. + fn handle_terminator<'mir>( + &self, + terminator: &'mir Terminator<'tcx>, + state: &mut State<FlatSet<Scalar>>, + ) -> TerminatorEdges<'mir, 'tcx> { + match &terminator.kind { + TerminatorKind::Call { .. } | TerminatorKind::InlineAsm { .. } => { + // Effect is applied by `handle_call_return`. + } + TerminatorKind::Drop { place, .. } => { + state.flood_with(place.as_ref(), &self.map, FlatSet::<Scalar>::BOTTOM); + } + TerminatorKind::Yield { .. } => { + // They would have an effect, but are not allowed in this phase. + bug!("encountered disallowed terminator"); + } + TerminatorKind::SwitchInt { discr, targets } => { + return self.handle_switch_int(discr, targets, state); + } + TerminatorKind::TailCall { .. } => { + // FIXME(explicit_tail_calls): determine if we need to do something here (probably + // not) + } + TerminatorKind::Goto { .. } + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::Assert { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } => { + // These terminators have no effect on the analysis. + } + } + terminator.edges() + } + + fn handle_call_return( + &self, + return_places: CallReturnPlaces<'_, 'tcx>, + state: &mut State<FlatSet<Scalar>>, + ) { + return_places.for_each(|place| { + state.flood(place.as_ref(), &self.map); + }) } fn handle_set_discriminant( &self, place: Place<'tcx>, variant_index: VariantIdx, - state: &mut State<Self::Value>, + state: &mut State<FlatSet<Scalar>>, ) { state.flood_discr(place.as_ref(), &self.map); if self.map.find_discr(place.as_ref()).is_some() { @@ -109,17 +295,17 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { &self, target: Place<'tcx>, rvalue: &Rvalue<'tcx>, - state: &mut State<Self::Value>, + state: &mut State<FlatSet<Scalar>>, ) { match rvalue { Rvalue::Use(operand) => { - state.flood(target.as_ref(), self.map()); + state.flood(target.as_ref(), &self.map); if let Some(target) = self.map.find(target.as_ref()) { self.assign_operand(state, target, operand); } } Rvalue::CopyForDeref(rhs) => { - state.flood(target.as_ref(), self.map()); + state.flood(target.as_ref(), &self.map); if let Some(target) = self.map.find(target.as_ref()) { self.assign_operand(state, target, &Operand::Copy(*rhs)); } @@ -127,9 +313,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { Rvalue::Aggregate(kind, operands) => { // If we assign `target = Enum::Variant#0(operand)`, // we must make sure that all `target as Variant#i` are `Top`. - state.flood(target.as_ref(), self.map()); + state.flood(target.as_ref(), &self.map); - let Some(target_idx) = self.map().find(target.as_ref()) else { return }; + let Some(target_idx) = self.map.find(target.as_ref()) else { return }; let (variant_target, variant_index) = match **kind { AggregateKind::Tuple | AggregateKind::Closure(..) => (Some(target_idx), None), @@ -148,14 +334,14 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { if let Some(variant_target_idx) = variant_target { for (field_index, operand) in operands.iter_enumerated() { if let Some(field) = - self.map().apply(variant_target_idx, TrackElem::Field(field_index)) + self.map.apply(variant_target_idx, TrackElem::Field(field_index)) { self.assign_operand(state, field, operand); } } } if let Some(variant_index) = variant_index - && let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant) + && let Some(discr_idx) = self.map.apply(target_idx, TrackElem::Discriminant) { // We are assigning the discriminant as part of an aggregate. // This discriminant can only alias a variant field's value if the operand @@ -170,23 +356,23 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { } Rvalue::BinaryOp(op, box (left, right)) if op.is_overflowing() => { // Flood everything now, so we can use `insert_value_idx` directly later. - state.flood(target.as_ref(), self.map()); + state.flood(target.as_ref(), &self.map); - let Some(target) = self.map().find(target.as_ref()) else { return }; + let Some(target) = self.map.find(target.as_ref()) else { return }; - let value_target = self.map().apply(target, TrackElem::Field(0_u32.into())); - let overflow_target = self.map().apply(target, TrackElem::Field(1_u32.into())); + let value_target = self.map.apply(target, TrackElem::Field(0_u32.into())); + let overflow_target = self.map.apply(target, TrackElem::Field(1_u32.into())); if value_target.is_some() || overflow_target.is_some() { let (val, overflow) = self.binary_op(state, *op, left, right); if let Some(value_target) = value_target { // We have flooded `target` earlier. - state.insert_value_idx(value_target, val, self.map()); + state.insert_value_idx(value_target, val, &self.map); } if let Some(overflow_target) = overflow_target { // We have flooded `target` earlier. - state.insert_value_idx(overflow_target, overflow, self.map()); + state.insert_value_idx(overflow_target, overflow, &self.map); } } } @@ -196,27 +382,30 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { _, ) => { let pointer = self.handle_operand(operand, state); - state.assign(target.as_ref(), pointer, self.map()); + state.assign(target.as_ref(), pointer, &self.map); - if let Some(target_len) = self.map().find_len(target.as_ref()) + if let Some(target_len) = self.map.find_len(target.as_ref()) && let operand_ty = operand.ty(self.local_decls, self.tcx) && let Some(operand_ty) = operand_ty.builtin_deref(true) && let ty::Array(_, len) = operand_ty.kind() && let Some(len) = Const::Ty(self.tcx.types.usize, *len) .try_eval_scalar_int(self.tcx, self.param_env) { - state.insert_value_idx(target_len, FlatSet::Elem(len.into()), self.map()); + state.insert_value_idx(target_len, FlatSet::Elem(len.into()), &self.map); } } - _ => self.super_assign(target, rvalue, state), + _ => { + let result = self.handle_rvalue(rvalue, state); + state.assign(target.as_ref(), result, &self.map); + } } } fn handle_rvalue( &self, rvalue: &Rvalue<'tcx>, - state: &mut State<Self::Value>, - ) -> ValueOrPlace<Self::Value> { + state: &mut State<FlatSet<Scalar>>, + ) -> ValueOrPlace<FlatSet<Scalar>> { let val = match rvalue { Rvalue::Len(place) => { let place_ty = place.ty(self.local_decls, self.tcx); @@ -225,7 +414,7 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { .try_eval_scalar(self.tcx, self.param_env) .map_or(FlatSet::Top, FlatSet::Elem) } else if let [ProjectionElem::Deref] = place.projection[..] { - state.get_len(place.local.into(), self.map()) + state.get_len(place.local.into(), &self.map) } else { FlatSet::Top } @@ -296,8 +485,24 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { }; FlatSet::Elem(Scalar::from_target_usize(val, &self.tcx)) } - Rvalue::Discriminant(place) => state.get_discr(place.as_ref(), self.map()), - _ => return self.super_rvalue(rvalue, state), + Rvalue::Discriminant(place) => state.get_discr(place.as_ref(), &self.map), + Rvalue::Use(operand) => return self.handle_operand(operand, state), + Rvalue::CopyForDeref(place) => { + return self.handle_operand(&Operand::Copy(*place), state); + } + Rvalue::Ref(..) | Rvalue::RawPtr(..) => { + // We don't track such places. + return ValueOrPlace::TOP; + } + Rvalue::Repeat(..) + | Rvalue::ThreadLocalRef(..) + | Rvalue::Cast(..) + | Rvalue::BinaryOp(..) + | Rvalue::Aggregate(..) + | Rvalue::ShallowInitBox(..) => { + // No modification is possible through these r-values. + return ValueOrPlace::TOP; + } }; ValueOrPlace::Value(val) } @@ -305,8 +510,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { fn handle_constant( &self, constant: &ConstOperand<'tcx>, - _state: &mut State<Self::Value>, - ) -> Self::Value { + _state: &mut State<FlatSet<Scalar>>, + ) -> FlatSet<Scalar> { constant .const_ .try_eval_scalar(self.tcx, self.param_env) @@ -317,11 +522,11 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { &self, discr: &'mir Operand<'tcx>, targets: &'mir SwitchTargets, - state: &mut State<Self::Value>, + state: &mut State<FlatSet<Scalar>>, ) -> TerminatorEdges<'mir, 'tcx> { let value = match self.handle_operand(discr, state) { ValueOrPlace::Value(value) => value, - ValueOrPlace::Place(place) => state.get_idx(place, self.map()), + ValueOrPlace::Place(place) => state.get_idx(place, &self.map), }; match value { // We are branching on uninitialized data, this is UB, treat it as unreachable. @@ -334,19 +539,6 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { FlatSet::Top => TerminatorEdges::SwitchInt { discr, targets }, } } -} - -impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { - fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map<'tcx>) -> Self { - let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); - Self { - map, - tcx, - local_decls: &body.local_decls, - ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), - param_env, - } - } /// The caller must have flooded `place`. fn assign_operand( @@ -537,6 +729,30 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { } } +/// This is used to visualize the dataflow analysis. +impl<'tcx> DebugWithContext<ConstAnalysis<'_, 'tcx>> for State<FlatSet<Scalar>> { + fn fmt_with(&self, ctxt: &ConstAnalysis<'_, 'tcx>, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + State::Reachable(values) => debug_with_context(values, None, &ctxt.map, f), + State::Unreachable => write!(f, "unreachable"), + } + } + + fn fmt_diff_with( + &self, + old: &Self, + ctxt: &ConstAnalysis<'_, 'tcx>, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + match (self, old) { + (State::Reachable(this), State::Reachable(old)) => { + debug_with_context(this, Some(old), &ctxt.map, f) + } + _ => Ok(()), // Consider printing something here. + } + } +} + pub(crate) struct Patch<'tcx> { tcx: TyCtxt<'tcx>, @@ -725,8 +941,7 @@ fn try_write_constant<'tcx>( interp_ok(()) } -impl<'mir, 'tcx> - ResultsVisitor<'mir, 'tcx, Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>> +impl<'mir, 'tcx> ResultsVisitor<'mir, 'tcx, Results<'tcx, ConstAnalysis<'_, 'tcx>>> for Collector<'_, 'tcx> { type Domain = State<FlatSet<Scalar>>; @@ -734,7 +949,7 @@ impl<'mir, 'tcx> #[instrument(level = "trace", skip(self, results, statement))] fn visit_statement_before_primary_effect( &mut self, - results: &mut Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, + results: &mut Results<'tcx, ConstAnalysis<'_, 'tcx>>, state: &Self::Domain, statement: &'mir Statement<'tcx>, location: Location, @@ -744,8 +959,8 @@ impl<'mir, 'tcx> OperandCollector { state, visitor: self, - ecx: &mut results.analysis.0.ecx, - map: &results.analysis.0.map, + ecx: &mut results.analysis.ecx, + map: &results.analysis.map, } .visit_rvalue(rvalue, location); } @@ -756,7 +971,7 @@ impl<'mir, 'tcx> #[instrument(level = "trace", skip(self, results, statement))] fn visit_statement_after_primary_effect( &mut self, - results: &mut Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, + results: &mut Results<'tcx, ConstAnalysis<'_, 'tcx>>, state: &Self::Domain, statement: &'mir Statement<'tcx>, location: Location, @@ -767,10 +982,10 @@ impl<'mir, 'tcx> } StatementKind::Assign(box (place, _)) => { if let Some(value) = self.try_make_constant( - &mut results.analysis.0.ecx, + &mut results.analysis.ecx, place, state, - &results.analysis.0.map, + &results.analysis.map, ) { self.patch.assignments.insert(location, value); } @@ -781,7 +996,7 @@ impl<'mir, 'tcx> fn visit_terminator_before_primary_effect( &mut self, - results: &mut Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, + results: &mut Results<'tcx, ConstAnalysis<'_, 'tcx>>, state: &Self::Domain, terminator: &'mir Terminator<'tcx>, location: Location, @@ -789,8 +1004,8 @@ impl<'mir, 'tcx> OperandCollector { state, visitor: self, - ecx: &mut results.analysis.0.ecx, - map: &results.analysis.0.map, + ecx: &mut results.analysis.ecx, + map: &results.analysis.map, } .visit_terminator(terminator, location); } | 
