about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors[bot] <26634292+bors[bot]@users.noreply.github.com>2020-10-15 18:02:27 +0000
committerGitHub <noreply@github.com>2020-10-15 18:02:27 +0000
commit0d45802d671f94cb768b93a64882733396cfbe2d (patch)
treeabb91645ee84216304065bf42a959223d040814c
parent1de202010948c94658235f7cfe9b25dda0c7ddf3 (diff)
parent0e9d1e17d6e45b44ec1a8f1430109cfc75e41241 (diff)
downloadrust-0d45802d671f94cb768b93a64882733396cfbe2d.tar.gz
rust-0d45802d671f94cb768b93a64882733396cfbe2d.zip
Merge #6220
6220: implement binary operator overloading type inference r=flodiebold a=ruabmbua

Extend type inference of *binary operator expression*, by adding support for operator overloads.

Before this merge request, the type inference of binary expressions could only resolve operations done on built-in primitive types. This merge requests adds a code path, which is executed in case the built-in inference could not get any results. It resolves the proper operator overload trait in *core::ops* via lang items, and then resolves the associated *Output* type.

```rust
struct V2([f32; 2]);

#[lang = "add"]
pub trait Add<Rhs = Self> {
    /// The resulting type after applying the `+` operator.
    type Output;

    /// Performs the `+` operation.
    #[must_use]
    fn add(self, rhs: Rhs) -> Self::Output;
}

impl Add<V2> for V2 {
    type Output = V2;

    fn add(self, rhs: V2) -> V2 {
        let x = self.0[0] + rhs.0[0];
        let y = self.0[1] + rhs.0[1];
        V2([x, y])
    }
}

fn test() {
    let va = V2([0.0, 1.0]);
    let vb = V2([0.0, 1.0]);

    let r = va + vb; // This infers to V2 now
}
```

There is a problem with operator overloads, which do not explicitly set the *Rhs* type parameter in the respective impl block. 

**Example:**

```rust
impl Add for V2 {
    type Output = V2;

    fn add(self, rhs: V2) -> V2 {
        let x = self.0[0] + rhs.0[0];
        let y = self.0[1] + rhs.0[1];
        V2([x, y])
    }
}
```

In this case, the trait solver does not realize, that the *Rhs* type parameter is actually self in the context of the impl block. This stops type inference in its tracks, and it can not resolve the associated *Output* type.

I guess we can still merge this back, because it increases the amount of resolved types, and does not regress anything (in the tests).

Somewhat blocked by https://github.com/rust-analyzer/rust-analyzer/issues/5685
Resolves  https://github.com/rust-analyzer/rust-analyzer/issues/5544

Co-authored-by: Roland Ruckerbauer <roland.rucky@gmail.com>
-rw-r--r--crates/hir_ty/src/infer.rs24
-rw-r--r--crates/hir_ty/src/infer/expr.rs18
-rw-r--r--crates/hir_ty/src/tests/simple.rs89
3 files changed, 126 insertions, 5 deletions
diff --git a/crates/hir_ty/src/infer.rs b/crates/hir_ty/src/infer.rs
index 9a7785c763d..644ebd42d36 100644
--- a/crates/hir_ty/src/infer.rs
+++ b/crates/hir_ty/src/infer.rs
@@ -22,7 +22,7 @@ use arena::map::ArenaMap;
 use hir_def::{
     body::Body,
     data::{ConstData, FunctionData, StaticData},
-    expr::{BindingAnnotation, ExprId, PatId},
+    expr::{ArithOp, BinaryOp, BindingAnnotation, ExprId, PatId},
     lang_item::LangItemTarget,
     path::{path, Path},
     resolver::{HasResolver, Resolver, TypeNs},
@@ -586,6 +586,28 @@ impl<'a> InferenceContext<'a> {
         self.db.trait_data(trait_).associated_type_by_name(&name![Output])
     }
 
+    fn resolve_binary_op_output(&self, bop: &BinaryOp) -> Option<TypeAliasId> {
+        let lang_item = match bop {
+            BinaryOp::ArithOp(aop) => match aop {
+                ArithOp::Add => "add",
+                ArithOp::Sub => "sub",
+                ArithOp::Mul => "mul",
+                ArithOp::Div => "div",
+                ArithOp::Shl => "shl",
+                ArithOp::Shr => "shr",
+                ArithOp::Rem => "rem",
+                ArithOp::BitXor => "bitxor",
+                ArithOp::BitOr => "bitor",
+                ArithOp::BitAnd => "bitand",
+            },
+            _ => return None,
+        };
+
+        let trait_ = self.resolve_lang_item(lang_item)?.as_trait();
+
+        self.db.trait_data(trait_?).associated_type_by_name(&name![Output])
+    }
+
     fn resolve_boxed_box(&self) -> Option<AdtId> {
         let struct_ = self.resolve_lang_item("owned_box")?.as_struct()?;
         Some(struct_.into())
diff --git a/crates/hir_ty/src/infer/expr.rs b/crates/hir_ty/src/infer/expr.rs
index 0a141b9cb94..8ac4cf89a0c 100644
--- a/crates/hir_ty/src/infer/expr.rs
+++ b/crates/hir_ty/src/infer/expr.rs
@@ -12,6 +12,7 @@ use hir_def::{
 };
 use hir_expand::name::{name, Name};
 use syntax::ast::RangeOp;
+use test_utils::mark;
 
 use crate::{
     autoderef, method_resolution, op,
@@ -531,13 +532,22 @@ impl<'a> InferenceContext<'a> {
                         _ => Expectation::none(),
                     };
                     let lhs_ty = self.infer_expr(*lhs, &lhs_expectation);
-                    // FIXME: find implementation of trait corresponding to operation
-                    // symbol and resolve associated `Output` type
                     let rhs_expectation = op::binary_op_rhs_expectation(*op, lhs_ty.clone());
                     let rhs_ty = self.infer_expr(*rhs, &Expectation::has_type(rhs_expectation));
 
-                    // FIXME: similar as above, return ty is often associated trait type
-                    op::binary_op_return_ty(*op, lhs_ty, rhs_ty)
+                    let ret = op::binary_op_return_ty(*op, lhs_ty.clone(), rhs_ty.clone());
+
+                    if ret == Ty::Unknown {
+                        mark::hit!(infer_expr_inner_binary_operator_overload);
+
+                        self.resolve_associated_type_with_params(
+                            lhs_ty,
+                            self.resolve_binary_op_output(op),
+                            &[rhs_ty],
+                        )
+                    } else {
+                        ret
+                    }
                 }
                 _ => Ty::Unknown,
             },
diff --git a/crates/hir_ty/src/tests/simple.rs b/crates/hir_ty/src/tests/simple.rs
index 5b07948f3da..4f72582b6bd 100644
--- a/crates/hir_ty/src/tests/simple.rs
+++ b/crates/hir_ty/src/tests/simple.rs
@@ -1,4 +1,5 @@
 use expect_test::expect;
+use test_utils::mark;
 
 use super::{check_infer, check_types};
 
@@ -2225,3 +2226,91 @@ fn generic_default_depending_on_other_type_arg_forward() {
         "#]],
     );
 }
+
+#[test]
+fn infer_operator_overload() {
+    mark::check!(infer_expr_inner_binary_operator_overload);
+
+    check_infer(
+        r#"
+        struct V2([f32; 2]);
+
+        #[lang = "add"]
+        pub trait Add<Rhs = Self> {
+            /// The resulting type after applying the `+` operator.
+            type Output;
+
+            /// Performs the `+` operation.
+            #[must_use]
+            fn add(self, rhs: Rhs) -> Self::Output;
+        }
+
+        impl Add<V2> for V2 {
+            type Output = V2;
+
+            fn add(self, rhs: V2) -> V2 {
+                let x = self.0[0] + rhs.0[0];
+                let y = self.0[1] + rhs.0[1];
+                V2([x, y])
+            }
+        }
+
+        fn test() {
+            let va = V2([0.0, 1.0]);
+            let vb = V2([0.0, 1.0]);
+
+            let r = va + vb;
+        }
+
+        "#,
+        expect![[r#"
+            207..211 'self': Self
+            213..216 'rhs': Rhs
+            299..303 'self': V2
+            305..308 'rhs': V2
+            320..422 '{     ...     }': V2
+            334..335 'x': f32
+            338..342 'self': V2
+            338..344 'self.0': [f32; _]
+            338..347 'self.0[0]': {unknown}
+            338..358 'self.0...s.0[0]': f32
+            345..346 '0': i32
+            350..353 'rhs': V2
+            350..355 'rhs.0': [f32; _]
+            350..358 'rhs.0[0]': {unknown}
+            356..357 '0': i32
+            372..373 'y': f32
+            376..380 'self': V2
+            376..382 'self.0': [f32; _]
+            376..385 'self.0[1]': {unknown}
+            376..396 'self.0...s.0[1]': f32
+            383..384 '1': i32
+            388..391 'rhs': V2
+            388..393 'rhs.0': [f32; _]
+            388..396 'rhs.0[1]': {unknown}
+            394..395 '1': i32
+            406..408 'V2': V2([f32; _]) -> V2
+            406..416 'V2([x, y])': V2
+            409..415 '[x, y]': [f32; _]
+            410..411 'x': f32
+            413..414 'y': f32
+            436..519 '{     ... vb; }': ()
+            446..448 'va': V2
+            451..453 'V2': V2([f32; _]) -> V2
+            451..465 'V2([0.0, 1.0])': V2
+            454..464 '[0.0, 1.0]': [f32; _]
+            455..458 '0.0': f32
+            460..463 '1.0': f32
+            475..477 'vb': V2
+            480..482 'V2': V2([f32; _]) -> V2
+            480..494 'V2([0.0, 1.0])': V2
+            483..493 '[0.0, 1.0]': [f32; _]
+            484..487 '0.0': f32
+            489..492 '1.0': f32
+            505..506 'r': V2
+            509..511 'va': V2
+            509..516 'va + vb': V2
+            514..516 'vb': V2
+        "#]],
+    );
+}