about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide-assists/src/handlers/replace_or_with_or_else.rs75
1 files changed, 72 insertions, 3 deletions
diff --git a/crates/ide-assists/src/handlers/replace_or_with_or_else.rs b/crates/ide-assists/src/handlers/replace_or_with_or_else.rs
index b5b5798b8bc..96314263c97 100644
--- a/crates/ide-assists/src/handlers/replace_or_with_or_else.rs
+++ b/crates/ide-assists/src/handlers/replace_or_with_or_else.rs
@@ -1,6 +1,9 @@
-use ide_db::assists::{AssistId, AssistKind};
+use ide_db::{
+    assists::{AssistId, AssistKind},
+    famous_defs::FamousDefs,
+};
 use syntax::{
-    ast::{self, make, HasArgList},
+    ast::{self, make, Expr, HasArgList},
     AstNode,
 };
 
@@ -21,6 +24,9 @@ use crate::{AssistContext, Assists};
 // ```
 pub(crate) fn replace_or_with_or_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
     let call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
+
+    is_option_or_result(call.receiver()?, ctx)?;
+
     let (name, arg_list) = (call.name_ref()?, call.arg_list()?);
 
     let replace = match &*name.text() {
@@ -76,6 +82,8 @@ pub(crate) fn replace_or_with_or_else(acc: &mut Assists, ctx: &AssistContext<'_>
 pub(crate) fn replace_or_else_with_or(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
     let call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
 
+    is_option_or_result(call.receiver()?, ctx)?;
+
     let (name, arg_list) = (call.name_ref()?, call.arg_list()?);
 
     let replace = match &*name.text() {
@@ -115,9 +123,32 @@ pub(crate) fn replace_or_else_with_or(acc: &mut Assists, ctx: &AssistContext<'_>
     )
 }
 
+fn is_option_or_result(receiver: Expr, ctx: &AssistContext<'_>) -> Option<()> {
+    let ty = ctx.sema.type_of_expr(&receiver)?.adjusted().as_adt()?.as_enum()?;
+    let option_enum =
+        FamousDefs(&ctx.sema, ctx.sema.scope(receiver.syntax())?.krate()).core_option_Option();
+
+    if let Some(option_enum) = option_enum {
+        if ty == option_enum {
+            return Some(());
+        }
+    }
+
+    let result_enum =
+        FamousDefs(&ctx.sema, ctx.sema.scope(receiver.syntax())?.krate()).core_result_Result();
+
+    if let Some(result_enum) = result_enum {
+        if ty == result_enum {
+            return Some(());
+        }
+    }
+
+    None
+}
+
 #[cfg(test)]
 mod tests {
-    use crate::tests::check_assist;
+    use crate::tests::{check_assist, check_assist_not_applicable};
 
     use super::*;
 
@@ -126,6 +157,7 @@ mod tests {
         check_assist(
             replace_or_with_or_else,
             r#"
+//- minicore: option
 fn foo() {
     let foo = Some(1);
     return foo.unwrap_$0or(2);
@@ -145,6 +177,7 @@ fn foo() {
         check_assist(
             replace_or_with_or_else,
             r#"
+//- minicore: option
 fn foo() {
     let foo = Some(1);
     return foo.unwrap_$0or(x());
@@ -164,6 +197,7 @@ fn foo() {
         check_assist(
             replace_or_with_or_else,
             r#"
+//- minicore: option
 fn foo() {
     let foo = Some(1);
     return foo.unwrap_$0or({
@@ -195,6 +229,7 @@ fn foo() {
         check_assist(
             replace_or_else_with_or,
             r#"
+//- minicore: option
 fn foo() {
     let foo = Some(1);
     return foo.unwrap_$0or_else(|| 2);
@@ -214,6 +249,7 @@ fn foo() {
         check_assist(
             replace_or_else_with_or,
             r#"
+//- minicore: option
 fn foo() {
     let foo = Some(1);
     return foo.unwrap_$0or_else(x);
@@ -227,4 +263,37 @@ fn foo() {
 "#,
         )
     }
+
+    #[test]
+    fn replace_or_else_with_or_result() {
+        check_assist(
+            replace_or_else_with_or,
+            r#"
+//- minicore: result
+fn foo() {
+    let foo = Ok(1);
+    return foo.unwrap_$0or_else(x);
+}
+"#,
+            r#"
+fn foo() {
+    let foo = Ok(1);
+    return foo.unwrap_or(x());
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn replace_or_else_with_or_not_applicable() {
+        check_assist_not_applicable(
+            replace_or_else_with_or,
+            r#"
+fn foo() {
+    let foo = Ok(1);
+    return foo.unwrap_$0or_else(x);
+}
+"#,
+        )
+    }
 }