about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2025-02-03 14:03:50 +0000
committerGitHub <noreply@github.com>2025-02-03 14:03:50 +0000
commit93b72cefca52b4e1ea9dfb37700f7c8d87cd416c (patch)
tree88bf6642f3269578f9bb8798e8e48d7621fa8010
parent08d4b033a4386a5cfab77c4f7ca43e52022565f1 (diff)
parent806ec889633d725898fbb0fa67dbbd25ed88a01c (diff)
downloadrust-93b72cefca52b4e1ea9dfb37700f7c8d87cd416c.tar.gz
rust-93b72cefca52b4e1ea9dfb37700f7c8d87cd416c.zip
Merge pull request #19066 from alibektas/slice_pattern_type_inference
fix: try to infer array type from slice pattern
-rw-r--r--src/tools/rust-analyzer/crates/hir-ty/src/infer.rs2
-rw-r--r--src/tools/rust-analyzer/crates/hir-ty/src/infer/expr.rs23
-rw-r--r--src/tools/rust-analyzer/crates/hir-ty/src/infer/pat.rs125
-rw-r--r--src/tools/rust-analyzer/crates/hir-ty/src/lib.rs17
-rw-r--r--src/tools/rust-analyzer/crates/hir-ty/src/tests/simple.rs47
5 files changed, 183 insertions, 31 deletions
diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer.rs
index f4a018e2eec..617ebba8811 100644
--- a/src/tools/rust-analyzer/crates/hir-ty/src/infer.rs
+++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer.rs
@@ -946,7 +946,7 @@ impl<'a> InferenceContext<'a> {
             let ty = self.insert_type_vars(ty);
             let ty = self.normalize_associated_types_in(ty);
 
-            self.infer_top_pat(*pat, &ty);
+            self.infer_top_pat(*pat, &ty, None);
             if ty
                 .data(Interner)
                 .flags
diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/expr.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/expr.rs
index b951443897c..86e5afdb509 100644
--- a/src/tools/rust-analyzer/crates/hir-ty/src/infer/expr.rs
+++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/expr.rs
@@ -43,9 +43,9 @@ use crate::{
     primitive::{self, UintTy},
     static_lifetime, to_chalk_trait_id,
     traits::FnTrait,
-    Adjust, Adjustment, AdtId, AutoBorrow, Binders, CallableDefId, CallableSig, FnAbi, FnPointer,
-    FnSig, FnSubst, Interner, Rawness, Scalar, Substitution, TraitEnvironment, TraitRef, Ty,
-    TyBuilder, TyExt, TyKind,
+    Adjust, Adjustment, AdtId, AutoBorrow, Binders, CallableDefId, CallableSig, DeclContext,
+    DeclOrigin, FnAbi, FnPointer, FnSig, FnSubst, Interner, Rawness, Scalar, Substitution,
+    TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind,
 };
 
 use super::{
@@ -334,7 +334,11 @@ impl InferenceContext<'_> {
                     ExprIsRead::No
                 };
                 let input_ty = self.infer_expr(expr, &Expectation::none(), child_is_read);
-                self.infer_top_pat(pat, &input_ty);
+                self.infer_top_pat(
+                    pat,
+                    &input_ty,
+                    Some(DeclContext { origin: DeclOrigin::LetExpr }),
+                );
                 self.result.standard_types.bool_.clone()
             }
             Expr::Block { statements, tail, label, id } => {
@@ -461,7 +465,7 @@ impl InferenceContext<'_> {
 
                 // Now go through the argument patterns
                 for (arg_pat, arg_ty) in args.iter().zip(&sig_tys) {
-                    self.infer_top_pat(*arg_pat, arg_ty);
+                    self.infer_top_pat(*arg_pat, arg_ty, None);
                 }
 
                 // FIXME: lift these out into a struct
@@ -582,7 +586,7 @@ impl InferenceContext<'_> {
                     let mut all_arms_diverge = Diverges::Always;
                     for arm in arms.iter() {
                         let input_ty = self.resolve_ty_shallow(&input_ty);
-                        self.infer_top_pat(arm.pat, &input_ty);
+                        self.infer_top_pat(arm.pat, &input_ty, None);
                     }
 
                     let expected = expected.adjust_for_branches(&mut self.table);
@@ -927,7 +931,7 @@ impl InferenceContext<'_> {
                     let resolver_guard =
                         self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, tgt_expr);
                     self.inside_assignment = true;
-                    self.infer_top_pat(target, &rhs_ty);
+                    self.infer_top_pat(target, &rhs_ty, None);
                     self.inside_assignment = false;
                     self.resolver.reset_to_guard(resolver_guard);
                 }
@@ -1632,8 +1636,11 @@ impl InferenceContext<'_> {
                                 decl_ty
                             };
 
-                            this.infer_top_pat(*pat, &ty);
+                            let decl = DeclContext {
+                                origin: DeclOrigin::LocalDecl { has_else: else_branch.is_some() },
+                            };
 
+                            this.infer_top_pat(*pat, &ty, Some(decl));
                             if let Some(expr) = else_branch {
                                 let previous_diverges =
                                     mem::replace(&mut this.diverges, Diverges::Maybe);
diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/pat.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/pat.rs
index ca8d5bae5e5..5ff22bea34d 100644
--- a/src/tools/rust-analyzer/crates/hir-ty/src/infer/pat.rs
+++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/pat.rs
@@ -6,20 +6,21 @@ use hir_def::{
     expr_store::Body,
     hir::{Binding, BindingAnnotation, BindingId, Expr, ExprId, Literal, Pat, PatId},
     path::Path,
+    HasModule,
 };
 use hir_expand::name::Name;
 use stdx::TupleExt;
 
 use crate::{
-    consteval::{try_const_usize, usize_const},
+    consteval::{self, try_const_usize, usize_const},
     infer::{
         coerce::CoerceNever, expr::ExprIsRead, BindingMode, Expectation, InferenceContext,
         TypeMismatch,
     },
     lower::lower_to_chalk_mutability,
     primitive::UintTy,
-    static_lifetime, InferenceDiagnostic, Interner, Mutability, Scalar, Substitution, Ty,
-    TyBuilder, TyExt, TyKind,
+    static_lifetime, DeclContext, DeclOrigin, InferenceDiagnostic, Interner, Mutability, Scalar,
+    Substitution, Ty, TyBuilder, TyExt, TyKind,
 };
 
 impl InferenceContext<'_> {
@@ -34,6 +35,7 @@ impl InferenceContext<'_> {
         id: PatId,
         ellipsis: Option<u32>,
         subs: &[PatId],
+        decl: Option<DeclContext>,
     ) -> Ty {
         let (ty, def) = self.resolve_variant(id.into(), path, true);
         let var_data = def.map(|it| it.variant_data(self.db.upcast()));
@@ -92,13 +94,13 @@ impl InferenceContext<'_> {
                         }
                     };
 
-                    self.infer_pat(subpat, &expected_ty, default_bm);
+                    self.infer_pat(subpat, &expected_ty, default_bm, decl);
                 }
             }
             None => {
                 let err_ty = self.err_ty();
                 for &inner in subs {
-                    self.infer_pat(inner, &err_ty, default_bm);
+                    self.infer_pat(inner, &err_ty, default_bm, decl);
                 }
             }
         }
@@ -114,6 +116,7 @@ impl InferenceContext<'_> {
         default_bm: BindingMode,
         id: PatId,
         subs: impl ExactSizeIterator<Item = (Name, PatId)>,
+        decl: Option<DeclContext>,
     ) -> Ty {
         let (ty, def) = self.resolve_variant(id.into(), path, false);
         if let Some(variant) = def {
@@ -162,13 +165,13 @@ impl InferenceContext<'_> {
                         }
                     };
 
-                    self.infer_pat(inner, &expected_ty, default_bm);
+                    self.infer_pat(inner, &expected_ty, default_bm, decl);
                 }
             }
             None => {
                 let err_ty = self.err_ty();
                 for (_, inner) in subs {
-                    self.infer_pat(inner, &err_ty, default_bm);
+                    self.infer_pat(inner, &err_ty, default_bm, decl);
                 }
             }
         }
@@ -185,6 +188,7 @@ impl InferenceContext<'_> {
         default_bm: BindingMode,
         ellipsis: Option<u32>,
         subs: &[PatId],
+        decl: Option<DeclContext>,
     ) -> Ty {
         let expected = self.resolve_ty_shallow(expected);
         let expectations = match expected.as_tuple() {
@@ -209,12 +213,12 @@ impl InferenceContext<'_> {
 
         // Process pre
         for (ty, pat) in inner_tys.iter_mut().zip(pre) {
-            *ty = self.infer_pat(*pat, ty, default_bm);
+            *ty = self.infer_pat(*pat, ty, default_bm, decl);
         }
 
         // Process post
         for (ty, pat) in inner_tys.iter_mut().skip(pre.len() + n_uncovered_patterns).zip(post) {
-            *ty = self.infer_pat(*pat, ty, default_bm);
+            *ty = self.infer_pat(*pat, ty, default_bm, decl);
         }
 
         TyKind::Tuple(inner_tys.len(), Substitution::from_iter(Interner, inner_tys))
@@ -223,11 +227,17 @@ impl InferenceContext<'_> {
 
     /// The resolver needs to be updated to the surrounding expression when inside assignment
     /// (because there, `Pat::Path` can refer to a variable).
-    pub(super) fn infer_top_pat(&mut self, pat: PatId, expected: &Ty) {
-        self.infer_pat(pat, expected, BindingMode::default());
+    pub(super) fn infer_top_pat(&mut self, pat: PatId, expected: &Ty, decl: Option<DeclContext>) {
+        self.infer_pat(pat, expected, BindingMode::default(), decl);
     }
 
-    fn infer_pat(&mut self, pat: PatId, expected: &Ty, mut default_bm: BindingMode) -> Ty {
+    fn infer_pat(
+        &mut self,
+        pat: PatId,
+        expected: &Ty,
+        mut default_bm: BindingMode,
+        decl: Option<DeclContext>,
+    ) -> Ty {
         let mut expected = self.resolve_ty_shallow(expected);
 
         if matches!(&self.body[pat], Pat::Ref { .. }) || self.inside_assignment {
@@ -261,11 +271,11 @@ impl InferenceContext<'_> {
 
         let ty = match &self.body[pat] {
             Pat::Tuple { args, ellipsis } => {
-                self.infer_tuple_pat_like(&expected, default_bm, *ellipsis, args)
+                self.infer_tuple_pat_like(&expected, default_bm, *ellipsis, args, decl)
             }
             Pat::Or(pats) => {
                 for pat in pats.iter() {
-                    self.infer_pat(*pat, &expected, default_bm);
+                    self.infer_pat(*pat, &expected, default_bm, decl);
                 }
                 expected.clone()
             }
@@ -274,6 +284,7 @@ impl InferenceContext<'_> {
                 lower_to_chalk_mutability(mutability),
                 &expected,
                 default_bm,
+                decl,
             ),
             Pat::TupleStruct { path: p, args: subpats, ellipsis } => self
                 .infer_tuple_struct_pat_like(
@@ -283,10 +294,11 @@ impl InferenceContext<'_> {
                     pat,
                     *ellipsis,
                     subpats,
+                    decl,
                 ),
             Pat::Record { path: p, args: fields, ellipsis: _ } => {
                 let subs = fields.iter().map(|f| (f.name.clone(), f.pat));
-                self.infer_record_pat_like(p.as_deref(), &expected, default_bm, pat, subs)
+                self.infer_record_pat_like(p.as_deref(), &expected, default_bm, pat, subs, decl)
             }
             Pat::Path(path) => {
                 let ty = self.infer_path(path, pat.into()).unwrap_or_else(|| self.err_ty());
@@ -319,10 +331,10 @@ impl InferenceContext<'_> {
                 }
             }
             Pat::Bind { id, subpat } => {
-                return self.infer_bind_pat(pat, *id, default_bm, *subpat, &expected);
+                return self.infer_bind_pat(pat, *id, default_bm, *subpat, &expected, decl);
             }
             Pat::Slice { prefix, slice, suffix } => {
-                self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm)
+                self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm, decl)
             }
             Pat::Wild => expected.clone(),
             Pat::Range { .. } => {
@@ -345,7 +357,7 @@ impl InferenceContext<'_> {
                         _ => (self.result.standard_types.unknown.clone(), None),
                     };
 
-                    let inner_ty = self.infer_pat(*inner, &inner_ty, default_bm);
+                    let inner_ty = self.infer_pat(*inner, &inner_ty, default_bm, decl);
                     let mut b = TyBuilder::adt(self.db, box_adt).push(inner_ty);
 
                     if let Some(alloc_ty) = alloc_ty {
@@ -420,6 +432,7 @@ impl InferenceContext<'_> {
         mutability: Mutability,
         expected: &Ty,
         default_bm: BindingMode,
+        decl: Option<DeclContext>,
     ) -> Ty {
         let (expectation_type, expectation_lt) = match expected.as_reference() {
             Some((inner_ty, lifetime, _exp_mut)) => (inner_ty.clone(), lifetime.clone()),
@@ -433,7 +446,7 @@ impl InferenceContext<'_> {
                 (inner_ty, inner_lt)
             }
         };
-        let subty = self.infer_pat(inner_pat, &expectation_type, default_bm);
+        let subty = self.infer_pat(inner_pat, &expectation_type, default_bm, decl);
         TyKind::Ref(mutability, expectation_lt, subty).intern(Interner)
     }
 
@@ -444,6 +457,7 @@ impl InferenceContext<'_> {
         default_bm: BindingMode,
         subpat: Option<PatId>,
         expected: &Ty,
+        decl: Option<DeclContext>,
     ) -> Ty {
         let Binding { mode, .. } = self.body.bindings[binding];
         let mode = if mode == BindingAnnotation::Unannotated {
@@ -454,7 +468,7 @@ impl InferenceContext<'_> {
         self.result.binding_modes.insert(pat, mode);
 
         let inner_ty = match subpat {
-            Some(subpat) => self.infer_pat(subpat, expected, default_bm),
+            Some(subpat) => self.infer_pat(subpat, expected, default_bm, decl),
             None => expected.clone(),
         };
         let inner_ty = self.insert_type_vars_shallow(inner_ty);
@@ -478,14 +492,28 @@ impl InferenceContext<'_> {
         slice: &Option<PatId>,
         suffix: &[PatId],
         default_bm: BindingMode,
+        decl: Option<DeclContext>,
     ) -> Ty {
+        let expected = self.resolve_ty_shallow(expected);
+
+        // If `expected` is an infer ty, we try to equate it to an array if the given pattern
+        // allows it. See issue #16609
+        if self.pat_is_irrefutable(decl) && expected.is_ty_var() {
+            if let Some(resolved_array_ty) =
+                self.try_resolve_slice_ty_to_array_ty(prefix, suffix, slice)
+            {
+                self.unify(&expected, &resolved_array_ty);
+            }
+        }
+
+        let expected = self.resolve_ty_shallow(&expected);
         let elem_ty = match expected.kind(Interner) {
             TyKind::Array(st, _) | TyKind::Slice(st) => st.clone(),
             _ => self.err_ty(),
         };
 
         for &pat_id in prefix.iter().chain(suffix.iter()) {
-            self.infer_pat(pat_id, &elem_ty, default_bm);
+            self.infer_pat(pat_id, &elem_ty, default_bm, decl);
         }
 
         if let &Some(slice_pat_id) = slice {
@@ -499,7 +527,7 @@ impl InferenceContext<'_> {
                 _ => TyKind::Slice(elem_ty.clone()),
             }
             .intern(Interner);
-            self.infer_pat(slice_pat_id, &rest_pat_ty, default_bm);
+            self.infer_pat(slice_pat_id, &rest_pat_ty, default_bm, decl);
         }
 
         match expected.kind(Interner) {
@@ -553,6 +581,59 @@ impl InferenceContext<'_> {
             | Pat::Expr(_) => false,
         }
     }
+
+    fn try_resolve_slice_ty_to_array_ty(
+        &mut self,
+        before: &[PatId],
+        suffix: &[PatId],
+        slice: &Option<PatId>,
+    ) -> Option<Ty> {
+        if !slice.is_none() {
+            return None;
+        }
+
+        let len = before.len() + suffix.len();
+        let size =
+            consteval::usize_const(self.db, Some(len as u128), self.owner.krate(self.db.upcast()));
+
+        let elem_ty = self.table.new_type_var();
+        let array_ty = TyKind::Array(elem_ty.clone(), size).intern(Interner);
+        Some(array_ty)
+    }
+
+    /// Used to determine whether we can infer the expected type in the slice pattern to be of type array.
+    /// This is only possible if we're in an irrefutable pattern. If we were to allow this in refutable
+    /// patterns we wouldn't e.g. report ambiguity in the following situation:
+    ///
+    /// ```ignore(rust)
+    ///    struct Zeroes;
+    ///    const ARR: [usize; 2] = [0; 2];
+    ///    const ARR2: [usize; 2] = [2; 2];
+    ///
+    ///    impl Into<&'static [usize; 2]> for Zeroes {
+    ///        fn into(self) -> &'static [usize; 2] {
+    ///            &ARR
+    ///        }
+    ///    }
+    ///
+    ///    impl Into<&'static [usize]> for Zeroes {
+    ///        fn into(self) -> &'static [usize] {
+    ///            &ARR2
+    ///        }
+    ///    }
+    ///
+    ///    fn main() {
+    ///        let &[a, b]: &[usize] = Zeroes.into() else {
+    ///           ..
+    ///        };
+    ///    }
+    /// ```
+    ///
+    /// If we're in an irrefutable pattern we prefer the array impl candidate given that
+    /// the slice impl candidate would be rejected anyway (if no ambiguity existed).
+    fn pat_is_irrefutable(&self, decl_ctxt: Option<DeclContext>) -> bool {
+        matches!(decl_ctxt, Some(DeclContext { origin: DeclOrigin::LocalDecl { has_else: false } }))
+    }
 }
 
 pub(super) fn contains_explicit_ref_binding(body: &Body, pat_id: PatId) -> bool {
diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs b/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs
index 4b159b7541e..55d81875a2b 100644
--- a/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs
+++ b/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs
@@ -1049,3 +1049,20 @@ pub fn known_const_to_ast(
     }
     Some(make::expr_const_value(konst.display(db, edition).to_string().as_str()))
 }
+
+#[derive(Debug, Copy, Clone)]
+pub(crate) enum DeclOrigin {
+    LetExpr,
+    /// from `let x = ..`
+    LocalDecl {
+        has_else: bool,
+    },
+}
+
+/// Provides context for checking patterns in declarations. More specifically this
+/// allows us to infer array types if the pattern is irrefutable and allows us to infer
+/// the size of the array. See issue rust-lang/rust#76342.
+#[derive(Debug, Copy, Clone)]
+pub(crate) struct DeclContext {
+    pub(crate) origin: DeclOrigin,
+}
diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/simple.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/simple.rs
index 15636604570..50a1ecd006d 100644
--- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/simple.rs
+++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/simple.rs
@@ -3814,3 +3814,50 @@ async fn foo(a: (), b: i32) -> u32 {
         "#,
     );
 }
+
+#[test]
+fn irrefutable_slices() {
+    check_infer(
+        r#"
+//- minicore: from
+struct A;
+
+impl From<A> for [u8; 2] {
+    fn from(a: A) -> Self {
+        [0; 2]
+    }
+}
+impl From<A> for [u8; 3] {
+    fn from(a: A) -> Self {
+        [0; 3]
+    }
+}
+
+
+fn main() {
+    let a = A;
+    let [b, c] = a.into();
+}
+"#,
+        expect![[r#"
+            50..51 'a': A
+            64..86 '{     ...     }': [u8; 2]
+            74..80 '[0; 2]': [u8; 2]
+            75..76 '0': u8
+            78..79 '2': usize
+            128..129 'a': A
+            142..164 '{     ...     }': [u8; 3]
+            152..158 '[0; 3]': [u8; 3]
+            153..154 '0': u8
+            156..157 '3': usize
+            179..224 '{     ...o(); }': ()
+            189..190 'a': A
+            193..194 'A': A
+            204..210 '[b, c]': [u8; 2]
+            205..206 'b': u8
+            208..209 'c': u8
+            213..214 'a': A
+            213..221 'a.into()': [u8; 2]
+        "#]],
+    );
+}