about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-01-17 12:39:53 +0000
committerbors <bors@rust-lang.org>2023-01-17 12:39:53 +0000
commit492b3deba7527ca4e0b5fbed2551891b905507b8 (patch)
tree244719c55d4d4cdbe11941264632b836d45cd021
parentfa874627f0adcd5a834b116c7a475b56611317c6 (diff)
parentc53064fb58ac170dbd93b9529d78fc324c1dae9f (diff)
downloadrust-492b3deba7527ca4e0b5fbed2551891b905507b8.tar.gz
rust-492b3deba7527ca4e0b5fbed2551891b905507b8.zip
Auto merge of #13971 - lowr:fix/more-precise-builtin-binop-types, r=Veykril
fix: more precise binop inference

While inferring binary operator expressions, Rust puts some extra constraints on the types of the operands for better inference. Relevant part in rustc is [this](https://github.com/rust-lang/rust/blob/159ba8a92c9e2fa4121f106176309521f4af87e9/compiler/rustc_hir_typeck/src/op.rs#L128-L152).

There are two things we currently fail to consider:
- we should enforce them only when both lhs and rhs type are builtin types that are applicable to the binop
- lhs and rhs types may be single reference to applicable builtin types

This PR basically ports [`enforce_builtin_binop_types()`](https://github.com/rust-lang/rust/blob/159ba8a92c9e2fa4121f106176309521f4af87e9/compiler/rustc_hir_typeck/src/op.rs#L159) and [`is_builtin_binop()`](https://github.com/rust-lang/rust/blob/159ba8a92c9e2fa4121f106176309521f4af87e9/compiler/rustc_hir_typeck/src/op.rs#LL927) to our inference context.
-rw-r--r--crates/hir-ty/src/chalk_ext.rs19
-rw-r--r--crates/hir-ty/src/infer.rs4
-rw-r--r--crates/hir-ty/src/infer/expr.rs223
-rw-r--r--crates/hir-ty/src/tests/traits.rs124
4 files changed, 247 insertions, 123 deletions
diff --git a/crates/hir-ty/src/chalk_ext.rs b/crates/hir-ty/src/chalk_ext.rs
index 996b42f5bd8..0244b6c653e 100644
--- a/crates/hir-ty/src/chalk_ext.rs
+++ b/crates/hir-ty/src/chalk_ext.rs
@@ -1,6 +1,6 @@
 //! Various extensions traits for Chalk types.
 
-use chalk_ir::{FloatTy, IntTy, Mutability, Scalar, UintTy};
+use chalk_ir::{FloatTy, IntTy, Mutability, Scalar, TyVariableKind, UintTy};
 use hir_def::{
     builtin_type::{BuiltinFloat, BuiltinInt, BuiltinType, BuiltinUint},
     generics::TypeOrConstParamData,
@@ -18,6 +18,8 @@ use crate::{
 
 pub trait TyExt {
     fn is_unit(&self) -> bool;
+    fn is_integral(&self) -> bool;
+    fn is_floating_point(&self) -> bool;
     fn is_never(&self) -> bool;
     fn is_unknown(&self) -> bool;
     fn is_ty_var(&self) -> bool;
@@ -51,6 +53,21 @@ impl TyExt for Ty {
         matches!(self.kind(Interner), TyKind::Tuple(0, _))
     }
 
+    fn is_integral(&self) -> bool {
+        matches!(
+            self.kind(Interner),
+            TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
+                | TyKind::InferenceVar(_, TyVariableKind::Integer)
+        )
+    }
+
+    fn is_floating_point(&self) -> bool {
+        matches!(
+            self.kind(Interner),
+            TyKind::Scalar(Scalar::Float(_)) | TyKind::InferenceVar(_, TyVariableKind::Float)
+        )
+    }
+
     fn is_never(&self) -> bool {
         matches!(self.kind(Interner), TyKind::Never)
     }
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs
index 6b59f1c20da..0e177db7726 100644
--- a/crates/hir-ty/src/infer.rs
+++ b/crates/hir-ty/src/infer.rs
@@ -1041,10 +1041,6 @@ impl Expectation {
         }
     }
 
-    fn from_option(ty: Option<Ty>) -> Self {
-        ty.map_or(Expectation::None, Expectation::HasType)
-    }
-
     /// The following explanation is copied straight from rustc:
     /// Provides an expectation for an rvalue expression given an *optional*
     /// hint, which is not required for type safety (the resulting type might
diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs
index 3f78806bd77..6f347f6757b 100644
--- a/crates/hir-ty/src/infer/expr.rs
+++ b/crates/hir-ty/src/infer/expr.rs
@@ -10,8 +10,7 @@ use chalk_ir::{
 };
 use hir_def::{
     expr::{
-        ArithOp, Array, BinaryOp, ClosureKind, CmpOp, Expr, ExprId, LabelId, Literal, Statement,
-        UnaryOp,
+        ArithOp, Array, BinaryOp, ClosureKind, Expr, ExprId, LabelId, Literal, Statement, UnaryOp,
     },
     generics::TypeOrConstParamData,
     path::{GenericArg, GenericArgs},
@@ -1017,11 +1016,21 @@ impl<'a> InferenceContext<'a> {
         let (trait_, func) = match trait_func {
             Some(it) => it,
             None => {
-                let rhs_ty = self.builtin_binary_op_rhs_expectation(op, lhs_ty.clone());
-                let rhs_ty = self.infer_expr_coerce(rhs, &Expectation::from_option(rhs_ty));
-                return self
-                    .builtin_binary_op_return_ty(op, lhs_ty, rhs_ty)
-                    .unwrap_or_else(|| self.err_ty());
+                // HACK: `rhs_ty` is a general inference variable with no clue at all at this
+                // point. Passing `lhs_ty` as both operands just to check if `lhs_ty` is a builtin
+                // type applicable to `op`.
+                let ret_ty = if self.is_builtin_binop(&lhs_ty, &lhs_ty, op) {
+                    // Assume both operands are builtin so we can continue inference. No guarantee
+                    // on the correctness, rustc would complain as necessary lang items don't seem
+                    // to exist anyway.
+                    self.enforce_builtin_binop_types(&lhs_ty, &rhs_ty, op)
+                } else {
+                    self.err_ty()
+                };
+
+                self.infer_expr_coerce(rhs, &Expectation::has_type(rhs_ty));
+
+                return ret_ty;
             }
         };
 
@@ -1071,11 +1080,9 @@ impl<'a> InferenceContext<'a> {
 
         let ret_ty = self.normalize_associated_types_in(ret_ty);
 
-        // use knowledge of built-in binary ops, which can sometimes help inference
-        if let Some(builtin_rhs) = self.builtin_binary_op_rhs_expectation(op, lhs_ty.clone()) {
-            self.unify(&builtin_rhs, &rhs_ty);
-        }
-        if let Some(builtin_ret) = self.builtin_binary_op_return_ty(op, lhs_ty, rhs_ty) {
+        if self.is_builtin_binop(&lhs_ty, &rhs_ty, op) {
+            // use knowledge of built-in binary ops, which can sometimes help inference
+            let builtin_ret = self.enforce_builtin_binop_types(&lhs_ty, &rhs_ty, op);
             self.unify(&builtin_ret, &ret_ty);
         }
 
@@ -1477,92 +1484,124 @@ impl<'a> InferenceContext<'a> {
         indices
     }
 
-    fn builtin_binary_op_return_ty(&mut self, op: BinaryOp, lhs_ty: Ty, rhs_ty: Ty) -> Option<Ty> {
-        let lhs_ty = self.resolve_ty_shallow(&lhs_ty);
-        let rhs_ty = self.resolve_ty_shallow(&rhs_ty);
-        match op {
-            BinaryOp::LogicOp(_) | BinaryOp::CmpOp(_) => {
-                Some(TyKind::Scalar(Scalar::Bool).intern(Interner))
+    /// Dereferences a single level of immutable referencing.
+    fn deref_ty_if_possible(&mut self, ty: &Ty) -> Ty {
+        let ty = self.resolve_ty_shallow(ty);
+        match ty.kind(Interner) {
+            TyKind::Ref(Mutability::Not, _, inner) => self.resolve_ty_shallow(inner),
+            _ => ty,
+        }
+    }
+
+    /// Enforces expectations on lhs type and rhs type depending on the operator and returns the
+    /// output type of the binary op.
+    fn enforce_builtin_binop_types(&mut self, lhs: &Ty, rhs: &Ty, op: BinaryOp) -> Ty {
+        // Special-case a single layer of referencing, so that things like `5.0 + &6.0f32` work (See rust-lang/rust#57447).
+        let lhs = self.deref_ty_if_possible(lhs);
+        let rhs = self.deref_ty_if_possible(rhs);
+
+        let (op, is_assign) = match op {
+            BinaryOp::Assignment { op: Some(inner) } => (BinaryOp::ArithOp(inner), true),
+            _ => (op, false),
+        };
+
+        let output_ty = match op {
+            BinaryOp::LogicOp(_) => {
+                let bool_ = self.result.standard_types.bool_.clone();
+                self.unify(&lhs, &bool_);
+                self.unify(&rhs, &bool_);
+                bool_
             }
-            BinaryOp::Assignment { .. } => Some(TyBuilder::unit()),
+
             BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
-                // all integer combinations are valid here
-                if matches!(
-                    lhs_ty.kind(Interner),
-                    TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
-                        | TyKind::InferenceVar(_, TyVariableKind::Integer)
-                ) && matches!(
-                    rhs_ty.kind(Interner),
-                    TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
-                        | TyKind::InferenceVar(_, TyVariableKind::Integer)
-                ) {
-                    Some(lhs_ty)
-                } else {
-                    None
-                }
+                // result type is same as LHS always
+                lhs
             }
-            BinaryOp::ArithOp(_) => match (lhs_ty.kind(Interner), rhs_ty.kind(Interner)) {
-                // (int, int) | (uint, uint) | (float, float)
-                (TyKind::Scalar(Scalar::Int(_)), TyKind::Scalar(Scalar::Int(_)))
-                | (TyKind::Scalar(Scalar::Uint(_)), TyKind::Scalar(Scalar::Uint(_)))
-                | (TyKind::Scalar(Scalar::Float(_)), TyKind::Scalar(Scalar::Float(_))) => {
-                    Some(rhs_ty)
-                }
-                // ({int}, int) | ({int}, uint)
-                (
-                    TyKind::InferenceVar(_, TyVariableKind::Integer),
-                    TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_)),
-                ) => Some(rhs_ty),
-                // (int, {int}) | (uint, {int})
-                (
-                    TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_)),
-                    TyKind::InferenceVar(_, TyVariableKind::Integer),
-                ) => Some(lhs_ty),
-                // ({float} | float)
-                (
-                    TyKind::InferenceVar(_, TyVariableKind::Float),
-                    TyKind::Scalar(Scalar::Float(_)),
-                ) => Some(rhs_ty),
-                // (float, {float})
-                (
-                    TyKind::Scalar(Scalar::Float(_)),
-                    TyKind::InferenceVar(_, TyVariableKind::Float),
-                ) => Some(lhs_ty),
-                // ({int}, {int}) | ({float}, {float})
-                (
-                    TyKind::InferenceVar(_, TyVariableKind::Integer),
-                    TyKind::InferenceVar(_, TyVariableKind::Integer),
-                )
-                | (
-                    TyKind::InferenceVar(_, TyVariableKind::Float),
-                    TyKind::InferenceVar(_, TyVariableKind::Float),
-                ) => Some(rhs_ty),
-                _ => None,
-            },
+
+            BinaryOp::ArithOp(_) => {
+                // LHS, RHS, and result will have the same type
+                self.unify(&lhs, &rhs);
+                lhs
+            }
+
+            BinaryOp::CmpOp(_) => {
+                // LHS and RHS will have the same type
+                self.unify(&lhs, &rhs);
+                self.result.standard_types.bool_.clone()
+            }
+
+            BinaryOp::Assignment { op: None } => {
+                stdx::never!("Simple assignment operator is not binary op.");
+                lhs
+            }
+
+            BinaryOp::Assignment { .. } => unreachable!("handled above"),
+        };
+
+        if is_assign {
+            self.result.standard_types.unit.clone()
+        } else {
+            output_ty
         }
     }
 
-    fn builtin_binary_op_rhs_expectation(&mut self, op: BinaryOp, lhs_ty: Ty) -> Option<Ty> {
-        Some(match op {
-            BinaryOp::LogicOp(..) => TyKind::Scalar(Scalar::Bool).intern(Interner),
-            BinaryOp::Assignment { op: None } => lhs_ty,
-            BinaryOp::CmpOp(CmpOp::Eq { .. }) => match self
-                .resolve_ty_shallow(&lhs_ty)
-                .kind(Interner)
-            {
-                TyKind::Scalar(_) | TyKind::Str => lhs_ty,
-                TyKind::InferenceVar(_, TyVariableKind::Integer | TyVariableKind::Float) => lhs_ty,
-                _ => return None,
-            },
-            BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => return None,
-            BinaryOp::CmpOp(CmpOp::Ord { .. })
-            | BinaryOp::Assignment { op: Some(_) }
-            | BinaryOp::ArithOp(_) => match self.resolve_ty_shallow(&lhs_ty).kind(Interner) {
-                TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_) | Scalar::Float(_)) => lhs_ty,
-                TyKind::InferenceVar(_, TyVariableKind::Integer | TyVariableKind::Float) => lhs_ty,
-                _ => return None,
-            },
-        })
+    fn is_builtin_binop(&mut self, lhs: &Ty, rhs: &Ty, op: BinaryOp) -> bool {
+        // Special-case a single layer of referencing, so that things like `5.0 + &6.0f32` work (See rust-lang/rust#57447).
+        let lhs = self.deref_ty_if_possible(lhs);
+        let rhs = self.deref_ty_if_possible(rhs);
+
+        let op = match op {
+            BinaryOp::Assignment { op: Some(inner) } => BinaryOp::ArithOp(inner),
+            _ => op,
+        };
+
+        match op {
+            BinaryOp::LogicOp(_) => true,
+
+            BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
+                lhs.is_integral() && rhs.is_integral()
+            }
+
+            BinaryOp::ArithOp(
+                ArithOp::Add | ArithOp::Sub | ArithOp::Mul | ArithOp::Div | ArithOp::Rem,
+            ) => {
+                lhs.is_integral() && rhs.is_integral()
+                    || lhs.is_floating_point() && rhs.is_floating_point()
+            }
+
+            BinaryOp::ArithOp(ArithOp::BitAnd | ArithOp::BitOr | ArithOp::BitXor) => {
+                lhs.is_integral() && rhs.is_integral()
+                    || lhs.is_floating_point() && rhs.is_floating_point()
+                    || matches!(
+                        (lhs.kind(Interner), rhs.kind(Interner)),
+                        (TyKind::Scalar(Scalar::Bool), TyKind::Scalar(Scalar::Bool))
+                    )
+            }
+
+            BinaryOp::CmpOp(_) => {
+                let is_scalar = |kind| {
+                    matches!(
+                        kind,
+                        &TyKind::Scalar(_)
+                            | TyKind::FnDef(..)
+                            | TyKind::Function(_)
+                            | TyKind::Raw(..)
+                            | TyKind::InferenceVar(
+                                _,
+                                TyVariableKind::Integer | TyVariableKind::Float
+                            )
+                    )
+                };
+                is_scalar(lhs.kind(Interner)) && is_scalar(rhs.kind(Interner))
+            }
+
+            BinaryOp::Assignment { op: None } => {
+                stdx::never!("Simple assignment operator is not binary op.");
+                false
+            }
+
+            BinaryOp::Assignment { .. } => unreachable!("handled above"),
+        }
     }
 
     fn with_breakable_ctx<T>(
diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs
index d01fe063285..4c560702a1b 100644
--- a/crates/hir-ty/src/tests/traits.rs
+++ b/crates/hir-ty/src/tests/traits.rs
@@ -3507,14 +3507,9 @@ trait Request {
 fn bin_op_adt_with_rhs_primitive() {
     check_infer_with_mismatches(
         r#"
-#[lang = "add"]
-pub trait Add<Rhs = Self> {
-    type Output;
-    fn add(self, rhs: Rhs) -> Self::Output;
-}
-
+//- minicore: add
 struct Wrapper(u32);
-impl Add<u32> for Wrapper {
+impl core::ops::Add<u32> for Wrapper {
     type Output = Self;
     fn add(self, rhs: u32) -> Wrapper {
         Wrapper(rhs)
@@ -3527,30 +3522,107 @@ fn main(){
 
 }"#,
         expect![[r#"
-            72..76 'self': Self
-            78..81 'rhs': Rhs
-            192..196 'self': Wrapper
-            198..201 'rhs': u32
-            219..247 '{     ...     }': Wrapper
-            229..236 'Wrapper': Wrapper(u32) -> Wrapper
-            229..241 'Wrapper(rhs)': Wrapper
-            237..240 'rhs': u32
-            259..345 '{     ...um;  }': ()
-            269..276 'wrapped': Wrapper
-            279..286 'Wrapper': Wrapper(u32) -> Wrapper
-            279..290 'Wrapper(10)': Wrapper
-            287..289 '10': u32
-            300..303 'num': u32
-            311..312 '2': u32
-            322..325 'res': Wrapper
-            328..335 'wrapped': Wrapper
-            328..341 'wrapped + num': Wrapper
-            338..341 'num': u32
+            95..99 'self': Wrapper
+            101..104 'rhs': u32
+            122..150 '{     ...     }': Wrapper
+            132..139 'Wrapper': Wrapper(u32) -> Wrapper
+            132..144 'Wrapper(rhs)': Wrapper
+            140..143 'rhs': u32
+            162..248 '{     ...um;  }': ()
+            172..179 'wrapped': Wrapper
+            182..189 'Wrapper': Wrapper(u32) -> Wrapper
+            182..193 'Wrapper(10)': Wrapper
+            190..192 '10': u32
+            203..206 'num': u32
+            214..215 '2': u32
+            225..228 'res': Wrapper
+            231..238 'wrapped': Wrapper
+            231..244 'wrapped + num': Wrapper
+            241..244 'num': u32
         "#]],
     )
 }
 
 #[test]
+fn builtin_binop_expectation_works_on_single_reference() {
+    check_types(
+        r#"
+//- minicore: add
+use core::ops::Add;
+impl Add<i32> for i32 { type Output = i32 }
+impl Add<&i32> for i32 { type Output = i32 }
+impl Add<u32> for u32 { type Output = u32 }
+impl Add<&u32> for u32 { type Output = u32 }
+
+struct V<T>;
+impl<T> V<T> {
+    fn default() -> Self { loop {} }
+    fn get(&self, _: &T) -> &T { loop {} }
+}
+
+fn take_u32(_: u32) {}
+fn minimized() {
+    let v = V::default();
+    let p = v.get(&0);
+      //^ &u32
+    take_u32(42 + p);
+}
+"#,
+    );
+}
+
+#[test]
+fn no_builtin_binop_expectation_for_general_ty_var() {
+    // FIXME: Ideally type mismatch should be reported on `take_u32(42 - p)`.
+    check_types(
+        r#"
+//- minicore: add
+use core::ops::Add;
+impl Add<i32> for i32 { type Output = i32; }
+impl Add<&i32> for i32 { type Output = i32; }
+// This is needed to prevent chalk from giving unique solution to `i32: Add<&?0>` after applying
+// fallback to integer type variable for `42`.
+impl Add<&()> for i32 { type Output = (); }
+
+struct V<T>;
+impl<T> V<T> {
+    fn default() -> Self { loop {} }
+    fn get(&self) -> &T { loop {} }
+}
+
+fn take_u32(_: u32) {}
+fn minimized() {
+    let v = V::default();
+    let p = v.get();
+      //^ &{unknown}
+    take_u32(42 + p);
+}
+"#,
+    );
+}
+
+#[test]
+fn no_builtin_binop_expectation_for_non_builtin_types() {
+    check_no_mismatches(
+        r#"
+//- minicore: default, eq
+struct S;
+impl Default for S { fn default() -> Self { S } }
+impl Default for i32 { fn default() -> Self { 0 } }
+impl PartialEq<S> for i32 { fn eq(&self, _: &S) -> bool { true } }
+impl PartialEq<i32> for i32 { fn eq(&self, _: &S) -> bool { true } }
+
+fn take_s(_: S) {}
+fn test() {
+    let s = Default::default();
+    let _eq = 0 == s;
+    take_s(s);
+}
+"#,
+    )
+}
+
+#[test]
 fn array_length() {
     check_infer(
         r#"