about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFlorian Diebold <flodiebold@gmail.com>2022-03-20 19:20:16 +0100
committerFlorian Diebold <flodiebold@gmail.com>2022-03-21 16:46:01 +0100
commit0689fdb650b43c7a5dc3bb27655b2df6879d8387 (patch)
treedc2054f07e6b9a5be0e6b1a64a8bbc3d33d7f510
parentab3313b1cb5e9ff79ecef0fb188873c892c193f1 (diff)
downloadrust-0689fdb650b43c7a5dc3bb27655b2df6879d8387.tar.gz
rust-0689fdb650b43c7a5dc3bb27655b2df6879d8387.zip
Add "add missing Ok/Some" fix
-rw-r--r--crates/hir/src/lib.rs28
-rw-r--r--crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs223
-rw-r--r--crates/ide_diagnostics/src/handlers/type_mismatch.rs228
3 files changed, 252 insertions, 227 deletions
diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs
index 035ae2d408b..3af56b743f3 100644
--- a/crates/hir/src/lib.rs
+++ b/crates/hir/src/lib.rs
@@ -1004,6 +1004,26 @@ impl Adt {
         Type::from_def(db, id.module(db.upcast()).krate(), id)
     }
 
+    /// Turns this ADT into a type with the given type parameters. This isn't
+    /// the greatest API, FIXME find a better one.
+    pub fn ty_with_args(self, db: &dyn HirDatabase, args: &[Type]) -> Type {
+        let id = AdtId::from(self);
+        let mut it = args.iter().map(|t| t.ty.clone());
+        let ty = TyBuilder::def_ty(db, id.into())
+            .fill(|x| {
+                let r = it.next().unwrap_or_else(|| TyKind::Error.intern(Interner));
+                match x {
+                    ParamKind::Type => GenericArgData::Ty(r).intern(Interner),
+                    ParamKind::Const(ty) => {
+                        unknown_const_as_generic(ty.clone())
+                    }
+                }
+            })
+            .build();
+        let krate = id.module(db.upcast()).krate();
+        Type::new(db, krate, id, ty)
+    }
+
     pub fn module(self, db: &dyn HirDatabase) -> Module {
         match self {
             Adt::Struct(s) => s.module(db),
@@ -1019,6 +1039,14 @@ impl Adt {
             Adt::Enum(e) => e.name(db),
         }
     }
+
+    pub fn as_enum(&self) -> Option<Enum> {
+        if let Self::Enum(v) = self {
+            Some(*v)
+        } else {
+            None
+        }
+    }
 }
 
 impl HasVisibility for Adt {
diff --git a/crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs b/crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs
deleted file mode 100644
index d5635ba8baf..00000000000
--- a/crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs
+++ /dev/null
@@ -1,223 +0,0 @@
-use hir::{db::AstDatabase, TypeInfo};
-use ide_db::{
-    assists::Assist, source_change::SourceChange, syntax_helpers::node_ext::for_each_tail_expr,
-};
-use syntax::AstNode;
-use text_edit::TextEdit;
-
-use crate::{fix, Diagnostic, DiagnosticsContext};
-
-// Diagnostic: missing-ok-or-some-in-tail-expr
-//
-// This diagnostic is triggered if a block that should return `Result` returns a value not wrapped in `Ok`,
-// or if a block that should return `Option` returns a value not wrapped in `Some`.
-//
-// Example:
-//
-// ```rust
-// fn foo() -> Result<u8, ()> {
-//     10
-// }
-// ```
-pub(crate) fn missing_ok_or_some_in_tail_expr(
-    ctx: &DiagnosticsContext<'_>,
-    d: &hir::MissingOkOrSomeInTailExpr,
-) -> Diagnostic {
-    Diagnostic::new(
-        "missing-ok-or-some-in-tail-expr",
-        format!("wrap return expression in {}", d.required),
-        ctx.sema.diagnostics_display_range(d.expr.clone().map(|it| it.into())).range,
-    )
-    .with_fixes(fixes(ctx, d))
-}
-
-fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingOkOrSomeInTailExpr) -> Option<Vec<Assist>> {
-    let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?;
-    let tail_expr = d.expr.value.to_node(&root);
-    let tail_expr_range = tail_expr.syntax().text_range();
-    let mut builder = TextEdit::builder();
-    for_each_tail_expr(&tail_expr, &mut |expr| {
-        if ctx.sema.type_of_expr(expr).map(TypeInfo::original).as_ref() != Some(&d.expected) {
-            builder.insert(expr.syntax().text_range().start(), format!("{}(", d.required));
-            builder.insert(expr.syntax().text_range().end(), ")".to_string());
-        }
-    });
-    let source_change =
-        SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), builder.finish());
-    let name = if d.required == "Ok" { "Wrap with Ok" } else { "Wrap with Some" };
-    Some(vec![fix("wrap_tail_expr", name, source_change, tail_expr_range)])
-}
-
-#[cfg(test)]
-mod tests {
-    use crate::tests::{check_diagnostics, check_fix};
-
-    #[test]
-    fn test_wrap_return_type_option() {
-        check_fix(
-            r#"
-//- minicore: option, result
-fn div(x: i32, y: i32) -> Option<i32> {
-    if y == 0 {
-        return None;
-    }
-    x / y$0
-}
-"#,
-            r#"
-fn div(x: i32, y: i32) -> Option<i32> {
-    if y == 0 {
-        return None;
-    }
-    Some(x / y)
-}
-"#,
-        );
-    }
-
-    #[test]
-    fn test_wrap_return_type_option_tails() {
-        check_fix(
-            r#"
-//- minicore: option, result
-fn div(x: i32, y: i32) -> Option<i32> {
-    if y == 0 {
-        0
-    } else if true {
-        100
-    } else {
-        None
-    }$0
-}
-"#,
-            r#"
-fn div(x: i32, y: i32) -> Option<i32> {
-    if y == 0 {
-        Some(0)
-    } else if true {
-        Some(100)
-    } else {
-        None
-    }
-}
-"#,
-        );
-    }
-
-    #[test]
-    fn test_wrap_return_type() {
-        check_fix(
-            r#"
-//- minicore: option, result
-fn div(x: i32, y: i32) -> Result<i32, ()> {
-    if y == 0 {
-        return Err(());
-    }
-    x / y$0
-}
-"#,
-            r#"
-fn div(x: i32, y: i32) -> Result<i32, ()> {
-    if y == 0 {
-        return Err(());
-    }
-    Ok(x / y)
-}
-"#,
-        );
-    }
-
-    #[test]
-    fn test_wrap_return_type_handles_generic_functions() {
-        check_fix(
-            r#"
-//- minicore: option, result
-fn div<T>(x: T) -> Result<T, i32> {
-    if x == 0 {
-        return Err(7);
-    }
-    $0x
-}
-"#,
-            r#"
-fn div<T>(x: T) -> Result<T, i32> {
-    if x == 0 {
-        return Err(7);
-    }
-    Ok(x)
-}
-"#,
-        );
-    }
-
-    #[test]
-    fn test_wrap_return_type_handles_type_aliases() {
-        check_fix(
-            r#"
-//- minicore: option, result
-type MyResult<T> = Result<T, ()>;
-
-fn div(x: i32, y: i32) -> MyResult<i32> {
-    if y == 0 {
-        return Err(());
-    }
-    x $0/ y
-}
-"#,
-            r#"
-type MyResult<T> = Result<T, ()>;
-
-fn div(x: i32, y: i32) -> MyResult<i32> {
-    if y == 0 {
-        return Err(());
-    }
-    Ok(x / y)
-}
-"#,
-        );
-    }
-
-    #[test]
-    fn test_in_const_and_static() {
-        check_fix(
-            r#"
-//- minicore: option, result
-static A: Option<()> = {($0)};
-            "#,
-            r#"
-static A: Option<()> = {Some(())};
-            "#,
-        );
-        check_fix(
-            r#"
-//- minicore: option, result
-const _: Option<()> = {($0)};
-            "#,
-            r#"
-const _: Option<()> = {Some(())};
-            "#,
-        );
-    }
-
-    #[test]
-    fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() {
-        check_diagnostics(
-            r#"
-//- minicore: option, result
-fn foo() -> Result<(), i32> { 0 }
-"#,
-        );
-    }
-
-    #[test]
-    fn test_wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() {
-        check_diagnostics(
-            r#"
-//- minicore: option, result
-enum SomeOtherEnum { Ok(i32), Err(String) }
-
-fn foo() -> SomeOtherEnum { 0 }
-"#,
-        );
-    }
-}
diff --git a/crates/ide_diagnostics/src/handlers/type_mismatch.rs b/crates/ide_diagnostics/src/handlers/type_mismatch.rs
index 2f8bda9efa7..571605ef266 100644
--- a/crates/ide_diagnostics/src/handlers/type_mismatch.rs
+++ b/crates/ide_diagnostics/src/handlers/type_mismatch.rs
@@ -1,5 +1,8 @@
-use hir::{db::AstDatabase, HirDisplay, Type};
-use ide_db::source_change::SourceChange;
+use hir::{db::AstDatabase, HirDisplay, Type, TypeInfo};
+use ide_db::{
+    famous_defs::FamousDefs, source_change::SourceChange,
+    syntax_helpers::node_ext::for_each_tail_expr,
+};
 use syntax::{AstNode, TextRange};
 use text_edit::TextEdit;
 
@@ -30,6 +33,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Option<Vec<Assi
     let mut fixes = Vec::new();
 
     add_reference(ctx, d, &mut fixes);
+    add_missing_ok_or_some(ctx, d, &mut fixes);
 
     if fixes.is_empty() {
         None
@@ -38,7 +42,11 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Option<Vec<Assi
     }
 }
 
-fn add_reference(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch, acc: &mut Vec<Assist>) -> Option<()> {
+fn add_reference(
+    ctx: &DiagnosticsContext<'_>,
+    d: &hir::TypeMismatch,
+    acc: &mut Vec<Assist>,
+) -> Option<()> {
     let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?;
     let expr_node = d.expr.value.to_node(&root);
 
@@ -59,9 +67,52 @@ fn add_reference(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch, acc: &mut
     Some(())
 }
 
+fn add_missing_ok_or_some(
+    ctx: &DiagnosticsContext<'_>,
+    d: &hir::TypeMismatch,
+    acc: &mut Vec<Assist>,
+) -> Option<()> {
+    let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?;
+    let tail_expr = d.expr.value.to_node(&root);
+    let tail_expr_range = tail_expr.syntax().text_range();
+    let scope = ctx.sema.scope(tail_expr.syntax());
+
+    let expected_adt = d.expected.as_adt()?;
+    let expected_enum = expected_adt.as_enum()?;
+
+    let famous_defs = FamousDefs(&ctx.sema, scope.krate());
+    let core_result = famous_defs.core_result_Result();
+    let core_option = famous_defs.core_option_Option();
+
+    if Some(expected_enum) != core_result && Some(expected_enum) != core_option {
+        return None;
+    }
+
+    let variant_name = if Some(expected_enum) == core_result { "Ok" } else { "Some" };
+
+    let wrapped_actual_ty = expected_adt.ty_with_args(ctx.sema.db, &[d.actual.clone()]);
+
+    if !d.expected.could_unify_with(ctx.sema.db, &wrapped_actual_ty) {
+        return None;
+    }
+
+    let mut builder = TextEdit::builder();
+    for_each_tail_expr(&tail_expr, &mut |expr| {
+        if ctx.sema.type_of_expr(expr).map(TypeInfo::adjusted).as_ref() != Some(&d.expected) {
+            builder.insert(expr.syntax().text_range().start(), format!("{}(", variant_name));
+            builder.insert(expr.syntax().text_range().end(), ")".to_string());
+        }
+    });
+    let source_change =
+        SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), builder.finish());
+    let name = format!("Wrap in {}", variant_name);
+    acc.push(fix("wrap_tail_expr", &name, source_change, tail_expr_range));
+    Some(())
+}
+
 #[cfg(test)]
 mod tests {
-    use crate::tests::{check_diagnostics, check_fix};
+    use crate::tests::{check_diagnostics, check_fix, check_no_fix};
 
     #[test]
     fn missing_reference() {
@@ -217,4 +268,173 @@ fn main() {
             "#,
         );
     }
+
+    #[test]
+    fn test_wrap_return_type_option() {
+        check_fix(
+            r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) -> Option<i32> {
+    if y == 0 {
+        return None;
+    }
+    x / y$0
+}
+"#,
+            r#"
+fn div(x: i32, y: i32) -> Option<i32> {
+    if y == 0 {
+        return None;
+    }
+    Some(x / y)
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn test_wrap_return_type_option_tails() {
+        check_fix(
+            r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) -> Option<i32> {
+    if y == 0 {
+        0
+    } else if true {
+        100
+    } else {
+        None
+    }$0
+}
+"#,
+            r#"
+fn div(x: i32, y: i32) -> Option<i32> {
+    if y == 0 {
+        Some(0)
+    } else if true {
+        Some(100)
+    } else {
+        None
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn test_wrap_return_type() {
+        check_fix(
+            r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) -> Result<i32, ()> {
+    if y == 0 {
+        return Err(());
+    }
+    x / y$0
+}
+"#,
+            r#"
+fn div(x: i32, y: i32) -> Result<i32, ()> {
+    if y == 0 {
+        return Err(());
+    }
+    Ok(x / y)
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn test_wrap_return_type_handles_generic_functions() {
+        check_fix(
+            r#"
+//- minicore: option, result
+fn div<T>(x: T) -> Result<T, i32> {
+    if x == 0 {
+        return Err(7);
+    }
+    $0x
+}
+"#,
+            r#"
+fn div<T>(x: T) -> Result<T, i32> {
+    if x == 0 {
+        return Err(7);
+    }
+    Ok(x)
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn test_wrap_return_type_handles_type_aliases() {
+        check_fix(
+            r#"
+//- minicore: option, result
+type MyResult<T> = Result<T, ()>;
+
+fn div(x: i32, y: i32) -> MyResult<i32> {
+    if y == 0 {
+        return Err(());
+    }
+    x $0/ y
+}
+"#,
+            r#"
+type MyResult<T> = Result<T, ()>;
+
+fn div(x: i32, y: i32) -> MyResult<i32> {
+    if y == 0 {
+        return Err(());
+    }
+    Ok(x / y)
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn test_in_const_and_static() {
+        check_fix(
+            r#"
+//- minicore: option, result
+static A: Option<()> = {($0)};
+            "#,
+            r#"
+static A: Option<()> = {Some(())};
+            "#,
+        );
+        check_fix(
+            r#"
+//- minicore: option, result
+const _: Option<()> = {($0)};
+            "#,
+            r#"
+const _: Option<()> = {Some(())};
+            "#,
+        );
+    }
+
+    #[test]
+    fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() {
+        check_no_fix(
+            r#"
+//- minicore: option, result
+fn foo() -> Result<(), i32> { 0$0 }
+"#,
+        );
+    }
+
+    #[test]
+    fn test_wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() {
+        check_no_fix(
+            r#"
+//- minicore: option, result
+enum SomeOtherEnum { Ok(i32), Err(String) }
+
+fn foo() -> SomeOtherEnum { 0$0 }
+"#,
+        );
+    }
 }