about summary refs log tree commit diff
diff options
context:
space:
mode:
authorhkalbasi <hamidrezakalbasi@protonmail.com>2023-03-17 19:10:25 +0330
committerhkalbasi <hamidrezakalbasi@protonmail.com>2023-03-17 19:10:25 +0330
commit9ad83deeccd8cb3d375b5558eb7e3d339b1a4e0b (patch)
treee533c936fd7fed5481422d1edcfa82ed61ac9868
parenteb4939e217960ee77d79ec436a39f3cead646de4 (diff)
downloadrust-9ad83deeccd8cb3d375b5558eb7e3d339b1a4e0b.tar.gz
rust-9ad83deeccd8cb3d375b5558eb7e3d339b1a4e0b.zip
Support overloaded index MIR lowering
-rw-r--r--crates/hir-ty/src/consteval/tests.rs46
-rw-r--r--crates/hir-ty/src/infer/mutability.rs15
-rw-r--r--crates/hir-ty/src/mir/lower/as_place.rs63
-rw-r--r--crates/ide-diagnostics/src/handlers/mutability_errors.rs48
4 files changed, 169 insertions, 3 deletions
diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs
index a658bfa0c9a..944912b6c26 100644
--- a/crates/hir-ty/src/consteval/tests.rs
+++ b/crates/hir-ty/src/consteval/tests.rs
@@ -2,7 +2,8 @@ use base_db::fixture::WithFixture;
 use hir_def::db::DefDatabase;
 
 use crate::{
-    consteval::try_const_usize, db::HirDatabase, test_db::TestDB, Const, ConstScalar, Interner,
+    consteval::try_const_usize, db::HirDatabase, mir::pad16, test_db::TestDB, Const, ConstScalar,
+    Interner,
 };
 
 use super::{
@@ -30,7 +31,12 @@ fn check_number(ra_fixture: &str, answer: i128) {
     match &r.data(Interner).value {
         chalk_ir::ConstValue::Concrete(c) => match &c.interned {
             ConstScalar::Bytes(b, _) => {
-                assert_eq!(b, &answer.to_le_bytes()[0..b.len()]);
+                assert_eq!(
+                    b,
+                    &answer.to_le_bytes()[0..b.len()],
+                    "Bytes differ. In decimal form: actual = {}, expected = {answer}",
+                    i128::from_le_bytes(pad16(b, true))
+                );
             }
             x => panic!("Expected number but found {:?}", x),
         },
@@ -216,6 +222,42 @@ fn overloaded_deref_autoref() {
 }
 
 #[test]
+fn overloaded_index() {
+    check_number(
+        r#"
+    //- minicore: index
+    struct Foo;
+
+    impl core::ops::Index<usize> for Foo {
+        type Output = i32;
+        fn index(&self, index: usize) -> &i32 {
+            if index == 7 {
+                &700
+            } else {
+                &1000
+            }
+        }
+    }
+
+    impl core::ops::IndexMut<usize> for Foo {
+        fn index_mut(&mut self, index: usize) -> &mut i32 {
+            if index == 7 {
+                &mut 7
+            } else {
+                &mut 10
+            }
+        }
+    }
+
+    const GOAL: i32 = {
+        (Foo[2]) + (Foo[7]) + (*&Foo[2]) + (*&Foo[7]) + (*&mut Foo[2]) + (*&mut Foo[7])
+    };
+    "#,
+        3417,
+    );
+}
+
+#[test]
 fn function_call() {
     check_number(
         r#"
diff --git a/crates/hir-ty/src/infer/mutability.rs b/crates/hir-ty/src/infer/mutability.rs
index 8e3d71788f2..784725da935 100644
--- a/crates/hir-ty/src/infer/mutability.rs
+++ b/crates/hir-ty/src/infer/mutability.rs
@@ -95,6 +95,21 @@ impl<'a> InferenceContext<'a> {
                 self.infer_mut_not_expr_iter(fields.iter().map(|x| x.expr).chain(*spread))
             }
             &Expr::Index { base, index } => {
+                if let Some((f, _)) = self.result.method_resolutions.get_mut(&tgt_expr) {
+                    if mutability == Mutability::Mut {
+                        if let Some(index_trait) = self
+                            .db
+                            .lang_item(self.table.trait_env.krate, LangItem::IndexMut)
+                            .and_then(|l| l.as_trait())
+                        {
+                            if let Some(index_fn) =
+                                self.db.trait_data(index_trait).method_by_name(&name![index_mut])
+                            {
+                                *f = index_fn;
+                            }
+                        }
+                    }
+                }
                 self.infer_mut_expr(base, mutability);
                 self.infer_mut_expr(index, Mutability::Not);
             }
diff --git a/crates/hir-ty/src/mir/lower/as_place.rs b/crates/hir-ty/src/mir/lower/as_place.rs
index c6f4f66ada0..425904850ba 100644
--- a/crates/hir-ty/src/mir/lower/as_place.rs
+++ b/crates/hir-ty/src/mir/lower/as_place.rs
@@ -1,6 +1,7 @@
 //! MIR lowering for places
 
 use super::*;
+use hir_def::FunctionId;
 use hir_expand::name;
 
 macro_rules! not_supported {
@@ -193,7 +194,24 @@ impl MirLowerCtx<'_> {
                 if index_ty != TyBuilder::usize()
                     || !matches!(base_ty.kind(Interner), TyKind::Array(..) | TyKind::Slice(..))
                 {
-                    not_supported!("overloaded index");
+                    let Some(index_fn) = self.infer.method_resolution(expr_id) else {
+                        return Err(MirLowerError::UnresolvedMethod);
+                    };
+                    let Some((base_place, current)) = self.lower_expr_as_place(current, *base, true)? else {
+                        return Ok(None);
+                    };
+                    let Some((index_operand, current)) = self.lower_expr_to_some_operand(*index, current)? else {
+                        return Ok(None);
+                    };
+                    return self.lower_overloaded_index(
+                        current,
+                        base_place,
+                        self.expr_ty_after_adjustments(*base),
+                        self.expr_ty(expr_id),
+                        index_operand,
+                        expr_id.into(),
+                        index_fn,
+                    );
                 }
                 let Some((mut p_base, current)) =
                     self.lower_expr_as_place(current, *base, true)? else {
@@ -210,6 +228,49 @@ impl MirLowerCtx<'_> {
         }
     }
 
+    fn lower_overloaded_index(
+        &mut self,
+        current: BasicBlockId,
+        place: Place,
+        base_ty: Ty,
+        result_ty: Ty,
+        index_operand: Operand,
+        span: MirSpan,
+        index_fn: (FunctionId, Substitution),
+    ) -> Result<Option<(Place, BasicBlockId)>> {
+        let is_mutable = 'b: {
+            if let Some(index_mut_trait) = self.resolve_lang_item(LangItem::IndexMut)?.as_trait() {
+                if let Some(index_mut_fn) =
+                    self.db.trait_data(index_mut_trait).method_by_name(&name![index_mut])
+                {
+                    break 'b index_mut_fn == index_fn.0;
+                }
+            }
+            false
+        };
+        let (mutability, borrow_kind) = match is_mutable {
+            true => (Mutability::Mut, BorrowKind::Mut { allow_two_phase_borrow: false }),
+            false => (Mutability::Not, BorrowKind::Shared),
+        };
+        let base_ref = TyKind::Ref(mutability, static_lifetime(), base_ty).intern(Interner);
+        let result_ref = TyKind::Ref(mutability, static_lifetime(), result_ty).intern(Interner);
+        let ref_place: Place = self.temp(base_ref)?.into();
+        self.push_assignment(current, ref_place.clone(), Rvalue::Ref(borrow_kind, place), span);
+        let mut result: Place = self.temp(result_ref)?.into();
+        let index_fn_op = Operand::const_zst(
+            TyKind::FnDef(
+                self.db.intern_callable_def(CallableDefId::FunctionId(index_fn.0)).into(),
+                index_fn.1,
+            )
+            .intern(Interner),
+        );
+        let Some(current) = self.lower_call(index_fn_op, vec![Operand::Copy(ref_place), index_operand], result.clone(), current, false)? else {
+            return Ok(None);
+        };
+        result.projection.push(ProjectionElem::Deref);
+        Ok(Some((result, current)))
+    }
+
     fn lower_overloaded_deref(
         &mut self,
         current: BasicBlockId,
diff --git a/crates/ide-diagnostics/src/handlers/mutability_errors.rs b/crates/ide-diagnostics/src/handlers/mutability_errors.rs
index 83c61f73db6..17a70f5701b 100644
--- a/crates/ide-diagnostics/src/handlers/mutability_errors.rs
+++ b/crates/ide-diagnostics/src/handlers/mutability_errors.rs
@@ -565,6 +565,54 @@ fn f(x: [(i32, u8); 10]) {
     }
 
     #[test]
+    fn overloaded_index() {
+        check_diagnostics(
+            r#"
+//- minicore: index
+use core::ops::{Index, IndexMut};
+
+struct Foo;
+impl Index<usize> for Foo {
+    type Output = (i32, u8);
+    fn index(&self, index: usize) -> &(i32, u8) {
+        &(5, 2)
+    }
+}
+impl IndexMut<usize> for Foo {
+    fn index_mut(&mut self, index: usize) -> &mut (i32, u8) {
+        &mut (5, 2)
+    }
+}
+fn f() {
+    let mut x = Foo;
+      //^^^^^ 💡 weak: variable does not need to be mutable
+    let y = &x[2];
+    let x = Foo;
+    let y = &mut x[2];
+               //^^^^ 💡 error: cannot mutate immutable variable `x`
+    let mut x = &mut Foo;
+      //^^^^^ 💡 weak: variable does not need to be mutable
+    let y: &mut (i32, u8) = &mut x[2];
+    let x = Foo;
+    let ref mut y = x[7];
+                  //^^^^ 💡 error: cannot mutate immutable variable `x`
+    let (ref mut y, _) = x[3];
+                       //^^^^ 💡 error: cannot mutate immutable variable `x`
+    match x[10] {
+        //^^^^^ 💡 error: cannot mutate immutable variable `x`
+        (ref y, _) => (),
+        (_, ref mut y) => (),
+    }
+    let mut x = Foo;
+    let mut i = 5;
+      //^^^^^ 💡 weak: variable does not need to be mutable
+    let y = &mut x[i];
+}
+"#,
+        );
+    }
+
+    #[test]
     fn overloaded_deref() {
         check_diagnostics(
             r#"