about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir/src/semantics.rs9
-rw-r--r--crates/syntax/src/ast.rs30
2 files changed, 31 insertions, 8 deletions
diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs
index e0d26103915..b801cd785e0 100644
--- a/crates/hir/src/semantics.rs
+++ b/crates/hir/src/semantics.rs
@@ -1472,14 +1472,7 @@ impl<'db> SemanticsImpl<'db> {
     }
 
     fn is_inside_unsafe(&self, expr: &ast::Expr) -> bool {
-        let item_or_variant = |ancestor: SyntaxNode| {
-            if ast::Item::can_cast(ancestor.kind()) {
-                ast::Item::cast(ancestor).map(Either::Left)
-            } else {
-                ast::Variant::cast(ancestor).map(Either::Right)
-            }
-        };
-        let Some(enclosing_item) = expr.syntax().ancestors().find_map(item_or_variant) else { return false };
+        let Some(enclosing_item) = expr.syntax().ancestors().find_map(Either::<ast::Item, ast::Variant>::cast) else { return false };
 
         let def = match &enclosing_item {
             Either::Left(ast::Item::Fn(it)) if it.unsafe_token().is_some() => return true,
diff --git a/crates/syntax/src/ast.rs b/crates/syntax/src/ast.rs
index 10c04575833..385a4e0a3ce 100644
--- a/crates/syntax/src/ast.rs
+++ b/crates/syntax/src/ast.rs
@@ -13,6 +13,8 @@ pub mod prec;
 
 use std::marker::PhantomData;
 
+use itertools::Either;
+
 use crate::{
     syntax_node::{SyntaxNode, SyntaxNodeChildren, SyntaxToken},
     SyntaxKind,
@@ -98,6 +100,34 @@ impl<N: AstNode> Iterator for AstChildren<N> {
     }
 }
 
+impl<L, R> AstNode for Either<L, R>
+where
+    L: AstNode,
+    R: AstNode,
+{
+    fn can_cast(kind: SyntaxKind) -> bool
+    where
+        Self: Sized,
+    {
+        L::can_cast(kind) || R::can_cast(kind)
+    }
+
+    fn cast(syntax: SyntaxNode) -> Option<Self>
+    where
+        Self: Sized,
+    {
+        if L::can_cast(syntax.kind()) {
+            L::cast(syntax).map(Either::Left)
+        } else {
+            R::cast(syntax).map(Either::Right)
+        }
+    }
+
+    fn syntax(&self) -> &SyntaxNode {
+        self.as_ref().either(L::syntax, R::syntax)
+    }
+}
+
 mod support {
     use super::{AstChildren, AstNode, SyntaxKind, SyntaxNode, SyntaxToken};