diff options
| -rw-r--r-- | crates/hir_ty/src/infer/expr.rs | 74 | ||||
| -rw-r--r-- | crates/hir_ty/src/tests/simple.rs | 20 |
2 files changed, 87 insertions, 7 deletions
diff --git a/crates/hir_ty/src/infer/expr.rs b/crates/hir_ty/src/infer/expr.rs index be00d8ab376..43cff92f233 100644 --- a/crates/hir_ty/src/infer/expr.rs +++ b/crates/hir_ty/src/infer/expr.rs @@ -28,7 +28,7 @@ use crate::{ lower::{ const_or_path_to_chalk, generic_arg_to_chalk, lower_to_chalk_mutability, ParamLoweringMode, }, - mapping::from_chalk, + mapping::{from_chalk, ToChalk}, method_resolution, primitive::{self, UintTy}, static_lifetime, to_chalk_trait_id, @@ -279,14 +279,16 @@ impl<'a> InferenceContext<'a> { let callee_ty = self.infer_expr(*callee, &Expectation::none()); let mut derefs = Autoderef::new(&mut self.table, callee_ty.clone()); let mut res = None; + let mut derefed_callee = callee_ty.clone(); // manual loop to be able to access `derefs.table` while let Some((callee_deref_ty, _)) = derefs.next() { res = derefs.table.callable_sig(&callee_deref_ty, args.len()); if res.is_some() { + derefed_callee = callee_deref_ty; break; } } - let (param_tys, ret_ty): (Vec<Ty>, Ty) = match res { + let (param_tys, ret_ty) = match res { Some(res) => { let adjustments = auto_deref_adjust_steps(&derefs); self.write_expr_adj(*callee, adjustments); @@ -294,6 +296,7 @@ impl<'a> InferenceContext<'a> { } None => (Vec::new(), self.err_ty()), }; + let indices_to_skip = self.check_legacy_const_generics(derefed_callee, args); self.register_obligations_for_call(&callee_ty); let expected_inputs = self.expected_inputs_for_expected_output( @@ -302,7 +305,7 @@ impl<'a> InferenceContext<'a> { param_tys.clone(), ); - self.check_call_arguments(args, &expected_inputs, ¶m_tys); + self.check_call_arguments(args, &expected_inputs, ¶m_tys, &indices_to_skip); self.normalize_associated_types_in(ret_ty) } Expr::MethodCall { receiver, args, method_name, generic_args } => self @@ -952,7 +955,7 @@ impl<'a> InferenceContext<'a> { let expected_inputs = self.expected_inputs_for_expected_output(expected, ret_ty.clone(), param_tys.clone()); - self.check_call_arguments(args, &expected_inputs, ¶m_tys); + self.check_call_arguments(args, &expected_inputs, ¶m_tys, &[]); self.normalize_associated_types_in(ret_ty) } @@ -983,24 +986,40 @@ impl<'a> InferenceContext<'a> { } } - fn check_call_arguments(&mut self, args: &[ExprId], expected_inputs: &[Ty], param_tys: &[Ty]) { + fn check_call_arguments( + &mut self, + args: &[ExprId], + expected_inputs: &[Ty], + param_tys: &[Ty], + skip_indices: &[u32], + ) { // Quoting https://github.com/rust-lang/rust/blob/6ef275e6c3cb1384ec78128eceeb4963ff788dca/src/librustc_typeck/check/mod.rs#L3325 -- // We do this in a pretty awful way: first we type-check any arguments // that are not closures, then we type-check the closures. This is so // that we have more information about the types of arguments when we // type-check the functions. This isn't really the right way to do this. for &check_closures in &[false, true] { + let mut skip_indices = skip_indices.into_iter().copied().fuse().peekable(); let param_iter = param_tys.iter().cloned().chain(repeat(self.err_ty())); let expected_iter = expected_inputs .iter() .cloned() .chain(param_iter.clone().skip(expected_inputs.len())); - for ((&arg, param_ty), expected_ty) in args.iter().zip(param_iter).zip(expected_iter) { + for (idx, ((&arg, param_ty), expected_ty)) in + args.iter().zip(param_iter).zip(expected_iter).enumerate() + { let is_closure = matches!(&self.body[arg], Expr::Lambda { .. }); if is_closure != check_closures { continue; } + while skip_indices.peek().map_or(false, |i| *i < idx as u32) { + skip_indices.next(); + } + if skip_indices.peek().copied() == Some(idx as u32) { + continue; + } + // the difference between param_ty and expected here is that // expected is the parameter when the expected *return* type is // taken into account. So in `let _: &[i32] = identity(&[1, 2])` @@ -1140,6 +1159,49 @@ impl<'a> InferenceContext<'a> { } } + /// Returns the argument indices to skip. + fn check_legacy_const_generics(&mut self, callee: Ty, args: &[ExprId]) -> Vec<u32> { + let (func, subst) = match callee.kind(Interner) { + TyKind::FnDef(fn_id, subst) => { + let callable = CallableDefId::from_chalk(self.db, *fn_id); + let func = match callable { + CallableDefId::FunctionId(f) => f, + _ => return Vec::new(), + }; + (func, subst) + } + _ => return Vec::new(), + }; + + let data = self.db.function_data(func); + if data.legacy_const_generics_indices.is_empty() { + return Vec::new(); + } + + // only use legacy const generics if the param count matches with them + if data.params.len() + data.legacy_const_generics_indices.len() != args.len() { + return Vec::new(); + } + + // check legacy const parameters + for (subst_idx, arg_idx) in data.legacy_const_generics_indices.iter().copied().enumerate() { + let arg = match subst.at(Interner, subst_idx).constant(Interner) { + Some(c) => c, + None => continue, // not a const parameter? + }; + if arg_idx >= args.len() as u32 { + continue; + } + let _ty = arg.data(Interner).ty.clone(); + let expected = Expectation::none(); // FIXME use actual const ty, when that is lowered correctly + self.infer_expr(args[arg_idx as usize], &expected); + // FIXME: evaluate and unify with the const + } + let mut indices = data.legacy_const_generics_indices.clone(); + indices.sort(); + indices + } + fn builtin_binary_op_return_ty(&mut self, op: BinaryOp, lhs_ty: Ty, rhs_ty: Ty) -> Option<Ty> { let lhs_ty = self.resolve_ty_shallow(&lhs_ty); let rhs_ty = self.resolve_ty_shallow(&rhs_ty); diff --git a/crates/hir_ty/src/tests/simple.rs b/crates/hir_ty/src/tests/simple.rs index 0d050f7461b..de27c294f61 100644 --- a/crates/hir_ty/src/tests/simple.rs +++ b/crates/hir_ty/src/tests/simple.rs @@ -1,6 +1,6 @@ use expect_test::expect; -use super::{check_infer, check_types}; +use super::{check_infer, check_no_mismatches, check_types}; #[test] fn infer_box() { @@ -2624,3 +2624,21 @@ pub mod prelude { "#, ); } + +#[test] +fn legacy_const_generics() { + check_no_mismatches( + r#" +#[rustc_legacy_const_generics(1, 3)] +fn mixed<const N1: &'static str, const N2: bool>( + a: u8, + b: i8, +) {} + +fn f() { + mixed(0, "", -1, true); + mixed::<"", true>(0, -1); +} + "#, + ); +} |
