about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-07-28 07:17:38 +0000
committerbors <bors@rust-lang.org>2023-07-28 07:17:38 +0000
commit037844c8a032db23676c8caff59d623c4bad873a (patch)
treecff0ffa504fb148722faf3459259fc340e0d3e2b
parentbc1b0bfa7fd8188e207976485a841dc6c37b4f94 (diff)
parent074488b290732092d077c8271fdcc2c6a91ecede (diff)
downloadrust-037844c8a032db23676c8caff59d623c4bad873a.tar.gz
rust-037844c8a032db23676c8caff59d623c4bad873a.zip
Auto merge of #15271 - lowr:patch/re-castable, r=HKalbasi
Properly infer types with type casts

This PR reenables `Expectation::Castable` (previous attempt at #14104, reverted by #14120) and implements type cast checks, which enable us to infer a bit more.

Castable expectations are relatively weak -- they only influence the inference if we cannot infer the types by other means. Therefore, we need to defer possible type unification with the casted type until we type check all expressions of the body. This PR adds a struct and slots in `InferenceContext` for the deferred cast checks (c.f. [`CastCheck`] in `rustc_hir_typeck`).

I only implemented the bits that affect the inference result. It should be possible to return type adjustments for well-formed casts and report diagnostics for invalid casts, but I'm leaving them for future work for now.

Fixes #11571
Fixes #15246

[`CastCheck`]: https://github.com/rust-lang/rust/blob/da1d099f91ea387a2814a6244dd875a2048b486f/compiler/rustc_hir_typeck/src/cast.rs#L55
-rw-r--r--crates/hir-ty/src/infer.rs34
-rw-r--r--crates/hir-ty/src/infer/cast.rs46
-rw-r--r--crates/hir-ty/src/infer/expr.rs18
-rw-r--r--crates/hir-ty/src/tests/regression.rs20
-rw-r--r--crates/hir-ty/src/tests/simple.rs22
5 files changed, 112 insertions, 28 deletions
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs
index 0a617dae7d4..b4915dbf0f9 100644
--- a/crates/hir-ty/src/infer.rs
+++ b/crates/hir-ty/src/infer.rs
@@ -13,6 +13,15 @@
 //! to certain types. To record this, we use the union-find implementation from
 //! the `ena` crate, which is extracted from rustc.
 
+mod cast;
+pub(crate) mod closure;
+mod coerce;
+mod expr;
+mod mutability;
+mod pat;
+mod path;
+pub(crate) mod unify;
+
 use std::{convert::identity, ops::Index};
 
 use chalk_ir::{
@@ -60,15 +69,8 @@ pub use coerce::could_coerce;
 #[allow(unreachable_pub)]
 pub use unify::could_unify;
 
-pub(crate) use self::closure::{CaptureKind, CapturedItem, CapturedItemWithoutTy};
-
-pub(crate) mod unify;
-mod path;
-mod expr;
-mod pat;
-mod coerce;
-pub(crate) mod closure;
-mod mutability;
+use cast::CastCheck;
+pub(crate) use closure::{CaptureKind, CapturedItem, CapturedItemWithoutTy};
 
 /// The entry point of type inference.
 pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<InferenceResult> {
@@ -508,6 +510,8 @@ pub(crate) struct InferenceContext<'a> {
     diverges: Diverges,
     breakables: Vec<BreakableContext>,
 
+    deferred_cast_checks: Vec<CastCheck>,
+
     // fields related to closure capture
     current_captures: Vec<CapturedItemWithoutTy>,
     current_closure: Option<ClosureId>,
@@ -582,7 +586,8 @@ impl<'a> InferenceContext<'a> {
             resolver,
             diverges: Diverges::Maybe,
             breakables: Vec::new(),
-            current_captures: vec![],
+            deferred_cast_checks: Vec::new(),
+            current_captures: Vec::new(),
             current_closure: None,
             deferred_closures: FxHashMap::default(),
             closure_dependencies: FxHashMap::default(),
@@ -594,7 +599,7 @@ impl<'a> InferenceContext<'a> {
     // used this function for another workaround, mention it here. If you really need this function and believe that
     // there is no problem in it being `pub(crate)`, remove this comment.
     pub(crate) fn resolve_all(self) -> InferenceResult {
-        let InferenceContext { mut table, mut result, .. } = self;
+        let InferenceContext { mut table, mut result, deferred_cast_checks, .. } = self;
         // Destructure every single field so whenever new fields are added to `InferenceResult` we
         // don't forget to handle them here.
         let InferenceResult {
@@ -622,6 +627,13 @@ impl<'a> InferenceContext<'a> {
 
         table.fallback_if_possible();
 
+        // Comment from rustc:
+        // Even though coercion casts provide type hints, we check casts after fallback for
+        // backwards compatibility. This makes fallback a stronger type hint than a cast coercion.
+        for cast in deferred_cast_checks {
+            cast.check(&mut table);
+        }
+
         // FIXME resolve obligations as well (use Guidance if necessary)
         table.resolve_obligations_as_possible();
 
diff --git a/crates/hir-ty/src/infer/cast.rs b/crates/hir-ty/src/infer/cast.rs
new file mode 100644
index 00000000000..9e1c74b16fa
--- /dev/null
+++ b/crates/hir-ty/src/infer/cast.rs
@@ -0,0 +1,46 @@
+//! Type cast logic. Basically coercion + additional casts.
+
+use crate::{infer::unify::InferenceTable, Interner, Ty, TyExt, TyKind};
+
+#[derive(Clone, Debug)]
+pub(super) struct CastCheck {
+    expr_ty: Ty,
+    cast_ty: Ty,
+}
+
+impl CastCheck {
+    pub(super) fn new(expr_ty: Ty, cast_ty: Ty) -> Self {
+        Self { expr_ty, cast_ty }
+    }
+
+    pub(super) fn check(self, table: &mut InferenceTable<'_>) {
+        // FIXME: This function currently only implements the bits that influence the type
+        // inference. We should return the adjustments on success and report diagnostics on error.
+        let expr_ty = table.resolve_ty_shallow(&self.expr_ty);
+        let cast_ty = table.resolve_ty_shallow(&self.cast_ty);
+
+        if expr_ty.contains_unknown() || cast_ty.contains_unknown() {
+            return;
+        }
+
+        if table.coerce(&expr_ty, &cast_ty).is_ok() {
+            return;
+        }
+
+        if check_ref_to_ptr_cast(expr_ty, cast_ty, table) {
+            // Note that this type of cast is actually split into a coercion to a
+            // pointer type and a cast:
+            // &[T; N] -> *[T; N] -> *T
+            return;
+        }
+
+        // FIXME: Check other kinds of non-coercion casts and report error if any?
+    }
+}
+
+fn check_ref_to_ptr_cast(expr_ty: Ty, cast_ty: Ty, table: &mut InferenceTable<'_>) -> bool {
+    let Some((expr_inner_ty, _, _)) = expr_ty.as_reference() else { return false; };
+    let Some((cast_inner_ty, _)) = cast_ty.as_raw_ptr() else { return false; };
+    let TyKind::Array(expr_elt_ty, _) = expr_inner_ty.kind(Interner) else { return false; };
+    table.coerce(expr_elt_ty, cast_inner_ty).is_ok()
+}
diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs
index 72e6443beb7..501c0b4bd77 100644
--- a/crates/hir-ty/src/infer/expr.rs
+++ b/crates/hir-ty/src/infer/expr.rs
@@ -46,8 +46,8 @@ use crate::{
 };
 
 use super::{
-    coerce::auto_deref_adjust_steps, find_breakable, BreakableContext, Diverges, Expectation,
-    InferenceContext, InferenceDiagnostic, TypeMismatch,
+    cast::CastCheck, coerce::auto_deref_adjust_steps, find_breakable, BreakableContext, Diverges,
+    Expectation, InferenceContext, InferenceDiagnostic, TypeMismatch,
 };
 
 impl InferenceContext<'_> {
@@ -574,16 +574,8 @@ impl InferenceContext<'_> {
             }
             Expr::Cast { expr, type_ref } => {
                 let cast_ty = self.make_ty(type_ref);
-                // FIXME: propagate the "castable to" expectation
-                let inner_ty = self.infer_expr_no_expect(*expr);
-                match (inner_ty.kind(Interner), cast_ty.kind(Interner)) {
-                    (TyKind::Ref(_, _, inner), TyKind::Raw(_, cast)) => {
-                        // FIXME: record invalid cast diagnostic in case of mismatch
-                        self.unify(inner, cast);
-                    }
-                    // FIXME check the other kinds of cast...
-                    _ => (),
-                }
+                let expr_ty = self.infer_expr(*expr, &Expectation::Castable(cast_ty.clone()));
+                self.deferred_cast_checks.push(CastCheck::new(expr_ty, cast_ty.clone()));
                 cast_ty
             }
             Expr::Ref { expr, rawness, mutability } => {
@@ -1592,7 +1584,7 @@ impl InferenceContext<'_> {
         output: Ty,
         inputs: Vec<Ty>,
     ) -> Vec<Ty> {
-        if let Some(expected_ty) = expected_output.to_option(&mut self.table) {
+        if let Some(expected_ty) = expected_output.only_has_type(&mut self.table) {
             self.table.fudge_inference(|table| {
                 if table.try_unify(&expected_ty, &output).is_ok() {
                     table.resolve_with_fallback(inputs, &|var, kind, _, _| match kind {
diff --git a/crates/hir-ty/src/tests/regression.rs b/crates/hir-ty/src/tests/regression.rs
index 8b95110233f..375014d6c7f 100644
--- a/crates/hir-ty/src/tests/regression.rs
+++ b/crates/hir-ty/src/tests/regression.rs
@@ -1978,3 +1978,23 @@ fn x(a: [i32; 4]) {
         "#,
     );
 }
+
+#[test]
+fn dont_unify_on_casts() {
+    // #15246
+    check_types(
+        r#"
+fn unify(_: [bool; 1]) {}
+fn casted(_: *const bool) {}
+fn default<T>() -> T { loop {} }
+
+fn test() {
+    let foo = default();
+      //^^^ [bool; 1]
+
+    casted(&foo as *const _);
+    unify(foo);
+}
+"#,
+    );
+}
diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs
index a0ff628435f..2ad7946c8ac 100644
--- a/crates/hir-ty/src/tests/simple.rs
+++ b/crates/hir-ty/src/tests/simple.rs
@@ -3513,7 +3513,6 @@ fn func() {
     );
 }
 
-// FIXME
 #[test]
 fn castable_to() {
     check_infer(
@@ -3538,10 +3537,10 @@ fn func() {
             120..122 '{}': ()
             138..184 '{     ...0]>; }': ()
             148..149 'x': Box<[i32; 0]>
-            152..160 'Box::new': fn new<[{unknown}; 0]>([{unknown}; 0]) -> Box<[{unknown}; 0]>
-            152..164 'Box::new([])': Box<[{unknown}; 0]>
+            152..160 'Box::new': fn new<[i32; 0]>([i32; 0]) -> Box<[i32; 0]>
+            152..164 'Box::new([])': Box<[i32; 0]>
             152..181 'Box::n...2; 0]>': Box<[i32; 0]>
-            161..163 '[]': [{unknown}; 0]
+            161..163 '[]': [i32; 0]
         "#]],
     );
 }
@@ -3578,6 +3577,21 @@ fn f<T>(t: Ark<T>) {
 }
 
 #[test]
+fn ref_to_array_to_ptr_cast() {
+    check_types(
+        r#"
+fn default<T>() -> T { loop {} }
+fn foo() {
+    let arr = [default()];
+      //^^^ [i32; 1]
+    let ref_to_arr = &arr;
+    let casted = ref_to_arr as *const i32;
+}
+"#,
+    );
+}
+
+#[test]
 fn const_dependent_on_local() {
     check_types(
         r#"