about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCelina G. Val <celinval@amazon.com>2023-11-29 12:49:32 -0800
committerCelina G. Val <celinval@amazon.com>2023-11-30 11:45:34 -0800
commit3e0b2fac5dd02378d8a5fe2c8b0c6affed404db0 (patch)
treed7bcab5328dceaee757ba99c7bffc1c9a8f5124e
parent07921b50ba6dcb5b2984a1dba039a38d85bffba2 (diff)
downloadrust-3e0b2fac5dd02378d8a5fe2c8b0c6affed404db0.tar.gz
rust-3e0b2fac5dd02378d8a5fe2c8b0c6affed404db0.zip
Change SwitchTarget representation
The new structure encodes its invariant, which reduces the likelihood
of having an inconsistent representation. It is also more intuitive and
user friendly.

I encapsulated the structure for now in case we decide to change it back.
-rw-r--r--compiler/rustc_smir/src/rustc_smir/convert/mir.rs11
-rw-r--r--compiler/stable_mir/src/mir/body.rs85
-rw-r--r--compiler/stable_mir/src/mir/pretty.rs15
-rw-r--r--compiler/stable_mir/src/mir/visit.rs5
4 files changed, 72 insertions, 44 deletions
diff --git a/compiler/rustc_smir/src/rustc_smir/convert/mir.rs b/compiler/rustc_smir/src/rustc_smir/convert/mir.rs
index e9b835f90db..165b8c50e70 100644
--- a/compiler/rustc_smir/src/rustc_smir/convert/mir.rs
+++ b/compiler/rustc_smir/src/rustc_smir/convert/mir.rs
@@ -579,13 +579,12 @@ impl<'tcx> Stable<'tcx> for mir::TerminatorKind<'tcx> {
             mir::TerminatorKind::SwitchInt { discr, targets } => TerminatorKind::SwitchInt {
                 discr: discr.stable(tables),
                 targets: {
-                    let (value_vec, mut target_vec): (Vec<_>, Vec<_>) =
-                        targets.iter().map(|(value, target)| (value, target.as_usize())).unzip();
-                    // We need to push otherwise as last element to ensure it's same as in MIR.
-                    target_vec.push(targets.otherwise().as_usize());
-                    stable_mir::mir::SwitchTargets { value: value_vec, targets: target_vec }
+                    let branches = targets.iter().map(|(val, target)| (val, target.as_usize()));
+                    stable_mir::mir::SwitchTargets::new(
+                        branches.collect(),
+                        targets.otherwise().as_usize(),
+                    )
                 },
-                otherwise: targets.otherwise().as_usize(),
             },
             mir::TerminatorKind::UnwindResume => TerminatorKind::Resume,
             mir::TerminatorKind::UnwindTerminate(_) => TerminatorKind::Abort,
diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs
index 02a28687676..8ccc6732a8d 100644
--- a/compiler/stable_mir/src/mir/body.rs
+++ b/compiler/stable_mir/src/mir/body.rs
@@ -3,7 +3,7 @@ use crate::ty::{
     AdtDef, ClosureDef, Const, CoroutineDef, GenericArgs, Movability, Region, RigidTy, Ty, TyKind,
 };
 use crate::{Error, Opaque, Span, Symbol};
-use std::{io, slice};
+use std::io;
 /// The SMIR representation of a single function.
 #[derive(Clone, Debug)]
 pub struct Body {
@@ -23,6 +23,8 @@ pub struct Body {
     pub(super) var_debug_info: Vec<VarDebugInfo>,
 }
 
+pub type BasicBlockIdx = usize;
+
 impl Body {
     /// Constructs a `Body`.
     ///
@@ -114,22 +116,21 @@ pub struct Terminator {
 }
 
 impl Terminator {
-    pub fn successors(&self) -> Successors<'_> {
+    pub fn successors(&self) -> Successors {
         self.kind.successors()
     }
 }
 
-pub type Successors<'a> = impl Iterator<Item = usize> + 'a;
+pub type Successors = Vec<BasicBlockIdx>;
 
 #[derive(Clone, Debug, Eq, PartialEq)]
 pub enum TerminatorKind {
     Goto {
-        target: usize,
+        target: BasicBlockIdx,
     },
     SwitchInt {
         discr: Operand,
         targets: SwitchTargets,
-        otherwise: usize,
     },
     Resume,
     Abort,
@@ -137,43 +138,42 @@ pub enum TerminatorKind {
     Unreachable,
     Drop {
         place: Place,
-        target: usize,
+        target: BasicBlockIdx,
         unwind: UnwindAction,
     },
     Call {
         func: Operand,
         args: Vec<Operand>,
         destination: Place,
-        target: Option<usize>,
+        target: Option<BasicBlockIdx>,
         unwind: UnwindAction,
     },
     Assert {
         cond: Operand,
         expected: bool,
         msg: AssertMessage,
-        target: usize,
+        target: BasicBlockIdx,
         unwind: UnwindAction,
     },
-    CoroutineDrop,
     InlineAsm {
         template: String,
         operands: Vec<InlineAsmOperand>,
         options: String,
         line_spans: String,
-        destination: Option<usize>,
+        destination: Option<BasicBlockIdx>,
         unwind: UnwindAction,
     },
 }
 
 impl TerminatorKind {
-    pub fn successors(&self) -> Successors<'_> {
+    pub fn successors(&self) -> Successors {
         use self::TerminatorKind::*;
         match *self {
-            Call { target: Some(t), unwind: UnwindAction::Cleanup(ref u), .. }
-            | Drop { target: t, unwind: UnwindAction::Cleanup(ref u), .. }
-            | Assert { target: t, unwind: UnwindAction::Cleanup(ref u), .. }
-            | InlineAsm { destination: Some(t), unwind: UnwindAction::Cleanup(ref u), .. } => {
-                Some(t).into_iter().chain(slice::from_ref(u).into_iter().copied())
+            Call { target: Some(t), unwind: UnwindAction::Cleanup(u), .. }
+            | Drop { target: t, unwind: UnwindAction::Cleanup(u), .. }
+            | Assert { target: t, unwind: UnwindAction::Cleanup(u), .. }
+            | InlineAsm { destination: Some(t), unwind: UnwindAction::Cleanup(u), .. } => {
+                vec![t, u]
             }
             Goto { target: t }
             | Call { target: None, unwind: UnwindAction::Cleanup(t), .. }
@@ -182,21 +182,18 @@ impl TerminatorKind {
             | Assert { target: t, unwind: _, .. }
             | InlineAsm { destination: None, unwind: UnwindAction::Cleanup(t), .. }
             | InlineAsm { destination: Some(t), unwind: _, .. } => {
-                Some(t).into_iter().chain((&[]).into_iter().copied())
+                vec![t]
             }
 
-            CoroutineDrop
-            | Return
+            Return
             | Resume
             | Abort
             | Unreachable
             | Call { target: None, unwind: _, .. }
             | InlineAsm { destination: None, unwind: _, .. } => {
-                None.into_iter().chain((&[]).into_iter().copied())
-            }
-            SwitchInt { ref targets, .. } => {
-                None.into_iter().chain(targets.targets.iter().copied())
+                vec![]
             }
+            SwitchInt { ref targets, .. } => targets.all_targets(),
         }
     }
 
@@ -205,7 +202,6 @@ impl TerminatorKind {
             TerminatorKind::Goto { .. }
             | TerminatorKind::Return
             | TerminatorKind::Unreachable
-            | TerminatorKind::CoroutineDrop
             | TerminatorKind::Resume
             | TerminatorKind::Abort
             | TerminatorKind::SwitchInt { .. } => None,
@@ -231,7 +227,7 @@ pub enum UnwindAction {
     Continue,
     Unreachable,
     Terminate,
-    Cleanup(usize),
+    Cleanup(BasicBlockIdx),
 }
 
 #[derive(Clone, Debug, Eq, PartialEq)]
@@ -662,10 +658,45 @@ pub struct Constant {
     pub literal: Const,
 }
 
+/// The possible branch sites of a [TerminatorKind::SwitchInt].
 #[derive(Clone, Debug, Eq, PartialEq)]
 pub struct SwitchTargets {
-    pub value: Vec<u128>,
-    pub targets: Vec<usize>,
+    /// The conditional branches where the first element represents the value that guards this
+    /// branch, and the second element is the branch target.
+    branches: Vec<(u128, BasicBlockIdx)>,
+    /// The `otherwise` branch which will be taken in case none of the conditional branches are
+    /// satisfied.
+    otherwise: BasicBlockIdx,
+}
+
+impl SwitchTargets {
+    /// All possible targets including the `otherwise` target.
+    pub fn all_targets(&self) -> Successors {
+        Some(self.otherwise)
+            .into_iter()
+            .chain(self.branches.iter().map(|(_, target)| *target))
+            .collect()
+    }
+
+    /// The `otherwise` branch target.
+    pub fn otherwise(&self) -> BasicBlockIdx {
+        self.otherwise
+    }
+
+    /// The conditional targets which are only taken if the pattern matches the given value.
+    pub fn branches(&self) -> impl Iterator<Item = (u128, BasicBlockIdx)> + '_ {
+        self.branches.iter().copied()
+    }
+
+    /// The number of targets including `otherwise`.
+    pub fn len(&self) -> usize {
+        self.branches.len() + 1
+    }
+
+    /// Create a new SwitchTargets from the given branches and `otherwise` target.
+    pub fn new(branches: Vec<(u128, BasicBlockIdx)>, otherwise: BasicBlockIdx) -> SwitchTargets {
+        SwitchTargets { branches, otherwise }
+    }
 }
 
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
diff --git a/compiler/stable_mir/src/mir/pretty.rs b/compiler/stable_mir/src/mir/pretty.rs
index 759c3b148df..3a0eed521dc 100644
--- a/compiler/stable_mir/src/mir/pretty.rs
+++ b/compiler/stable_mir/src/mir/pretty.rs
@@ -76,7 +76,8 @@ pub fn pretty_statement(statement: &StatementKind) -> String {
 
 pub fn pretty_terminator<W: io::Write>(terminator: &TerminatorKind, w: &mut W) -> io::Result<()> {
     write!(w, "{}", pretty_terminator_head(terminator))?;
-    let successor_count = terminator.successors().count();
+    let successors = terminator.successors();
+    let successor_count = successors.len();
     let labels = pretty_successor_labels(terminator);
 
     let show_unwind = !matches!(terminator.unwind(), None | Some(UnwindAction::Cleanup(_)));
@@ -98,12 +99,12 @@ pub fn pretty_terminator<W: io::Write>(terminator: &TerminatorKind, w: &mut W) -
             Ok(())
         }
         (1, false) => {
-            write!(w, " -> {:?}", terminator.successors().next().unwrap())?;
+            write!(w, " -> {:?}", successors[0])?;
             Ok(())
         }
         _ => {
             write!(w, " -> [")?;
-            for (i, target) in terminator.successors().enumerate() {
+            for (i, target) in successors.iter().enumerate() {
                 if i > 0 {
                     write!(w, ", ")?;
                 }
@@ -157,7 +158,6 @@ pub fn pretty_terminator_head(terminator: &TerminatorKind) -> String {
             pretty.push_str(")");
             pretty
         }
-        CoroutineDrop => format!("        coroutine_drop"),
         InlineAsm { .. } => todo!(),
     }
 }
@@ -165,12 +165,11 @@ pub fn pretty_terminator_head(terminator: &TerminatorKind) -> String {
 pub fn pretty_successor_labels(terminator: &TerminatorKind) -> Vec<String> {
     use self::TerminatorKind::*;
     match terminator {
-        Resume | Abort | Return | Unreachable | CoroutineDrop => vec![],
+        Resume | Abort | Return | Unreachable => vec![],
         Goto { .. } => vec!["".to_string()],
         SwitchInt { targets, .. } => targets
-            .value
-            .iter()
-            .map(|target| format!("{}", target))
+            .branches()
+            .map(|(val, _target)| format!("{val}"))
             .chain(iter::once("otherwise".into()))
             .collect(),
         Drop { unwind: UnwindAction::Cleanup(_), .. } => vec!["return".into(), "unwind".into()],
diff --git a/compiler/stable_mir/src/mir/visit.rs b/compiler/stable_mir/src/mir/visit.rs
index 69bf6ca72b0..0c44781c463 100644
--- a/compiler/stable_mir/src/mir/visit.rs
+++ b/compiler/stable_mir/src/mir/visit.rs
@@ -237,8 +237,7 @@ pub trait MirVisitor {
             TerminatorKind::Goto { .. }
             | TerminatorKind::Resume
             | TerminatorKind::Abort
-            | TerminatorKind::Unreachable
-            | TerminatorKind::CoroutineDrop => {}
+            | TerminatorKind::Unreachable => {}
             TerminatorKind::Assert { cond, expected: _, msg, target: _, unwind: _ } => {
                 self.visit_operand(cond, location);
                 self.visit_assert_msg(msg, location);
@@ -268,7 +267,7 @@ pub trait MirVisitor {
                 let local = RETURN_LOCAL;
                 self.visit_local(&local, PlaceContext::NON_MUTATING, location);
             }
-            TerminatorKind::SwitchInt { discr, targets: _, otherwise: _ } => {
+            TerminatorKind::SwitchInt { discr, targets: _ } => {
                 self.visit_operand(discr, location);
             }
         }