about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRyo Yoshida <low.ryoshida@gmail.com>2023-07-12 23:46:23 +0900
committerRyo Yoshida <low.ryoshida@gmail.com>2023-07-13 00:14:15 +0900
commit074488b290732092d077c8271fdcc2c6a91ecede (patch)
treed7fc5524e60697a37ee8d23522e21f34d2bd407c
parent75ac37f317269798da87d42738d79cade3a28ab9 (diff)
downloadrust-074488b290732092d077c8271fdcc2c6a91ecede.tar.gz
rust-074488b290732092d077c8271fdcc2c6a91ecede.zip
Properly infer types with type casts
-rw-r--r--crates/hir-ty/src/infer.rs34
-rw-r--r--crates/hir-ty/src/infer/cast.rs46
-rw-r--r--crates/hir-ty/src/infer/expr.rs18
-rw-r--r--crates/hir-ty/src/tests/regression.rs20
-rw-r--r--crates/hir-ty/src/tests/simple.rs22
5 files changed, 112 insertions, 28 deletions
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs
index 0a617dae7d4..b4915dbf0f9 100644
--- a/crates/hir-ty/src/infer.rs
+++ b/crates/hir-ty/src/infer.rs
@@ -13,6 +13,15 @@
 //! to certain types. To record this, we use the union-find implementation from
 //! the `ena` crate, which is extracted from rustc.
 
+mod cast;
+pub(crate) mod closure;
+mod coerce;
+mod expr;
+mod mutability;
+mod pat;
+mod path;
+pub(crate) mod unify;
+
 use std::{convert::identity, ops::Index};
 
 use chalk_ir::{
@@ -60,15 +69,8 @@ pub use coerce::could_coerce;
 #[allow(unreachable_pub)]
 pub use unify::could_unify;
 
-pub(crate) use self::closure::{CaptureKind, CapturedItem, CapturedItemWithoutTy};
-
-pub(crate) mod unify;
-mod path;
-mod expr;
-mod pat;
-mod coerce;
-pub(crate) mod closure;
-mod mutability;
+use cast::CastCheck;
+pub(crate) use closure::{CaptureKind, CapturedItem, CapturedItemWithoutTy};
 
 /// The entry point of type inference.
 pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<InferenceResult> {
@@ -508,6 +510,8 @@ pub(crate) struct InferenceContext<'a> {
     diverges: Diverges,
     breakables: Vec<BreakableContext>,
 
+    deferred_cast_checks: Vec<CastCheck>,
+
     // fields related to closure capture
     current_captures: Vec<CapturedItemWithoutTy>,
     current_closure: Option<ClosureId>,
@@ -582,7 +586,8 @@ impl<'a> InferenceContext<'a> {
             resolver,
             diverges: Diverges::Maybe,
             breakables: Vec::new(),
-            current_captures: vec![],
+            deferred_cast_checks: Vec::new(),
+            current_captures: Vec::new(),
             current_closure: None,
             deferred_closures: FxHashMap::default(),
             closure_dependencies: FxHashMap::default(),
@@ -594,7 +599,7 @@ impl<'a> InferenceContext<'a> {
     // used this function for another workaround, mention it here. If you really need this function and believe that
     // there is no problem in it being `pub(crate)`, remove this comment.
     pub(crate) fn resolve_all(self) -> InferenceResult {
-        let InferenceContext { mut table, mut result, .. } = self;
+        let InferenceContext { mut table, mut result, deferred_cast_checks, .. } = self;
         // Destructure every single field so whenever new fields are added to `InferenceResult` we
         // don't forget to handle them here.
         let InferenceResult {
@@ -622,6 +627,13 @@ impl<'a> InferenceContext<'a> {
 
         table.fallback_if_possible();
 
+        // Comment from rustc:
+        // Even though coercion casts provide type hints, we check casts after fallback for
+        // backwards compatibility. This makes fallback a stronger type hint than a cast coercion.
+        for cast in deferred_cast_checks {
+            cast.check(&mut table);
+        }
+
         // FIXME resolve obligations as well (use Guidance if necessary)
         table.resolve_obligations_as_possible();
 
diff --git a/crates/hir-ty/src/infer/cast.rs b/crates/hir-ty/src/infer/cast.rs
new file mode 100644
index 00000000000..9e1c74b16fa
--- /dev/null
+++ b/crates/hir-ty/src/infer/cast.rs
@@ -0,0 +1,46 @@
+//! Type cast logic. Basically coercion + additional casts.
+
+use crate::{infer::unify::InferenceTable, Interner, Ty, TyExt, TyKind};
+
+#[derive(Clone, Debug)]
+pub(super) struct CastCheck {
+    expr_ty: Ty,
+    cast_ty: Ty,
+}
+
+impl CastCheck {
+    pub(super) fn new(expr_ty: Ty, cast_ty: Ty) -> Self {
+        Self { expr_ty, cast_ty }
+    }
+
+    pub(super) fn check(self, table: &mut InferenceTable<'_>) {
+        // FIXME: This function currently only implements the bits that influence the type
+        // inference. We should return the adjustments on success and report diagnostics on error.
+        let expr_ty = table.resolve_ty_shallow(&self.expr_ty);
+        let cast_ty = table.resolve_ty_shallow(&self.cast_ty);
+
+        if expr_ty.contains_unknown() || cast_ty.contains_unknown() {
+            return;
+        }
+
+        if table.coerce(&expr_ty, &cast_ty).is_ok() {
+            return;
+        }
+
+        if check_ref_to_ptr_cast(expr_ty, cast_ty, table) {
+            // Note that this type of cast is actually split into a coercion to a
+            // pointer type and a cast:
+            // &[T; N] -> *[T; N] -> *T
+            return;
+        }
+
+        // FIXME: Check other kinds of non-coercion casts and report error if any?
+    }
+}
+
+fn check_ref_to_ptr_cast(expr_ty: Ty, cast_ty: Ty, table: &mut InferenceTable<'_>) -> bool {
+    let Some((expr_inner_ty, _, _)) = expr_ty.as_reference() else { return false; };
+    let Some((cast_inner_ty, _)) = cast_ty.as_raw_ptr() else { return false; };
+    let TyKind::Array(expr_elt_ty, _) = expr_inner_ty.kind(Interner) else { return false; };
+    table.coerce(expr_elt_ty, cast_inner_ty).is_ok()
+}
diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs
index 4b14345aa39..9475eed44e4 100644
--- a/crates/hir-ty/src/infer/expr.rs
+++ b/crates/hir-ty/src/infer/expr.rs
@@ -46,8 +46,8 @@ use crate::{
 };
 
 use super::{
-    coerce::auto_deref_adjust_steps, find_breakable, BreakableContext, Diverges, Expectation,
-    InferenceContext, InferenceDiagnostic, TypeMismatch,
+    cast::CastCheck, coerce::auto_deref_adjust_steps, find_breakable, BreakableContext, Diverges,
+    Expectation, InferenceContext, InferenceDiagnostic, TypeMismatch,
 };
 
 impl InferenceContext<'_> {
@@ -574,16 +574,8 @@ impl InferenceContext<'_> {
             }
             Expr::Cast { expr, type_ref } => {
                 let cast_ty = self.make_ty(type_ref);
-                // FIXME: propagate the "castable to" expectation
-                let inner_ty = self.infer_expr_no_expect(*expr);
-                match (inner_ty.kind(Interner), cast_ty.kind(Interner)) {
-                    (TyKind::Ref(_, _, inner), TyKind::Raw(_, cast)) => {
-                        // FIXME: record invalid cast diagnostic in case of mismatch
-                        self.unify(inner, cast);
-                    }
-                    // FIXME check the other kinds of cast...
-                    _ => (),
-                }
+                let expr_ty = self.infer_expr(*expr, &Expectation::Castable(cast_ty.clone()));
+                self.deferred_cast_checks.push(CastCheck::new(expr_ty, cast_ty.clone()));
                 cast_ty
             }
             Expr::Ref { expr, rawness, mutability } => {
@@ -1592,7 +1584,7 @@ impl InferenceContext<'_> {
         output: Ty,
         inputs: Vec<Ty>,
     ) -> Vec<Ty> {
-        if let Some(expected_ty) = expected_output.to_option(&mut self.table) {
+        if let Some(expected_ty) = expected_output.only_has_type(&mut self.table) {
             self.table.fudge_inference(|table| {
                 if table.try_unify(&expected_ty, &output).is_ok() {
                     table.resolve_with_fallback(inputs, &|var, kind, _, _| match kind {
diff --git a/crates/hir-ty/src/tests/regression.rs b/crates/hir-ty/src/tests/regression.rs
index 8b95110233f..375014d6c7f 100644
--- a/crates/hir-ty/src/tests/regression.rs
+++ b/crates/hir-ty/src/tests/regression.rs
@@ -1978,3 +1978,23 @@ fn x(a: [i32; 4]) {
         "#,
     );
 }
+
+#[test]
+fn dont_unify_on_casts() {
+    // #15246
+    check_types(
+        r#"
+fn unify(_: [bool; 1]) {}
+fn casted(_: *const bool) {}
+fn default<T>() -> T { loop {} }
+
+fn test() {
+    let foo = default();
+      //^^^ [bool; 1]
+
+    casted(&foo as *const _);
+    unify(foo);
+}
+"#,
+    );
+}
diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs
index a0ff628435f..2ad7946c8ac 100644
--- a/crates/hir-ty/src/tests/simple.rs
+++ b/crates/hir-ty/src/tests/simple.rs
@@ -3513,7 +3513,6 @@ fn func() {
     );
 }
 
-// FIXME
 #[test]
 fn castable_to() {
     check_infer(
@@ -3538,10 +3537,10 @@ fn func() {
             120..122 '{}': ()
             138..184 '{     ...0]>; }': ()
             148..149 'x': Box<[i32; 0]>
-            152..160 'Box::new': fn new<[{unknown}; 0]>([{unknown}; 0]) -> Box<[{unknown}; 0]>
-            152..164 'Box::new([])': Box<[{unknown}; 0]>
+            152..160 'Box::new': fn new<[i32; 0]>([i32; 0]) -> Box<[i32; 0]>
+            152..164 'Box::new([])': Box<[i32; 0]>
             152..181 'Box::n...2; 0]>': Box<[i32; 0]>
-            161..163 '[]': [{unknown}; 0]
+            161..163 '[]': [i32; 0]
         "#]],
     );
 }
@@ -3578,6 +3577,21 @@ fn f<T>(t: Ark<T>) {
 }
 
 #[test]
+fn ref_to_array_to_ptr_cast() {
+    check_types(
+        r#"
+fn default<T>() -> T { loop {} }
+fn foo() {
+    let arr = [default()];
+      //^^^ [i32; 1]
+    let ref_to_arr = &arr;
+    let casted = ref_to_arr as *const i32;
+}
+"#,
+    );
+}
+
+#[test]
 fn const_dependent_on_local() {
     check_types(
         r#"