about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2023-03-04 12:48:57 +0100
committerLukas Wirth <lukastw97@gmail.com>2023-03-04 12:48:57 +0100
commit24ba1bed040f7d8f483b250dbd4e49383823f644 (patch)
tree5a07a0dd5de431a79ee6df0eb5e13204de48b9f6
parent73e2505cfa8be6838f5151b272f1d24869b2a3d6 (diff)
downloadrust-24ba1bed040f7d8f483b250dbd4e49383823f644.tar.gz
rust-24ba1bed040f7d8f483b250dbd4e49383823f644.zip
Set expectation for no-semi expression statements to unit
-rw-r--r--crates/hir-def/src/resolver.rs4
-rw-r--r--crates/hir-ty/src/infer/expr.rs73
-rw-r--r--crates/hir-ty/src/infer/path.rs25
-rw-r--r--crates/hir-ty/src/tests/diagnostics.rs21
-rw-r--r--crates/hir-ty/src/tests/regression.rs4
5 files changed, 81 insertions, 46 deletions
diff --git a/crates/hir-def/src/resolver.rs b/crates/hir-def/src/resolver.rs
index 0b9c136c7eb..664db292a7f 100644
--- a/crates/hir-def/src/resolver.rs
+++ b/crates/hir-def/src/resolver.rs
@@ -294,8 +294,8 @@ impl Resolver {
             }
         }
 
-        if let res @ Some(_) = self.module_scope.resolve_path_in_value_ns(db, path) {
-            return res;
+        if let Some(res) = self.module_scope.resolve_path_in_value_ns(db, path) {
+            return Some(res);
         }
 
         // If a path of the shape `u16::from_le_bytes` failed to resolve at all, then we fall back
diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs
index 02024e1ea78..81e97a9b0bf 100644
--- a/crates/hir-ty/src/infer/expr.rs
+++ b/crates/hir-ty/src/infer/expr.rs
@@ -130,7 +130,7 @@ impl<'a> InferenceContext<'a> {
                 );
                 let ty = match label {
                     Some(_) => {
-                        let break_ty = self.table.new_type_var();
+                        let break_ty = expected.coercion_target_type(&mut self.table);
                         let (breaks, ty) = self.with_breakable_ctx(
                             BreakableKind::Block,
                             Some(break_ty.clone()),
@@ -403,37 +403,47 @@ impl<'a> InferenceContext<'a> {
             Expr::Match { expr, arms } => {
                 let input_ty = self.infer_expr(*expr, &Expectation::none());
 
-                let expected = expected.adjust_for_branches(&mut self.table);
-
-                let result_ty = if arms.is_empty() {
+                if arms.is_empty() {
+                    self.diverges = Diverges::Always;
                     self.result.standard_types.never.clone()
                 } else {
-                    expected.coercion_target_type(&mut self.table)
-                };
-                let mut coerce = CoerceMany::new(result_ty);
-
-                let matchee_diverges = self.diverges;
-                let mut all_arms_diverge = Diverges::Always;
-
-                for arm in arms.iter() {
-                    self.diverges = Diverges::Maybe;
-                    let input_ty = self.resolve_ty_shallow(&input_ty);
-                    self.infer_top_pat(arm.pat, &input_ty);
-                    if let Some(guard_expr) = arm.guard {
-                        self.infer_expr(
-                            guard_expr,
-                            &Expectation::HasType(self.result.standard_types.bool_.clone()),
-                        );
+                    let matchee_diverges = mem::replace(&mut self.diverges, Diverges::Maybe);
+                    let mut all_arms_diverge = Diverges::Always;
+                    for arm in arms.iter() {
+                        let input_ty = self.resolve_ty_shallow(&input_ty);
+                        self.infer_top_pat(arm.pat, &input_ty);
                     }
 
-                    let arm_ty = self.infer_expr_inner(arm.expr, &expected);
-                    all_arms_diverge &= self.diverges;
-                    coerce.coerce(self, Some(arm.expr), &arm_ty);
-                }
+                    let expected = expected.adjust_for_branches(&mut self.table);
+                    let result_ty = match &expected {
+                        // We don't coerce to `()` so that if the match expression is a
+                        // statement it's branches can have any consistent type.
+                        Expectation::HasType(ty) if *ty != self.result.standard_types.unit => {
+                            ty.clone()
+                        }
+                        _ => self.table.new_type_var(),
+                    };
+                    let mut coerce = CoerceMany::new(result_ty);
+
+                    for arm in arms.iter() {
+                        if let Some(guard_expr) = arm.guard {
+                            self.diverges = Diverges::Maybe;
+                            self.infer_expr(
+                                guard_expr,
+                                &Expectation::HasType(self.result.standard_types.bool_.clone()),
+                            );
+                        }
+                        self.diverges = Diverges::Maybe;
 
-                self.diverges = matchee_diverges | all_arms_diverge;
+                        let arm_ty = self.infer_expr_inner(arm.expr, &expected);
+                        all_arms_diverge &= self.diverges;
+                        coerce.coerce(self, Some(arm.expr), &arm_ty);
+                    }
 
-                coerce.complete(self)
+                    self.diverges = matchee_diverges | all_arms_diverge;
+
+                    coerce.complete(self)
+                }
             }
             Expr::Path(p) => {
                 // FIXME this could be more efficient...
@@ -1179,8 +1189,15 @@ impl<'a> InferenceContext<'a> {
                         self.diverges = previous_diverges;
                     }
                 }
-                Statement::Expr { expr, .. } => {
-                    self.infer_expr(*expr, &Expectation::none());
+                &Statement::Expr { expr, has_semi } => {
+                    self.infer_expr(
+                        expr,
+                        &if has_semi {
+                            Expectation::none()
+                        } else {
+                            Expectation::HasType(self.result.standard_types.unit.clone())
+                        },
+                    );
                 }
             }
         }
diff --git a/crates/hir-ty/src/infer/path.rs b/crates/hir-ty/src/infer/path.rs
index 0a8527afbd0..b3867623f37 100644
--- a/crates/hir-ty/src/infer/path.rs
+++ b/crates/hir-ty/src/infer/path.rs
@@ -40,20 +40,14 @@ impl<'a> InferenceContext<'a> {
         id: ExprOrPatId,
     ) -> Option<Ty> {
         let (value, self_subst) = if let Some(type_ref) = path.type_anchor() {
-            if path.segments().is_empty() {
-                // This can't actually happen syntax-wise
-                return None;
-            }
+            let Some(last) = path.segments().last() else { return None };
             let ty = self.make_ty(type_ref);
             let remaining_segments_for_ty = path.segments().take(path.segments().len() - 1);
             let ctx = crate::lower::TyLoweringContext::new(self.db, resolver);
             let (ty, _) = ctx.lower_ty_relative_path(ty, None, remaining_segments_for_ty);
-            self.resolve_ty_assoc_item(
-                ty,
-                path.segments().last().expect("path had at least one segment").name,
-                id,
-            )?
+            self.resolve_ty_assoc_item(ty, last.name, id)?
         } else {
+            // FIXME: report error, unresolved first path segment
             let value_or_partial =
                 resolver.resolve_path_in_value_ns(self.db.upcast(), path.mod_path())?;
 
@@ -66,10 +60,13 @@ impl<'a> InferenceContext<'a> {
         };
 
         let typable: ValueTyDefId = match value {
-            ValueNs::LocalBinding(pat) => {
-                let ty = self.result.type_of_pat.get(pat)?.clone();
-                return Some(ty);
-            }
+            ValueNs::LocalBinding(pat) => match self.result.type_of_pat.get(pat) {
+                Some(ty) => return Some(ty.clone()),
+                None => {
+                    never!("uninferred pattern?");
+                    return None;
+                }
+            },
             ValueNs::FunctionId(it) => it.into(),
             ValueNs::ConstId(it) => it.into(),
             ValueNs::StaticId(it) => it.into(),
@@ -91,7 +88,7 @@ impl<'a> InferenceContext<'a> {
                     let ty = self.db.value_ty(struct_id.into()).substitute(Interner, &substs);
                     return Some(ty);
                 } else {
-                    // FIXME: diagnostic, invalid Self reference
+                    // FIXME: report error, invalid Self reference
                     return None;
                 }
             }
diff --git a/crates/hir-ty/src/tests/diagnostics.rs b/crates/hir-ty/src/tests/diagnostics.rs
index f00fa972948..1876be303ad 100644
--- a/crates/hir-ty/src/tests/diagnostics.rs
+++ b/crates/hir-ty/src/tests/diagnostics.rs
@@ -73,3 +73,24 @@ fn test(x: bool) -> &'static str {
 "#,
     );
 }
+
+#[test]
+fn non_unit_block_expr_stmt_no_semi() {
+    check(
+        r#"
+fn test(x: bool) {
+    if x {
+        "notok"
+      //^^^^^^^ expected (), got &str
+    } else {
+        "ok"
+      //^^^^ expected (), got &str
+    }
+    match x { true => true, false => 0 }
+  //^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected (), got bool
+                                   //^ expected bool, got i32
+    ()
+}
+"#,
+    );
+}
diff --git a/crates/hir-ty/src/tests/regression.rs b/crates/hir-ty/src/tests/regression.rs
index de6ae7fff8f..5fc2f46d560 100644
--- a/crates/hir-ty/src/tests/regression.rs
+++ b/crates/hir-ty/src/tests/regression.rs
@@ -1015,9 +1015,9 @@ fn cfg_tail() {
             20..31 '{ "first" }': ()
             22..29 '"first"': &str
             72..190 '{     ...] 13 }': ()
-            78..88 '{ "fake" }': &str
+            78..88 '{ "fake" }': ()
             80..86 '"fake"': &str
-            93..103 '{ "fake" }': &str
+            93..103 '{ "fake" }': ()
             95..101 '"fake"': &str
             108..120 '{ "second" }': ()
             110..118 '"second"': &str