about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ra_assists/src/handlers/change_return_type_to_result.rs2
-rw-r--r--crates/ra_hir_def/src/body/lower.rs30
-rw-r--r--crates/ra_syntax/src/ast.rs2
-rw-r--r--crates/ra_syntax/src/ast/generated.rs41
-rw-r--r--crates/ra_syntax/src/ast/generated/nodes.rs26
-rw-r--r--xtask/src/codegen/gen_syntax.rs54
-rw-r--r--xtask/src/codegen/rust.ungram2
7 files changed, 94 insertions, 63 deletions
diff --git a/crates/ra_assists/src/handlers/change_return_type_to_result.rs b/crates/ra_assists/src/handlers/change_return_type_to_result.rs
index 167e162d804..4b73c41dacb 100644
--- a/crates/ra_assists/src/handlers/change_return_type_to_result.rs
+++ b/crates/ra_assists/src/handlers/change_return_type_to_result.rs
@@ -74,6 +74,7 @@ impl TailReturnCollector {
             let expr = match &stmt {
                 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
                 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
+                ast::Stmt::Item(_) => continue,
             };
             if let Some(expr) = &expr {
                 self.handle_exprs(expr, collect_break);
@@ -94,6 +95,7 @@ impl TailReturnCollector {
                         let expr_stmt = match &expr_stmt {
                             ast::Stmt::ExprStmt(stmt) => stmt.expr(),
                             ast::Stmt::LetStmt(stmt) => stmt.initializer(),
+                            ast::Stmt::Item(_) => None,
                         };
                         if let Some(expr) = &expr_stmt {
                             self.handle_exprs(expr, collect_break);
diff --git a/crates/ra_hir_def/src/body/lower.rs b/crates/ra_hir_def/src/body/lower.rs
index 827ced4ad21..5816bf5664f 100644
--- a/crates/ra_hir_def/src/body/lower.rs
+++ b/crates/ra_hir_def/src/body/lower.rs
@@ -10,7 +10,7 @@ use hir_expand::{
 use ra_arena::Arena;
 use ra_syntax::{
     ast::{
-        self, ArgListOwner, ArrayExprKind, LiteralKind, LoopBodyOwner, ModuleItemOwner, NameOwner,
+        self, ArgListOwner, ArrayExprKind, LiteralKind, LoopBodyOwner, NameOwner,
         SlicePatComponents,
     },
     AstNode, AstPtr,
@@ -601,14 +601,20 @@ impl ExprCollector<'_> {
         self.collect_block_items(&block);
         let statements = block
             .statements()
-            .map(|s| match s {
-                ast::Stmt::LetStmt(stmt) => {
-                    let pat = self.collect_pat_opt(stmt.pat());
-                    let type_ref = stmt.ty().map(|it| TypeRef::from_ast(&self.ctx(), it));
-                    let initializer = stmt.initializer().map(|e| self.collect_expr(e));
-                    Statement::Let { pat, type_ref, initializer }
-                }
-                ast::Stmt::ExprStmt(stmt) => Statement::Expr(self.collect_expr_opt(stmt.expr())),
+            .filter_map(|s| {
+                let stmt = match s {
+                    ast::Stmt::LetStmt(stmt) => {
+                        let pat = self.collect_pat_opt(stmt.pat());
+                        let type_ref = stmt.ty().map(|it| TypeRef::from_ast(&self.ctx(), it));
+                        let initializer = stmt.initializer().map(|e| self.collect_expr(e));
+                        Statement::Let { pat, type_ref, initializer }
+                    }
+                    ast::Stmt::ExprStmt(stmt) => {
+                        Statement::Expr(self.collect_expr_opt(stmt.expr()))
+                    }
+                    ast::Stmt::Item(_) => return None,
+                };
+                Some(stmt)
             })
             .collect();
         let tail = block.expr().map(|e| self.collect_expr(e));
@@ -620,7 +626,11 @@ impl ExprCollector<'_> {
         let container = ContainerId::DefWithBodyId(self.def);
 
         let items = block
-            .items()
+            .statements()
+            .filter_map(|stmt| match stmt {
+                ast::Stmt::Item(it) => Some(it),
+                ast::Stmt::LetStmt(_) | ast::Stmt::ExprStmt(_) => None,
+            })
             .filter_map(|item| {
                 let (def, name): (ModuleDefId, Option<ast::Name>) = match item {
                     ast::Item::Fn(def) => {
diff --git a/crates/ra_syntax/src/ast.rs b/crates/ra_syntax/src/ast.rs
index 8a0e3d27b21..d536bb1e7d6 100644
--- a/crates/ra_syntax/src/ast.rs
+++ b/crates/ra_syntax/src/ast.rs
@@ -17,7 +17,7 @@ use crate::{
 
 pub use self::{
     expr_ext::{ArrayExprKind, BinOp, Effect, ElseBranch, LiteralKind, PrefixOp, RangeOp},
-    generated::{nodes::*, tokens::*},
+    generated::*,
     node_ext::{
         AttrKind, FieldKind, NameOrNameRef, PathSegmentKind, SelfParamKind, SlicePatComponents,
         StructKind, TypeBoundKind, VisibilityKind,
diff --git a/crates/ra_syntax/src/ast/generated.rs b/crates/ra_syntax/src/ast/generated.rs
index f5199e09f21..4a6f41ee71f 100644
--- a/crates/ra_syntax/src/ast/generated.rs
+++ b/crates/ra_syntax/src/ast/generated.rs
@@ -1,6 +1,41 @@
 //! This file is actually hand-written, but the submodules are indeed generated.
-
 #[rustfmt::skip]
-pub(super) mod nodes;
+mod nodes;
 #[rustfmt::skip]
-pub(super) mod tokens;
+mod tokens;
+
+use crate::{
+    AstNode,
+    SyntaxKind::{self, *},
+    SyntaxNode,
+};
+
+pub use {nodes::*, tokens::*};
+
+// Stmt is the only nested enum, so it's easier to just hand-write it
+impl AstNode for Stmt {
+    fn can_cast(kind: SyntaxKind) -> bool {
+        match kind {
+            LET_STMT | EXPR_STMT => true,
+            _ => Item::can_cast(kind),
+        }
+    }
+    fn cast(syntax: SyntaxNode) -> Option<Self> {
+        let res = match syntax.kind() {
+            LET_STMT => Stmt::LetStmt(LetStmt { syntax }),
+            EXPR_STMT => Stmt::ExprStmt(ExprStmt { syntax }),
+            _ => {
+                let item = Item::cast(syntax)?;
+                Stmt::Item(item)
+            }
+        };
+        Some(res)
+    }
+    fn syntax(&self) -> &SyntaxNode {
+        match self {
+            Stmt::LetStmt(it) => &it.syntax,
+            Stmt::ExprStmt(it) => &it.syntax,
+            Stmt::Item(it) => it.syntax(),
+        }
+    }
+}
diff --git a/crates/ra_syntax/src/ast/generated/nodes.rs b/crates/ra_syntax/src/ast/generated/nodes.rs
index 286be1032b8..763fd20f40a 100644
--- a/crates/ra_syntax/src/ast/generated/nodes.rs
+++ b/crates/ra_syntax/src/ast/generated/nodes.rs
@@ -348,7 +348,6 @@ pub struct BlockExpr {
     pub(crate) syntax: SyntaxNode,
 }
 impl ast::AttrsOwner for BlockExpr {}
-impl ast::ModuleItemOwner for BlockExpr {}
 impl BlockExpr {
     pub fn label(&self) -> Option<Label> { support::child(&self.syntax) }
     pub fn l_curly_token(&self) -> Option<SyntaxToken> { support::token(&self.syntax, T!['{']) }
@@ -1395,8 +1394,8 @@ impl ast::AttrsOwner for GenericParam {}
 pub enum Stmt {
     LetStmt(LetStmt),
     ExprStmt(ExprStmt),
+    Item(Item),
 }
-impl ast::AttrsOwner for Stmt {}
 impl AstNode for SourceFile {
     fn can_cast(kind: SyntaxKind) -> bool { kind == SOURCE_FILE }
     fn cast(syntax: SyntaxNode) -> Option<Self> {
@@ -3380,27 +3379,8 @@ impl From<LetStmt> for Stmt {
 impl From<ExprStmt> for Stmt {
     fn from(node: ExprStmt) -> Stmt { Stmt::ExprStmt(node) }
 }
-impl AstNode for Stmt {
-    fn can_cast(kind: SyntaxKind) -> bool {
-        match kind {
-            LET_STMT | EXPR_STMT => true,
-            _ => false,
-        }
-    }
-    fn cast(syntax: SyntaxNode) -> Option<Self> {
-        let res = match syntax.kind() {
-            LET_STMT => Stmt::LetStmt(LetStmt { syntax }),
-            EXPR_STMT => Stmt::ExprStmt(ExprStmt { syntax }),
-            _ => return None,
-        };
-        Some(res)
-    }
-    fn syntax(&self) -> &SyntaxNode {
-        match self {
-            Stmt::LetStmt(it) => &it.syntax,
-            Stmt::ExprStmt(it) => &it.syntax,
-        }
-    }
+impl From<Item> for Stmt {
+    fn from(node: Item) -> Stmt { Stmt::Item(node) }
 }
 impl std::fmt::Display for Item {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
diff --git a/xtask/src/codegen/gen_syntax.rs b/xtask/src/codegen/gen_syntax.rs
index d6a72ccc06c..d9f35851335 100644
--- a/xtask/src/codegen/gen_syntax.rs
+++ b/xtask/src/codegen/gen_syntax.rs
@@ -153,25 +153,10 @@ fn generate_nodes(kinds: KindsSrc<'_>, grammar: &AstSrc) -> Result<String> {
                 quote!(impl ast::#trait_name for #name {})
             });
 
-            (
-                quote! {
-                    #[pretty_doc_comment_placeholder_workaround]
-                    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
-                    pub enum #name {
-                        #(#variants(#variants),)*
-                    }
-
-                    #(#traits)*
-                },
+            let ast_node = if en.name == "Stmt" {
+                quote! {}
+            } else {
                 quote! {
-                    #(
-                    impl From<#variants> for #name {
-                        fn from(node: #variants) -> #name {
-                            #name::#variants(node)
-                        }
-                    }
-                    )*
-
                     impl AstNode for #name {
                         fn can_cast(kind: SyntaxKind) -> bool {
                             match kind {
@@ -196,6 +181,28 @@ fn generate_nodes(kinds: KindsSrc<'_>, grammar: &AstSrc) -> Result<String> {
                             }
                         }
                     }
+                }
+            };
+
+            (
+                quote! {
+                    #[pretty_doc_comment_placeholder_workaround]
+                    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
+                    pub enum #name {
+                        #(#variants(#variants),)*
+                    }
+
+                    #(#traits)*
+                },
+                quote! {
+                    #(
+                        impl From<#variants> for #name {
+                            fn from(node: #variants) -> #name {
+                                #name::#variants(node)
+                            }
+                        }
+                    )*
+                    #ast_node
                 },
             )
         })
@@ -497,13 +504,7 @@ fn lower(grammar: &Grammar) -> AstSrc {
     let mut res = AstSrc::default();
     res.tokens = vec!["Whitespace".into(), "Comment".into(), "String".into(), "RawString".into()];
 
-    let nodes = grammar
-        .iter()
-        .filter(|&node| match grammar[node].rule {
-            Rule::Node(it) if it == node => false,
-            _ => true,
-        })
-        .collect::<Vec<_>>();
+    let nodes = grammar.iter().collect::<Vec<_>>();
 
     for &node in &nodes {
         let name = grammar[node].name.clone();
@@ -693,6 +694,9 @@ fn extract_struct_trait(node: &mut AstNodeSrc, trait_name: &str, methods: &[&str
 
 fn extract_enum_traits(ast: &mut AstSrc) {
     for enm in &mut ast.enums {
+        if enm.name == "Stmt" {
+            continue;
+        }
         let nodes = &ast.nodes;
         let mut variant_traits = enm
             .variants
diff --git a/xtask/src/codegen/rust.ungram b/xtask/src/codegen/rust.ungram
index 8271509cf34..17de36d7a21 100644
--- a/xtask/src/codegen/rust.ungram
+++ b/xtask/src/codegen/rust.ungram
@@ -197,6 +197,7 @@ Attr =
 Stmt =
   LetStmt
 | ExprStmt
+| Item
 
 LetStmt =
   Attr* 'let' Pat (':' Type)?
@@ -316,7 +317,6 @@ Label =
 BlockExpr =
   Attr* Label
   '{'
-    Item*
     statements:Stmt*
     Expr?
   '}'