about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCamille GILLOT <gillot.camille@gmail.com>2023-05-20 08:54:16 +0000
committerCamille GILLOT <gillot.camille@gmail.com>2024-01-08 22:42:07 +0000
commitcae0dc28332065dfa3fa168fdc1f2818bcbdf0c3 (patch)
treed234f063cf2d57c5afa8da73dc050a9ea85e328a
parent835680286207b633f4727310829a709b585f7b56 (diff)
downloadrust-cae0dc28332065dfa3fa168fdc1f2818bcbdf0c3.tar.gz
rust-cae0dc28332065dfa3fa168fdc1f2818bcbdf0c3.zip
Simplify code flow.
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs1
-rw-r--r--compiler/rustc_mir_transform/src/promote_consts.rs378
2 files changed, 157 insertions, 222 deletions
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index dfc4ff3b7a3..915ddcb2775 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -1,5 +1,6 @@
 #![deny(rustc::untranslatable_diagnostic)]
 #![deny(rustc::diagnostic_outside_of_impl)]
+#![feature(assert_matches)]
 #![feature(box_patterns)]
 #![feature(cow_is_borrowed)]
 #![feature(decl_macro)]
diff --git a/compiler/rustc_mir_transform/src/promote_consts.rs b/compiler/rustc_mir_transform/src/promote_consts.rs
index 9a60e83d322..9b9c9aae88b 100644
--- a/compiler/rustc_mir_transform/src/promote_consts.rs
+++ b/compiler/rustc_mir_transform/src/promote_consts.rs
@@ -12,6 +12,7 @@
 //! initialization and can otherwise silence errors, if
 //! move analysis runs after promotion on broken MIR.
 
+use either::{Left, Right};
 use rustc_hir as hir;
 use rustc_middle::mir;
 use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
@@ -22,6 +23,7 @@ use rustc_span::Span;
 
 use rustc_index::{Idx, IndexSlice, IndexVec};
 
+use std::assert_matches::assert_matches;
 use std::cell::Cell;
 use std::{cmp, iter, mem};
 
@@ -116,41 +118,38 @@ impl<'tcx> Visitor<'tcx> for Collector<'_, 'tcx> {
 
         let temp = &mut self.temps[index];
         debug!("visit_local: temp={:?}", temp);
-        if *temp == TempState::Undefined {
-            match context {
+        *temp = match *temp {
+            TempState::Undefined => match context {
                 PlaceContext::MutatingUse(MutatingUseContext::Store)
                 | PlaceContext::MutatingUse(MutatingUseContext::Call) => {
-                    *temp = TempState::Defined { location, uses: 0, valid: Err(()) };
+                    TempState::Defined { location, uses: 0, valid: Err(()) }
+                }
+                _ => TempState::Unpromotable,
+            },
+            TempState::Defined { ref mut uses, .. } => {
+                // We always allow borrows, even mutable ones, as we need
+                // to promote mutable borrows of some ZSTs e.g., `&mut []`.
+                let allowed_use = match context {
+                    PlaceContext::MutatingUse(MutatingUseContext::Borrow)
+                    | PlaceContext::NonMutatingUse(_) => true,
+                    PlaceContext::MutatingUse(_) | PlaceContext::NonUse(_) => false,
+                };
+                debug!("visit_local: allowed_use={:?}", allowed_use);
+                if allowed_use {
+                    *uses += 1;
                     return;
                 }
-                _ => { /* mark as unpromotable below */ }
+                TempState::Unpromotable
             }
-        } else if let TempState::Defined { uses, .. } = temp {
-            // We always allow borrows, even mutable ones, as we need
-            // to promote mutable borrows of some ZSTs e.g., `&mut []`.
-            let allowed_use = match context {
-                PlaceContext::MutatingUse(MutatingUseContext::Borrow)
-                | PlaceContext::NonMutatingUse(_) => true,
-                PlaceContext::MutatingUse(_) | PlaceContext::NonUse(_) => false,
-            };
-            debug!("visit_local: allowed_use={:?}", allowed_use);
-            if allowed_use {
-                *uses += 1;
-                return;
-            }
-            /* mark as unpromotable below */
-        }
-        *temp = TempState::Unpromotable;
+            _ => TempState::Unpromotable,
+        };
     }
 
     fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
         self.super_rvalue(rvalue, location);
 
-        match *rvalue {
-            Rvalue::Ref(..) => {
-                self.candidates.push(Candidate { location });
-            }
-            _ => {}
+        if let Rvalue::Ref(..) = *rvalue {
+            self.candidates.push(Candidate { location });
         }
     }
 }
@@ -189,230 +188,165 @@ struct Unpromotable;
 
 impl<'tcx> Validator<'_, 'tcx> {
     fn validate_candidate(&mut self, candidate: Candidate) -> Result<(), Unpromotable> {
-        let loc = candidate.location;
-        let statement = &self.body[loc.block].statements[loc.statement_index];
-        match &statement.kind {
-            StatementKind::Assign(box (_, Rvalue::Ref(_, kind, place))) => {
-                // We can only promote interior borrows of promotable temps (non-temps
-                // don't get promoted anyway).
-                self.validate_local(place.local)?;
-
-                // The reference operation itself must be promotable.
-                // (Needs to come after `validate_local` to avoid ICEs.)
-                self.validate_ref(*kind, place)?;
+        let Left(statement) = self.body.stmt_at(candidate.location) else { bug!() };
+        let Some((_, Rvalue::Ref(_, kind, place))) = statement.kind.as_assign() else { bug!() };
 
-                // We do not check all the projections (they do not get promoted anyway),
-                // but we do stay away from promoting anything involving a dereference.
-                if place.projection.contains(&ProjectionElem::Deref) {
-                    return Err(Unpromotable);
-                }
+        // We can only promote interior borrows of promotable temps (non-temps
+        // don't get promoted anyway).
+        self.validate_local(place.local)?;
 
-                Ok(())
-            }
-            _ => bug!(),
+        // The reference operation itself must be promotable.
+        // (Needs to come after `validate_local` to avoid ICEs.)
+        self.validate_ref(*kind, place)?;
+
+        // We do not check all the projections (they do not get promoted anyway),
+        // but we do stay away from promoting anything involving a dereference.
+        if place.projection.contains(&ProjectionElem::Deref) {
+            return Err(Unpromotable);
         }
+
+        Ok(())
     }
 
     // FIXME(eddyb) maybe cache this?
     fn qualif_local<Q: qualifs::Qualif>(&mut self, local: Local) -> bool {
-        if let TempState::Defined { location: loc, .. } = self.temps[local] {
-            let num_stmts = self.body[loc.block].statements.len();
-
-            if loc.statement_index < num_stmts {
-                let statement = &self.body[loc.block].statements[loc.statement_index];
-                match &statement.kind {
-                    StatementKind::Assign(box (_, rhs)) => qualifs::in_rvalue::<Q, _>(
-                        self.ccx,
-                        &mut |l| self.qualif_local::<Q>(l),
-                        rhs,
-                    ),
-                    _ => {
+        let TempState::Defined { location: loc, .. } = self.temps[local] else {
+            return false;
+        };
+
+        let stmt_or_term = self.body.stmt_at(loc);
+        match stmt_or_term {
+            Left(statement) => {
+                let Some((_, rhs)) = statement.kind.as_assign() else {
+                    span_bug!(statement.source_info.span, "{:?} is not an assignment", statement)
+                };
+                qualifs::in_rvalue::<Q, _>(self.ccx, &mut |l| self.qualif_local::<Q>(l), rhs)
+            }
+            Right(terminator) => {
+                assert_matches!(terminator.kind, TerminatorKind::Call { .. });
+                let return_ty = self.body.local_decls[local].ty;
+                Q::in_any_value_of_ty(self.ccx, return_ty)
+            }
+        }
+    }
+
+    fn validate_local(&mut self, local: Local) -> Result<(), Unpromotable> {
+        let TempState::Defined { location: loc, uses, valid } = self.temps[local] else {
+            return Err(Unpromotable);
+        };
+
+        // We cannot promote things that need dropping, since the promoted value would not get
+        // dropped.
+        if self.qualif_local::<qualifs::NeedsDrop>(local) {
+            return Err(Unpromotable);
+        }
+
+        if valid.is_ok() {
+            return Ok(());
+        }
+
+        let ok = {
+            let stmt_or_term = self.body.stmt_at(loc);
+            match stmt_or_term {
+                Left(statement) => {
+                    let Some((_, rhs)) = statement.kind.as_assign() else {
                         span_bug!(
                             statement.source_info.span,
                             "{:?} is not an assignment",
                             statement
-                        );
-                    }
+                        )
+                    };
+                    self.validate_rvalue(rhs)
                 }
-            } else {
-                let terminator = self.body[loc.block].terminator();
-                match &terminator.kind {
-                    TerminatorKind::Call { .. } => {
-                        let return_ty = self.body.local_decls[local].ty;
-                        Q::in_any_value_of_ty(self.ccx, return_ty)
-                    }
+                Right(terminator) => match &terminator.kind {
+                    TerminatorKind::Call { func, args, .. } => self.validate_call(func, args),
+                    TerminatorKind::Yield { .. } => Err(Unpromotable),
                     kind => {
                         span_bug!(terminator.source_info.span, "{:?} not promotable", kind);
                     }
-                }
+                },
             }
-        } else {
-            false
-        }
-    }
+        };
 
-    fn validate_local(&mut self, local: Local) -> Result<(), Unpromotable> {
-        if let TempState::Defined { location: loc, uses, valid } = self.temps[local] {
-            // We cannot promote things that need dropping, since the promoted value
-            // would not get dropped.
-            if self.qualif_local::<qualifs::NeedsDrop>(local) {
-                return Err(Unpromotable);
-            }
-            valid.or_else(|_| {
-                let ok = {
-                    let block = &self.body[loc.block];
-                    let num_stmts = block.statements.len();
-
-                    if loc.statement_index < num_stmts {
-                        let statement = &block.statements[loc.statement_index];
-                        match &statement.kind {
-                            StatementKind::Assign(box (_, rhs)) => self.validate_rvalue(rhs),
-                            _ => {
-                                span_bug!(
-                                    statement.source_info.span,
-                                    "{:?} is not an assignment",
-                                    statement
-                                );
-                            }
-                        }
-                    } else {
-                        let terminator = block.terminator();
-                        match &terminator.kind {
-                            TerminatorKind::Call { func, args, .. } => {
-                                self.validate_call(func, args)
-                            }
-                            TerminatorKind::Yield { .. } => Err(Unpromotable),
-                            kind => {
-                                span_bug!(terminator.source_info.span, "{:?} not promotable", kind);
-                            }
-                        }
-                    }
-                };
-                self.temps[local] = match ok {
-                    Ok(()) => TempState::Defined { location: loc, uses, valid: Ok(()) },
-                    Err(_) => TempState::Unpromotable,
-                };
-                ok
-            })
-        } else {
-            Err(Unpromotable)
-        }
+        self.temps[local] = match ok {
+            Ok(()) => TempState::Defined { location: loc, uses, valid: Ok(()) },
+            Err(_) => TempState::Unpromotable,
+        };
+
+        ok
     }
 
     fn validate_place(&mut self, place: PlaceRef<'tcx>) -> Result<(), Unpromotable> {
-        match place.last_projection() {
-            None => self.validate_local(place.local),
-            Some((place_base, elem)) => {
-                // Validate topmost projection, then recurse.
-                match elem {
-                    ProjectionElem::Deref => {
-                        let mut promotable = false;
-                        // When a static is used by-value, that gets desugared to `*STATIC_ADDR`,
-                        // and we need to be able to promote this. So check if this deref matches
-                        // that specific pattern.
-
-                        // We need to make sure this is a `Deref` of a local with no further projections.
-                        // Discussion can be found at
-                        // https://github.com/rust-lang/rust/pull/74945#discussion_r463063247
-                        if let Some(local) = place_base.as_local() {
-                            if let TempState::Defined { location, .. } = self.temps[local] {
-                                let def_stmt = self.body[location.block]
-                                    .statements
-                                    .get(location.statement_index);
-                                if let Some(Statement {
-                                    kind:
-                                        StatementKind::Assign(box (
-                                            _,
-                                            Rvalue::Use(Operand::Constant(c)),
-                                        )),
-                                    ..
-                                }) = def_stmt
-                                {
-                                    if let Some(did) = c.check_static_ptr(self.tcx) {
-                                        // Evaluating a promoted may not read statics except if it got
-                                        // promoted from a static (this is a CTFE check). So we
-                                        // can only promote static accesses inside statics.
-                                        if let Some(hir::ConstContext::Static(..)) = self.const_kind
-                                        {
-                                            if !self.tcx.is_thread_local_static(did) {
-                                                promotable = true;
-                                            }
-                                        }
-                                    }
-                                }
-                            }
-                        }
-                        if !promotable {
-                            return Err(Unpromotable);
-                        }
-                    }
-                    ProjectionElem::OpaqueCast(..) | ProjectionElem::Downcast(..) => {
-                        return Err(Unpromotable);
-                    }
+        let Some((place_base, elem)) = place.last_projection() else {
+            return self.validate_local(place.local);
+        };
 
-                    ProjectionElem::ConstantIndex { .. }
-                    | ProjectionElem::Subtype(_)
-                    | ProjectionElem::Subslice { .. } => {}
-
-                    ProjectionElem::Index(local) => {
-                        let mut promotable = false;
-                        // Only accept if we can predict the index and are indexing an array.
-                        let val = if let TempState::Defined { location: loc, .. } =
-                            self.temps[local]
-                        {
-                            let block = &self.body[loc.block];
-                            if loc.statement_index < block.statements.len() {
-                                let statement = &block.statements[loc.statement_index];
-                                match &statement.kind {
-                                    StatementKind::Assign(box (
-                                        _,
-                                        Rvalue::Use(Operand::Constant(c)),
-                                    )) => c.const_.try_eval_target_usize(self.tcx, self.param_env),
-                                    _ => None,
-                                }
-                            } else {
-                                None
-                            }
-                        } else {
-                            None
-                        };
-                        if let Some(idx) = val {
-                            // Determine the type of the thing we are indexing.
-                            let ty = place_base.ty(self.body, self.tcx).ty;
-                            match ty.kind() {
-                                ty::Array(_, len) => {
-                                    // It's an array; determine its length.
-                                    if let Some(len) =
-                                        len.try_eval_target_usize(self.tcx, self.param_env)
-                                    {
-                                        // If the index is in-bounds, go ahead.
-                                        if idx < len {
-                                            promotable = true;
-                                        }
-                                    }
-                                }
-                                _ => {}
-                            }
-                        }
-                        if !promotable {
-                            return Err(Unpromotable);
-                        }
+        // Validate topmost projection, then recurse.
+        match elem {
+            // Recurse directly.
+            ProjectionElem::ConstantIndex { .. }
+            | ProjectionElem::Subtype(_)
+            | ProjectionElem::Subslice { .. } => {}
 
-                        self.validate_local(local)?;
-                    }
+            // Never recurse.
+            ProjectionElem::OpaqueCast(..) | ProjectionElem::Downcast(..) => {
+                return Err(Unpromotable);
+            }
 
-                    ProjectionElem::Field(..) => {
-                        let base_ty = place_base.ty(self.body, self.tcx).ty;
-                        if base_ty.is_union() {
-                            // No promotion of union field accesses.
-                            return Err(Unpromotable);
-                        }
-                    }
+            ProjectionElem::Deref => {
+                // When a static is used by-value, that gets desugared to `*STATIC_ADDR`,
+                // and we need to be able to promote this. So check if this deref matches
+                // that specific pattern.
+
+                // We need to make sure this is a `Deref` of a local with no further projections.
+                // Discussion can be found at
+                // https://github.com/rust-lang/rust/pull/74945#discussion_r463063247
+                if let Some(local) = place_base.as_local()
+                    && let TempState::Defined { location, .. } = self.temps[local]
+                    && let Left(def_stmt) = self.body.stmt_at(location)
+                    && let Some((_, Rvalue::Use(Operand::Constant(c)))) = def_stmt.kind.as_assign()
+                    && let Some(did) = c.check_static_ptr(self.tcx)
+                    // Evaluating a promoted may not read statics except if it got
+                    // promoted from a static (this is a CTFE check). So we
+                    // can only promote static accesses inside statics.
+                    && let Some(hir::ConstContext::Static(..)) = self.const_kind
+                    && !self.tcx.is_thread_local_static(did)
+                {
+                    // Recurse.
+                } else {
+                    return Err(Unpromotable);
                 }
+            }
+            ProjectionElem::Index(local) => {
+                // Only accept if we can predict the index and are indexing an array.
+                if let TempState::Defined { location: loc, .. } = self.temps[local]
+                    && let Left(statement) =  self.body.stmt_at(loc)
+                    && let Some((_, Rvalue::Use(Operand::Constant(c)))) = statement.kind.as_assign()
+                    && let Some(idx) = c.const_.try_eval_target_usize(self.tcx, self.param_env)
+                    // Determine the type of the thing we are indexing.
+                    && let ty::Array(_, len) = place_base.ty(self.body, self.tcx).ty.kind()
+                    // It's an array; determine its length.
+                    && let Some(len) = len.try_eval_target_usize(self.tcx, self.param_env)
+                    // If the index is in-bounds, go ahead.
+                    && idx < len
+                {
+                    self.validate_local(local)?;
+                    // Recurse.
+                } else {
+                    return Err(Unpromotable);
+                }
+            }
 
-                self.validate_place(place_base)
+            ProjectionElem::Field(..) => {
+                let base_ty = place_base.ty(self.body, self.tcx).ty;
+                if base_ty.is_union() {
+                    // No promotion of union field accesses.
+                    return Err(Unpromotable);
+                }
             }
         }
+
+        self.validate_place(place_base)
     }
 
     fn validate_operand(&mut self, operand: &Operand<'tcx>) -> Result<(), Unpromotable> {