about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir_def/src/body/lower.rs73
-rw-r--r--crates/hir_def/src/body/scope.rs102
-rw-r--r--crates/hir_def/src/expr.rs16
-rw-r--r--crates/hir_def/src/macro_expansion_tests/mbe/tt_conversion.rs6
4 files changed, 87 insertions, 110 deletions
diff --git a/crates/hir_def/src/body/lower.rs b/crates/hir_def/src/body/lower.rs
index 7cbeef1488a..06ad7ce4cd0 100644
--- a/crates/hir_def/src/body/lower.rs
+++ b/crates/hir_def/src/body/lower.rs
@@ -28,7 +28,7 @@ use crate::{
     db::DefDatabase,
     expr::{
         dummy_expr_id, Array, BindingAnnotation, Expr, ExprId, Label, LabelId, Literal, MatchArm,
-        MatchGuard, Pat, PatId, RecordFieldPat, RecordLitField, Statement,
+        Pat, PatId, RecordFieldPat, RecordLitField, Statement,
     },
     intern::Interned,
     item_scope::BuiltinShadowMode,
@@ -155,9 +155,6 @@ impl ExprCollector<'_> {
     fn alloc_expr_desugared(&mut self, expr: Expr) -> ExprId {
         self.make_expr(expr, Err(SyntheticSyntax))
     }
-    fn unit(&mut self) -> ExprId {
-        self.alloc_expr_desugared(Expr::Tuple { exprs: Box::default() })
-    }
     fn missing_expr(&mut self) -> ExprId {
         self.alloc_expr_desugared(Expr::Missing)
     }
@@ -215,33 +212,15 @@ impl ExprCollector<'_> {
                     }
                 });
 
-                let condition = match e.condition() {
-                    None => self.missing_expr(),
-                    Some(condition) => match condition.pat() {
-                        None => self.collect_expr_opt(condition.expr()),
-                        // if let -- desugar to match
-                        Some(pat) => {
-                            let pat = self.collect_pat(pat);
-                            let match_expr = self.collect_expr_opt(condition.expr());
-                            let placeholder_pat = self.missing_pat();
-                            let arms = vec![
-                                MatchArm { pat, expr: then_branch, guard: None },
-                                MatchArm {
-                                    pat: placeholder_pat,
-                                    expr: else_branch.unwrap_or_else(|| self.unit()),
-                                    guard: None,
-                                },
-                            ]
-                            .into();
-                            return Some(
-                                self.alloc_expr(Expr::Match { expr: match_expr, arms }, syntax_ptr),
-                            );
-                        }
-                    },
-                };
+                let condition = self.collect_expr_opt(e.condition());
 
                 self.alloc_expr(Expr::If { condition, then_branch, else_branch }, syntax_ptr)
             }
+            ast::Expr::LetExpr(e) => {
+                let pat = self.collect_pat_opt(e.pat());
+                let expr = self.collect_expr_opt(e.expr());
+                self.alloc_expr(Expr::Let { pat, expr }, syntax_ptr)
+            }
             ast::Expr::BlockExpr(e) => match e.modifier() {
                 Some(ast::BlockModifier::Try(_)) => {
                     let body = self.collect_block(e);
@@ -282,31 +261,7 @@ impl ExprCollector<'_> {
                 let label = e.label().map(|label| self.collect_label(label));
                 let body = self.collect_block_opt(e.loop_body());
 
-                let condition = match e.condition() {
-                    None => self.missing_expr(),
-                    Some(condition) => match condition.pat() {
-                        None => self.collect_expr_opt(condition.expr()),
-                        // if let -- desugar to match
-                        Some(pat) => {
-                            cov_mark::hit!(infer_resolve_while_let);
-                            let pat = self.collect_pat(pat);
-                            let match_expr = self.collect_expr_opt(condition.expr());
-                            let placeholder_pat = self.missing_pat();
-                            let break_ =
-                                self.alloc_expr_desugared(Expr::Break { expr: None, label: None });
-                            let arms = vec![
-                                MatchArm { pat, expr: body, guard: None },
-                                MatchArm { pat: placeholder_pat, expr: break_, guard: None },
-                            ]
-                            .into();
-                            let match_expr =
-                                self.alloc_expr_desugared(Expr::Match { expr: match_expr, arms });
-                            return Some(
-                                self.alloc_expr(Expr::Loop { body: match_expr, label }, syntax_ptr),
-                            );
-                        }
-                    },
-                };
+                let condition = self.collect_expr_opt(e.condition());
 
                 self.alloc_expr(Expr::While { condition, body, label }, syntax_ptr)
             }
@@ -352,15 +307,9 @@ impl ExprCollector<'_> {
                             self.check_cfg(&arm).map(|()| MatchArm {
                                 pat: self.collect_pat_opt(arm.pat()),
                                 expr: self.collect_expr_opt(arm.expr()),
-                                guard: arm.guard().map(|guard| match guard.pat() {
-                                    Some(pat) => MatchGuard::IfLet {
-                                        pat: self.collect_pat(pat),
-                                        expr: self.collect_expr_opt(guard.expr()),
-                                    },
-                                    None => {
-                                        MatchGuard::If { expr: self.collect_expr_opt(guard.expr()) }
-                                    }
-                                }),
+                                guard: arm
+                                    .guard()
+                                    .map(|guard| self.collect_expr_opt(guard.condition())),
                             })
                         })
                         .collect()
diff --git a/crates/hir_def/src/body/scope.rs b/crates/hir_def/src/body/scope.rs
index 2658eece8e8..505d33fa482 100644
--- a/crates/hir_def/src/body/scope.rs
+++ b/crates/hir_def/src/body/scope.rs
@@ -8,7 +8,7 @@ use rustc_hash::FxHashMap;
 use crate::{
     body::Body,
     db::DefDatabase,
-    expr::{Expr, ExprId, LabelId, MatchGuard, Pat, PatId, Statement},
+    expr::{Expr, ExprId, LabelId, Pat, PatId, Statement},
     BlockId, DefWithBodyId,
 };
 
@@ -53,9 +53,9 @@ impl ExprScopes {
     fn new(body: &Body) -> ExprScopes {
         let mut scopes =
             ExprScopes { scopes: Arena::default(), scope_by_expr: FxHashMap::default() };
-        let root = scopes.root_scope();
+        let mut root = scopes.root_scope();
         scopes.add_params_bindings(body, root, &body.params);
-        compute_expr_scopes(body.body_expr, body, &mut scopes, root);
+        compute_expr_scopes(body.body_expr, body, &mut scopes, &mut root);
         scopes
     }
 
@@ -151,32 +151,32 @@ fn compute_block_scopes(
         match stmt {
             Statement::Let { pat, initializer, else_branch, .. } => {
                 if let Some(expr) = initializer {
-                    compute_expr_scopes(*expr, body, scopes, scope);
+                    compute_expr_scopes(*expr, body, scopes, &mut scope);
                 }
                 if let Some(expr) = else_branch {
-                    compute_expr_scopes(*expr, body, scopes, scope);
+                    compute_expr_scopes(*expr, body, scopes, &mut scope);
                 }
                 scope = scopes.new_scope(scope);
                 scopes.add_bindings(body, scope, *pat);
             }
             Statement::Expr { expr, .. } => {
-                compute_expr_scopes(*expr, body, scopes, scope);
+                compute_expr_scopes(*expr, body, scopes, &mut scope);
             }
         }
     }
     if let Some(expr) = tail {
-        compute_expr_scopes(expr, body, scopes, scope);
+        compute_expr_scopes(expr, body, scopes, &mut scope);
     }
 }
 
-fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) {
+fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: &mut ScopeId) {
     let make_label =
         |label: &Option<LabelId>| label.map(|label| (label, body.labels[label].name.clone()));
 
-    scopes.set_scope(expr, scope);
+    scopes.set_scope(expr, *scope);
     match &body[expr] {
         Expr::Block { statements, tail, id, label } => {
-            let scope = scopes.new_block_scope(scope, *id, make_label(label));
+            let scope = scopes.new_block_scope(*scope, *id, make_label(label));
             // Overwrite the old scope for the block expr, so that every block scope can be found
             // via the block itself (important for blocks that only contain items, no expressions).
             scopes.set_scope(expr, scope);
@@ -184,46 +184,49 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope
         }
         Expr::For { iterable, pat, body: body_expr, label } => {
             compute_expr_scopes(*iterable, body, scopes, scope);
-            let scope = scopes.new_labeled_scope(scope, make_label(label));
+            let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
             scopes.add_bindings(body, scope, *pat);
-            compute_expr_scopes(*body_expr, body, scopes, scope);
+            compute_expr_scopes(*body_expr, body, scopes, &mut scope);
         }
         Expr::While { condition, body: body_expr, label } => {
-            let scope = scopes.new_labeled_scope(scope, make_label(label));
-            compute_expr_scopes(*condition, body, scopes, scope);
-            compute_expr_scopes(*body_expr, body, scopes, scope);
+            let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
+            compute_expr_scopes(*condition, body, scopes, &mut scope);
+            compute_expr_scopes(*body_expr, body, scopes, &mut scope);
         }
         Expr::Loop { body: body_expr, label } => {
-            let scope = scopes.new_labeled_scope(scope, make_label(label));
-            compute_expr_scopes(*body_expr, body, scopes, scope);
+            let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
+            compute_expr_scopes(*body_expr, body, scopes, &mut scope);
         }
         Expr::Lambda { args, body: body_expr, .. } => {
-            let scope = scopes.new_scope(scope);
+            let mut scope = scopes.new_scope(*scope);
             scopes.add_params_bindings(body, scope, args);
-            compute_expr_scopes(*body_expr, body, scopes, scope);
+            compute_expr_scopes(*body_expr, body, scopes, &mut scope);
         }
         Expr::Match { expr, arms } => {
             compute_expr_scopes(*expr, body, scopes, scope);
             for arm in arms.iter() {
-                let mut scope = scopes.new_scope(scope);
+                let mut scope = scopes.new_scope(*scope);
                 scopes.add_bindings(body, scope, arm.pat);
-                match arm.guard {
-                    Some(MatchGuard::If { expr: guard }) => {
-                        scopes.set_scope(guard, scope);
-                        compute_expr_scopes(guard, body, scopes, scope);
-                    }
-                    Some(MatchGuard::IfLet { pat, expr: guard }) => {
-                        scopes.set_scope(guard, scope);
-                        compute_expr_scopes(guard, body, scopes, scope);
-                        scope = scopes.new_scope(scope);
-                        scopes.add_bindings(body, scope, pat);
-                    }
-                    _ => {}
-                };
-                scopes.set_scope(arm.expr, scope);
-                compute_expr_scopes(arm.expr, body, scopes, scope);
+                if let Some(guard) = arm.guard {
+                    scope = scopes.new_scope(scope);
+                    compute_expr_scopes(guard, body, scopes, &mut scope);
+                }
+                compute_expr_scopes(arm.expr, body, scopes, &mut scope);
             }
         }
+        &Expr::If { condition, then_branch, else_branch } => {
+            let mut then_branch_scope = scopes.new_scope(*scope);
+            compute_expr_scopes(condition, body, scopes, &mut then_branch_scope);
+            compute_expr_scopes(then_branch, body, scopes, &mut then_branch_scope);
+            if let Some(else_branch) = else_branch {
+                compute_expr_scopes(else_branch, body, scopes, scope);
+            }
+        }
+        &Expr::Let { pat, expr } => {
+            compute_expr_scopes(expr, body, scopes, scope);
+            *scope = scopes.new_scope(*scope);
+            scopes.add_bindings(body, *scope, pat);
+        }
         e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)),
     };
 }
@@ -500,8 +503,7 @@ fn foo() {
     }
 
     #[test]
-    fn while_let_desugaring() {
-        cov_mark::check!(infer_resolve_while_let);
+    fn while_let_adds_binding() {
         do_check_local_name(
             r#"
 fn test() {
@@ -513,5 +515,31 @@ fn test() {
 "#,
             75,
         );
+        do_check_local_name(
+            r#"
+fn test() {
+    let foo: Option<f32> = None;
+    while (((let Option::Some(_) = foo))) && let Option::Some(spam) = foo {
+        spam$0
+    }
+}
+"#,
+            107,
+        );
+    }
+
+    #[test]
+    fn match_guard_if_let() {
+        do_check_local_name(
+            r#"
+fn test() {
+    let foo: Option<f32> = None;
+    match foo {
+        _ if let Option::Some(spam) = foo => spam$0,
+    }
+}
+"#,
+            93,
+        );
     }
 }
diff --git a/crates/hir_def/src/expr.rs b/crates/hir_def/src/expr.rs
index 6534f970ee6..4dca8238880 100644
--- a/crates/hir_def/src/expr.rs
+++ b/crates/hir_def/src/expr.rs
@@ -59,6 +59,10 @@ pub enum Expr {
         then_branch: ExprId,
         else_branch: Option<ExprId>,
     },
+    Let {
+        pat: PatId,
+        expr: ExprId,
+    },
     Block {
         id: BlockId,
         statements: Box<[Statement]>,
@@ -189,18 +193,11 @@ pub enum Array {
 #[derive(Debug, Clone, Eq, PartialEq)]
 pub struct MatchArm {
     pub pat: PatId,
-    pub guard: Option<MatchGuard>,
+    pub guard: Option<ExprId>,
     pub expr: ExprId,
 }
 
 #[derive(Debug, Clone, Eq, PartialEq)]
-pub enum MatchGuard {
-    If { expr: ExprId },
-
-    IfLet { pat: PatId, expr: ExprId },
-}
-
-#[derive(Debug, Clone, Eq, PartialEq)]
 pub struct RecordLitField {
     pub name: Name,
     pub expr: ExprId,
@@ -232,6 +229,9 @@ impl Expr {
                     f(else_branch);
                 }
             }
+            Expr::Let { expr, .. } => {
+                f(*expr);
+            }
             Expr::Block { statements, tail, .. } => {
                 for stmt in statements.iter() {
                     match stmt {
diff --git a/crates/hir_def/src/macro_expansion_tests/mbe/tt_conversion.rs b/crates/hir_def/src/macro_expansion_tests/mbe/tt_conversion.rs
index 5f4b7d6d0bc..84cc3f3872f 100644
--- a/crates/hir_def/src/macro_expansion_tests/mbe/tt_conversion.rs
+++ b/crates/hir_def/src/macro_expansion_tests/mbe/tt_conversion.rs
@@ -108,18 +108,18 @@ fn expansion_does_not_parse_as_expression() {
     check(
         r#"
 macro_rules! stmts {
-    () => { let _ = 0; }
+    () => { fn foo() {} }
 }
 
 fn f() { let _ = stmts!/*+errors*/(); }
 "#,
         expect![[r#"
 macro_rules! stmts {
-    () => { let _ = 0; }
+    () => { fn foo() {} }
 }
 
 fn f() { let _ = /* parse error: expected expression */
-let _ = 0;; }
+fn foo() {}; }
 "#]],
     )
 }