about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCelina G. Val <celinval@amazon.com>2024-04-08 15:41:00 -0700
committerCelina G. Val <celinval@amazon.com>2024-04-08 15:47:37 -0700
commit0a4f4a3e29c5b5739db40aa111c115e1977539d1 (patch)
treec8b7f9cd9bbd1fb337b5c1fddd1cecbd06eadefa
parentea40fa210b87a322d2259852c149ffa212a3a0da (diff)
downloadrust-0a4f4a3e29c5b5739db40aa111c115e1977539d1.tar.gz
rust-0a4f4a3e29c5b5739db40aa111c115e1977539d1.zip
Remove unimplemented!() from BinOp::ty() function
To reduce redundancy, we now internalize the BinOp instead of
duplicating the `ty()` function body.
-rw-r--r--compiler/rustc_smir/src/rustc_internal/internal.rs34
-rw-r--r--compiler/rustc_smir/src/rustc_smir/context.rs11
-rw-r--r--compiler/stable_mir/src/compiler_interface.rs5
-rw-r--r--compiler/stable_mir/src/mir/body.rs38
-rw-r--r--tests/ui-fulldeps/stable-mir/check_binop.rs147
5 files changed, 196 insertions, 39 deletions
diff --git a/compiler/rustc_smir/src/rustc_internal/internal.rs b/compiler/rustc_smir/src/rustc_internal/internal.rs
index e8cc41cc886..79808f20b48 100644
--- a/compiler/rustc_smir/src/rustc_internal/internal.rs
+++ b/compiler/rustc_smir/src/rustc_internal/internal.rs
@@ -10,7 +10,7 @@ use rustc_span::Symbol;
 use stable_mir::abi::Layout;
 use stable_mir::mir::alloc::AllocId;
 use stable_mir::mir::mono::{Instance, MonoItem, StaticDef};
-use stable_mir::mir::{Mutability, Place, ProjectionElem, Safety};
+use stable_mir::mir::{BinOp, Mutability, Place, ProjectionElem, Safety};
 use stable_mir::ty::{
     Abi, AdtDef, Binder, BoundRegionKind, BoundTyKind, BoundVariableKind, ClosureKind, Const,
     DynKind, ExistentialPredicate, ExistentialProjection, ExistentialTraitRef, FloatTy, FnSig,
@@ -535,6 +535,38 @@ impl RustcInternal for ProjectionElem {
     }
 }
 
+impl RustcInternal for BinOp {
+    type T<'tcx> = rustc_middle::mir::BinOp;
+
+    fn internal<'tcx>(&self, _tables: &mut Tables<'_>, _tcx: TyCtxt<'tcx>) -> Self::T<'tcx> {
+        match self {
+            BinOp::Add => rustc_middle::mir::BinOp::Add,
+            BinOp::AddUnchecked => rustc_middle::mir::BinOp::AddUnchecked,
+            BinOp::Sub => rustc_middle::mir::BinOp::Sub,
+            BinOp::SubUnchecked => rustc_middle::mir::BinOp::SubUnchecked,
+            BinOp::Mul => rustc_middle::mir::BinOp::Mul,
+            BinOp::MulUnchecked => rustc_middle::mir::BinOp::MulUnchecked,
+            BinOp::Div => rustc_middle::mir::BinOp::Div,
+            BinOp::Rem => rustc_middle::mir::BinOp::Rem,
+            BinOp::BitXor => rustc_middle::mir::BinOp::BitXor,
+            BinOp::BitAnd => rustc_middle::mir::BinOp::BitAnd,
+            BinOp::BitOr => rustc_middle::mir::BinOp::BitOr,
+            BinOp::Shl => rustc_middle::mir::BinOp::Shl,
+            BinOp::ShlUnchecked => rustc_middle::mir::BinOp::ShlUnchecked,
+            BinOp::Shr => rustc_middle::mir::BinOp::Shr,
+            BinOp::ShrUnchecked => rustc_middle::mir::BinOp::ShrUnchecked,
+            BinOp::Eq => rustc_middle::mir::BinOp::Eq,
+            BinOp::Lt => rustc_middle::mir::BinOp::Lt,
+            BinOp::Le => rustc_middle::mir::BinOp::Le,
+            BinOp::Ne => rustc_middle::mir::BinOp::Ne,
+            BinOp::Ge => rustc_middle::mir::BinOp::Ge,
+            BinOp::Gt => rustc_middle::mir::BinOp::Gt,
+            BinOp::Cmp => rustc_middle::mir::BinOp::Cmp,
+            BinOp::Offset => rustc_middle::mir::BinOp::Offset,
+        }
+    }
+}
+
 impl<T> RustcInternal for &T
 where
     T: RustcInternal,
diff --git a/compiler/rustc_smir/src/rustc_smir/context.rs b/compiler/rustc_smir/src/rustc_smir/context.rs
index 7c12168b809..61bbedf9eec 100644
--- a/compiler/rustc_smir/src/rustc_smir/context.rs
+++ b/compiler/rustc_smir/src/rustc_smir/context.rs
@@ -19,7 +19,7 @@ use stable_mir::abi::{FnAbi, Layout, LayoutShape};
 use stable_mir::compiler_interface::Context;
 use stable_mir::mir::alloc::GlobalAlloc;
 use stable_mir::mir::mono::{InstanceDef, StaticDef};
-use stable_mir::mir::{Body, Place};
+use stable_mir::mir::{BinOp, Body, Place};
 use stable_mir::target::{MachineInfo, MachineSize};
 use stable_mir::ty::{
     AdtDef, AdtKind, Allocation, ClosureDef, ClosureKind, Const, FieldDef, FnDef, ForeignDef,
@@ -668,6 +668,15 @@ impl<'tcx> Context for TablesWrapper<'tcx> {
         let tcx = tables.tcx;
         format!("{:?}", place.internal(&mut *tables, tcx))
     }
+
+    fn binop_ty(&self, bin_op: BinOp, rhs: Ty, lhs: Ty) -> Ty {
+        let mut tables = self.0.borrow_mut();
+        let tcx = tables.tcx;
+        let rhs_internal = rhs.internal(&mut *tables, tcx);
+        let lhs_internal = lhs.internal(&mut *tables, tcx);
+        let ty = bin_op.internal(&mut *tables, tcx).ty(tcx, rhs_internal, lhs_internal);
+        ty.stable(&mut *tables)
+    }
 }
 
 pub struct TablesWrapper<'tcx>(pub RefCell<Tables<'tcx>>);
diff --git a/compiler/stable_mir/src/compiler_interface.rs b/compiler/stable_mir/src/compiler_interface.rs
index 8ed34fab54d..94c552199bc 100644
--- a/compiler/stable_mir/src/compiler_interface.rs
+++ b/compiler/stable_mir/src/compiler_interface.rs
@@ -8,7 +8,7 @@ use std::cell::Cell;
 use crate::abi::{FnAbi, Layout, LayoutShape};
 use crate::mir::alloc::{AllocId, GlobalAlloc};
 use crate::mir::mono::{Instance, InstanceDef, StaticDef};
-use crate::mir::{Body, Place};
+use crate::mir::{BinOp, Body, Place};
 use crate::target::MachineInfo;
 use crate::ty::{
     AdtDef, AdtKind, Allocation, ClosureDef, ClosureKind, Const, FieldDef, FnDef, ForeignDef,
@@ -211,6 +211,9 @@ pub trait Context {
 
     /// Get a debug string representation of a place.
     fn place_pretty(&self, place: &Place) -> String;
+
+    /// Get the resulting type of binary operation.
+    fn binop_ty(&self, bin_op: BinOp, rhs: Ty, lhs: Ty) -> Ty;
 }
 
 // A thread local variable that stores a pointer to the tables mapping between TyCtxt
diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs
index 8f77a19fc0e..79c3906b817 100644
--- a/compiler/stable_mir/src/mir/body.rs
+++ b/compiler/stable_mir/src/mir/body.rs
@@ -1,3 +1,4 @@
+use crate::compiler_interface::with;
 use crate::mir::pretty::function_body;
 use crate::ty::{
     AdtDef, ClosureDef, Const, CoroutineDef, GenericArgs, Movability, Region, RigidTy, Ty, TyKind,
@@ -337,42 +338,7 @@ impl BinOp {
     /// Return the type of this operation for the given input Ty.
     /// This function does not perform type checking, and it currently doesn't handle SIMD.
     pub fn ty(&self, lhs_ty: Ty, rhs_ty: Ty) -> Ty {
-        match self {
-            BinOp::Add
-            | BinOp::AddUnchecked
-            | BinOp::Sub
-            | BinOp::SubUnchecked
-            | BinOp::Mul
-            | BinOp::MulUnchecked
-            | BinOp::Div
-            | BinOp::Rem
-            | BinOp::BitXor
-            | BinOp::BitAnd
-            | BinOp::BitOr => {
-                assert_eq!(lhs_ty, rhs_ty);
-                assert!(lhs_ty.kind().is_primitive());
-                lhs_ty
-            }
-            BinOp::Shl | BinOp::ShlUnchecked | BinOp::Shr | BinOp::ShrUnchecked => {
-                assert!(lhs_ty.kind().is_primitive());
-                assert!(rhs_ty.kind().is_primitive());
-                lhs_ty
-            }
-            BinOp::Offset => {
-                assert!(lhs_ty.kind().is_raw_ptr());
-                assert!(rhs_ty.kind().is_integral());
-                lhs_ty
-            }
-            BinOp::Eq | BinOp::Lt | BinOp::Le | BinOp::Ne | BinOp::Ge | BinOp::Gt => {
-                assert_eq!(lhs_ty, rhs_ty);
-                let lhs_kind = lhs_ty.kind();
-                assert!(lhs_kind.is_primitive() || lhs_kind.is_raw_ptr() || lhs_kind.is_fn_ptr());
-                Ty::bool_ty()
-            }
-            BinOp::Cmp => {
-                unimplemented!("Should cmp::Ordering be a RigidTy?");
-            }
-        }
+        with(|ctx| ctx.binop_ty(*self, lhs_ty, rhs_ty))
     }
 }
 
diff --git a/tests/ui-fulldeps/stable-mir/check_binop.rs b/tests/ui-fulldeps/stable-mir/check_binop.rs
new file mode 100644
index 00000000000..3b52d88de3c
--- /dev/null
+++ b/tests/ui-fulldeps/stable-mir/check_binop.rs
@@ -0,0 +1,147 @@
+//@ run-pass
+//! Test information regarding binary operations.
+
+//@ ignore-stage1
+//@ ignore-cross-compile
+//@ ignore-remote
+//@ ignore-windows-gnu mingw has troubles with linking https://github.com/rust-lang/rust/pull/116837
+
+#![feature(rustc_private)]
+
+extern crate rustc_hir;
+#[macro_use]
+extern crate rustc_smir;
+extern crate rustc_driver;
+extern crate rustc_interface;
+extern crate stable_mir;
+
+use rustc_smir::rustc_internal;
+use stable_mir::mir::mono::Instance;
+use stable_mir::mir::visit::{Location, MirVisitor};
+use stable_mir::mir::{LocalDecl, Rvalue, Statement, StatementKind, Terminator, TerminatorKind};
+use stable_mir::ty::{RigidTy, TyKind};
+use std::collections::HashSet;
+use std::convert::TryFrom;
+use std::io::Write;
+use std::ops::ControlFlow;
+
+/// This function tests that we can correctly get type information from binary operations.
+fn test_binops() -> ControlFlow<()> {
+    // Find items in the local crate.
+    let items = stable_mir::all_local_items();
+    let mut instances =
+        items.into_iter().map(|item| Instance::try_from(item).unwrap()).collect::<Vec<_>>();
+    while let Some(instance) = instances.pop() {
+        // The test below shouldn't have recursion in it.
+        let Some(body) = instance.body() else {
+            continue;
+        };
+        let mut visitor = Visitor { locals: body.locals(), calls: Default::default() };
+        visitor.visit_body(&body);
+        instances.extend(visitor.calls.into_iter());
+    }
+    ControlFlow::Continue(())
+}
+
+struct Visitor<'a> {
+    locals: &'a [LocalDecl],
+    calls: HashSet<Instance>,
+}
+
+impl<'a> MirVisitor for Visitor<'a> {
+    fn visit_statement(&mut self, stmt: &Statement, _loc: Location) {
+        match &stmt.kind {
+            StatementKind::Assign(place, Rvalue::BinaryOp(op, rhs, lhs)) => {
+                let ret_ty = place.ty(self.locals).unwrap();
+                let op_ty = op.ty(rhs.ty(self.locals).unwrap(), lhs.ty(self.locals).unwrap());
+                assert_eq!(ret_ty, op_ty, "Operation type should match the assigned place type");
+            }
+            _ => {}
+        }
+    }
+
+    fn visit_terminator(&mut self, term: &Terminator, _loc: Location) {
+        match &term.kind {
+            TerminatorKind::Call { func, .. } => {
+                let TyKind::RigidTy(RigidTy::FnDef(def, args)) =
+                    func.ty(self.locals).unwrap().kind()
+                    else {
+                        return;
+                    };
+                self.calls.insert(Instance::resolve(def, &args).unwrap());
+            }
+            _ => {}
+        }
+    }
+}
+
+/// This test will generate and analyze a dummy crate using the stable mir.
+/// For that, it will first write the dummy crate into a file.
+/// Then it will create a `StableMir` using custom arguments and then
+/// it will run the compiler.
+fn main() {
+    let path = "binop_input.rs";
+    generate_input(&path).unwrap();
+    let args = vec!["rustc".to_string(), "--crate-type=lib".to_string(), path.to_string()];
+    run!(args, test_binops).unwrap();
+}
+
+fn generate_input(path: &str) -> std::io::Result<()> {
+    let mut file = std::fs::File::create(path)?;
+    write!(
+        file,
+        r#"
+        macro_rules! binop_int {{
+            ($fn:ident, $typ:ty) => {{
+                pub fn $fn(lhs: $typ, rhs: $typ) {{
+                    let eq = lhs == rhs;
+                    let lt = lhs < rhs;
+                    let le = lhs <= rhs;
+
+                    let sum = lhs + rhs;
+                    let mult = lhs * sum;
+                    let shift = mult << 2;
+                    let bit_or = shift | rhs;
+                    let cmp = lhs.cmp(&bit_or);
+
+                    // Try to avoid the results above being pruned
+                    std::hint::black_box(((eq, lt, le), cmp));
+                }}
+            }}
+        }}
+
+        binop_int!(binop_u8, u8);
+        binop_int!(binop_i64, i64);
+
+        pub fn binop_bool(lhs: bool, rhs: bool) {{
+            let eq = lhs == rhs;
+            let or = lhs | eq;
+            let lt = lhs < or;
+            let cmp = lhs.cmp(&rhs);
+
+            // Try to avoid the results above being pruned
+            std::hint::black_box((lt, cmp));
+        }}
+
+        pub fn binop_char(lhs: char, rhs: char) {{
+            let eq = lhs == rhs;
+            let lt = lhs < rhs;
+            let cmp = lhs.cmp(&rhs);
+
+            // Try to avoid the results above being pruned
+            std::hint::black_box(([eq, lt], cmp));
+        }}
+
+        pub fn binop_ptr(lhs: *const char, rhs: *const char) {{
+            let eq = lhs == rhs;
+            let lt = lhs < rhs;
+            let cmp = lhs.cmp(&rhs);
+            let off = unsafe {{ lhs.offset(2) }};
+
+            // Try to avoid the results above being pruned
+            std::hint::black_box(([eq, lt], cmp, off));
+        }}
+        "#
+    )?;
+    Ok(())
+}