diff options
| author | Florian Diebold <flodiebold@gmail.com> | 2022-03-20 19:20:16 +0100 |
|---|---|---|
| committer | Florian Diebold <flodiebold@gmail.com> | 2022-03-21 16:46:01 +0100 |
| commit | 0689fdb650b43c7a5dc3bb27655b2df6879d8387 (patch) | |
| tree | dc2054f07e6b9a5be0e6b1a64a8bbc3d33d7f510 | |
| parent | ab3313b1cb5e9ff79ecef0fb188873c892c193f1 (diff) | |
| download | rust-0689fdb650b43c7a5dc3bb27655b2df6879d8387.tar.gz rust-0689fdb650b43c7a5dc3bb27655b2df6879d8387.zip | |
Add "add missing Ok/Some" fix
| -rw-r--r-- | crates/hir/src/lib.rs | 28 | ||||
| -rw-r--r-- | crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs | 223 | ||||
| -rw-r--r-- | crates/ide_diagnostics/src/handlers/type_mismatch.rs | 228 |
3 files changed, 252 insertions, 227 deletions
diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 035ae2d408b..3af56b743f3 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -1004,6 +1004,26 @@ impl Adt { Type::from_def(db, id.module(db.upcast()).krate(), id) } + /// Turns this ADT into a type with the given type parameters. This isn't + /// the greatest API, FIXME find a better one. + pub fn ty_with_args(self, db: &dyn HirDatabase, args: &[Type]) -> Type { + let id = AdtId::from(self); + let mut it = args.iter().map(|t| t.ty.clone()); + let ty = TyBuilder::def_ty(db, id.into()) + .fill(|x| { + let r = it.next().unwrap_or_else(|| TyKind::Error.intern(Interner)); + match x { + ParamKind::Type => GenericArgData::Ty(r).intern(Interner), + ParamKind::Const(ty) => { + unknown_const_as_generic(ty.clone()) + } + } + }) + .build(); + let krate = id.module(db.upcast()).krate(); + Type::new(db, krate, id, ty) + } + pub fn module(self, db: &dyn HirDatabase) -> Module { match self { Adt::Struct(s) => s.module(db), @@ -1019,6 +1039,14 @@ impl Adt { Adt::Enum(e) => e.name(db), } } + + pub fn as_enum(&self) -> Option<Enum> { + if let Self::Enum(v) = self { + Some(*v) + } else { + None + } + } } impl HasVisibility for Adt { diff --git a/crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs b/crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs deleted file mode 100644 index d5635ba8baf..00000000000 --- a/crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs +++ /dev/null @@ -1,223 +0,0 @@ -use hir::{db::AstDatabase, TypeInfo}; -use ide_db::{ - assists::Assist, source_change::SourceChange, syntax_helpers::node_ext::for_each_tail_expr, -}; -use syntax::AstNode; -use text_edit::TextEdit; - -use crate::{fix, Diagnostic, DiagnosticsContext}; - -// Diagnostic: missing-ok-or-some-in-tail-expr -// -// This diagnostic is triggered if a block that should return `Result` returns a value not wrapped in `Ok`, -// or if a block that should return `Option` returns a value not wrapped in `Some`. -// -// Example: -// -// ```rust -// fn foo() -> Result<u8, ()> { -// 10 -// } -// ``` -pub(crate) fn missing_ok_or_some_in_tail_expr( - ctx: &DiagnosticsContext<'_>, - d: &hir::MissingOkOrSomeInTailExpr, -) -> Diagnostic { - Diagnostic::new( - "missing-ok-or-some-in-tail-expr", - format!("wrap return expression in {}", d.required), - ctx.sema.diagnostics_display_range(d.expr.clone().map(|it| it.into())).range, - ) - .with_fixes(fixes(ctx, d)) -} - -fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingOkOrSomeInTailExpr) -> Option<Vec<Assist>> { - let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?; - let tail_expr = d.expr.value.to_node(&root); - let tail_expr_range = tail_expr.syntax().text_range(); - let mut builder = TextEdit::builder(); - for_each_tail_expr(&tail_expr, &mut |expr| { - if ctx.sema.type_of_expr(expr).map(TypeInfo::original).as_ref() != Some(&d.expected) { - builder.insert(expr.syntax().text_range().start(), format!("{}(", d.required)); - builder.insert(expr.syntax().text_range().end(), ")".to_string()); - } - }); - let source_change = - SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), builder.finish()); - let name = if d.required == "Ok" { "Wrap with Ok" } else { "Wrap with Some" }; - Some(vec![fix("wrap_tail_expr", name, source_change, tail_expr_range)]) -} - -#[cfg(test)] -mod tests { - use crate::tests::{check_diagnostics, check_fix}; - - #[test] - fn test_wrap_return_type_option() { - check_fix( - r#" -//- minicore: option, result -fn div(x: i32, y: i32) -> Option<i32> { - if y == 0 { - return None; - } - x / y$0 -} -"#, - r#" -fn div(x: i32, y: i32) -> Option<i32> { - if y == 0 { - return None; - } - Some(x / y) -} -"#, - ); - } - - #[test] - fn test_wrap_return_type_option_tails() { - check_fix( - r#" -//- minicore: option, result -fn div(x: i32, y: i32) -> Option<i32> { - if y == 0 { - 0 - } else if true { - 100 - } else { - None - }$0 -} -"#, - r#" -fn div(x: i32, y: i32) -> Option<i32> { - if y == 0 { - Some(0) - } else if true { - Some(100) - } else { - None - } -} -"#, - ); - } - - #[test] - fn test_wrap_return_type() { - check_fix( - r#" -//- minicore: option, result -fn div(x: i32, y: i32) -> Result<i32, ()> { - if y == 0 { - return Err(()); - } - x / y$0 -} -"#, - r#" -fn div(x: i32, y: i32) -> Result<i32, ()> { - if y == 0 { - return Err(()); - } - Ok(x / y) -} -"#, - ); - } - - #[test] - fn test_wrap_return_type_handles_generic_functions() { - check_fix( - r#" -//- minicore: option, result -fn div<T>(x: T) -> Result<T, i32> { - if x == 0 { - return Err(7); - } - $0x -} -"#, - r#" -fn div<T>(x: T) -> Result<T, i32> { - if x == 0 { - return Err(7); - } - Ok(x) -} -"#, - ); - } - - #[test] - fn test_wrap_return_type_handles_type_aliases() { - check_fix( - r#" -//- minicore: option, result -type MyResult<T> = Result<T, ()>; - -fn div(x: i32, y: i32) -> MyResult<i32> { - if y == 0 { - return Err(()); - } - x $0/ y -} -"#, - r#" -type MyResult<T> = Result<T, ()>; - -fn div(x: i32, y: i32) -> MyResult<i32> { - if y == 0 { - return Err(()); - } - Ok(x / y) -} -"#, - ); - } - - #[test] - fn test_in_const_and_static() { - check_fix( - r#" -//- minicore: option, result -static A: Option<()> = {($0)}; - "#, - r#" -static A: Option<()> = {Some(())}; - "#, - ); - check_fix( - r#" -//- minicore: option, result -const _: Option<()> = {($0)}; - "#, - r#" -const _: Option<()> = {Some(())}; - "#, - ); - } - - #[test] - fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() { - check_diagnostics( - r#" -//- minicore: option, result -fn foo() -> Result<(), i32> { 0 } -"#, - ); - } - - #[test] - fn test_wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() { - check_diagnostics( - r#" -//- minicore: option, result -enum SomeOtherEnum { Ok(i32), Err(String) } - -fn foo() -> SomeOtherEnum { 0 } -"#, - ); - } -} diff --git a/crates/ide_diagnostics/src/handlers/type_mismatch.rs b/crates/ide_diagnostics/src/handlers/type_mismatch.rs index 2f8bda9efa7..571605ef266 100644 --- a/crates/ide_diagnostics/src/handlers/type_mismatch.rs +++ b/crates/ide_diagnostics/src/handlers/type_mismatch.rs @@ -1,5 +1,8 @@ -use hir::{db::AstDatabase, HirDisplay, Type}; -use ide_db::source_change::SourceChange; +use hir::{db::AstDatabase, HirDisplay, Type, TypeInfo}; +use ide_db::{ + famous_defs::FamousDefs, source_change::SourceChange, + syntax_helpers::node_ext::for_each_tail_expr, +}; use syntax::{AstNode, TextRange}; use text_edit::TextEdit; @@ -30,6 +33,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Option<Vec<Assi let mut fixes = Vec::new(); add_reference(ctx, d, &mut fixes); + add_missing_ok_or_some(ctx, d, &mut fixes); if fixes.is_empty() { None @@ -38,7 +42,11 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Option<Vec<Assi } } -fn add_reference(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch, acc: &mut Vec<Assist>) -> Option<()> { +fn add_reference( + ctx: &DiagnosticsContext<'_>, + d: &hir::TypeMismatch, + acc: &mut Vec<Assist>, +) -> Option<()> { let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?; let expr_node = d.expr.value.to_node(&root); @@ -59,9 +67,52 @@ fn add_reference(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch, acc: &mut Some(()) } +fn add_missing_ok_or_some( + ctx: &DiagnosticsContext<'_>, + d: &hir::TypeMismatch, + acc: &mut Vec<Assist>, +) -> Option<()> { + let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?; + let tail_expr = d.expr.value.to_node(&root); + let tail_expr_range = tail_expr.syntax().text_range(); + let scope = ctx.sema.scope(tail_expr.syntax()); + + let expected_adt = d.expected.as_adt()?; + let expected_enum = expected_adt.as_enum()?; + + let famous_defs = FamousDefs(&ctx.sema, scope.krate()); + let core_result = famous_defs.core_result_Result(); + let core_option = famous_defs.core_option_Option(); + + if Some(expected_enum) != core_result && Some(expected_enum) != core_option { + return None; + } + + let variant_name = if Some(expected_enum) == core_result { "Ok" } else { "Some" }; + + let wrapped_actual_ty = expected_adt.ty_with_args(ctx.sema.db, &[d.actual.clone()]); + + if !d.expected.could_unify_with(ctx.sema.db, &wrapped_actual_ty) { + return None; + } + + let mut builder = TextEdit::builder(); + for_each_tail_expr(&tail_expr, &mut |expr| { + if ctx.sema.type_of_expr(expr).map(TypeInfo::adjusted).as_ref() != Some(&d.expected) { + builder.insert(expr.syntax().text_range().start(), format!("{}(", variant_name)); + builder.insert(expr.syntax().text_range().end(), ")".to_string()); + } + }); + let source_change = + SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), builder.finish()); + let name = format!("Wrap in {}", variant_name); + acc.push(fix("wrap_tail_expr", &name, source_change, tail_expr_range)); + Some(()) +} + #[cfg(test)] mod tests { - use crate::tests::{check_diagnostics, check_fix}; + use crate::tests::{check_diagnostics, check_fix, check_no_fix}; #[test] fn missing_reference() { @@ -217,4 +268,173 @@ fn main() { "#, ); } + + #[test] + fn test_wrap_return_type_option() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> Option<i32> { + if y == 0 { + return None; + } + x / y$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> Option<i32> { + if y == 0 { + return None; + } + Some(x / y) +} +"#, + ); + } + + #[test] + fn test_wrap_return_type_option_tails() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> Option<i32> { + if y == 0 { + 0 + } else if true { + 100 + } else { + None + }$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> Option<i32> { + if y == 0 { + Some(0) + } else if true { + Some(100) + } else { + None + } +} +"#, + ); + } + + #[test] + fn test_wrap_return_type() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> Result<i32, ()> { + if y == 0 { + return Err(()); + } + x / y$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> Result<i32, ()> { + if y == 0 { + return Err(()); + } + Ok(x / y) +} +"#, + ); + } + + #[test] + fn test_wrap_return_type_handles_generic_functions() { + check_fix( + r#" +//- minicore: option, result +fn div<T>(x: T) -> Result<T, i32> { + if x == 0 { + return Err(7); + } + $0x +} +"#, + r#" +fn div<T>(x: T) -> Result<T, i32> { + if x == 0 { + return Err(7); + } + Ok(x) +} +"#, + ); + } + + #[test] + fn test_wrap_return_type_handles_type_aliases() { + check_fix( + r#" +//- minicore: option, result +type MyResult<T> = Result<T, ()>; + +fn div(x: i32, y: i32) -> MyResult<i32> { + if y == 0 { + return Err(()); + } + x $0/ y +} +"#, + r#" +type MyResult<T> = Result<T, ()>; + +fn div(x: i32, y: i32) -> MyResult<i32> { + if y == 0 { + return Err(()); + } + Ok(x / y) +} +"#, + ); + } + + #[test] + fn test_in_const_and_static() { + check_fix( + r#" +//- minicore: option, result +static A: Option<()> = {($0)}; + "#, + r#" +static A: Option<()> = {Some(())}; + "#, + ); + check_fix( + r#" +//- minicore: option, result +const _: Option<()> = {($0)}; + "#, + r#" +const _: Option<()> = {Some(())}; + "#, + ); + } + + #[test] + fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() { + check_no_fix( + r#" +//- minicore: option, result +fn foo() -> Result<(), i32> { 0$0 } +"#, + ); + } + + #[test] + fn test_wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() { + check_no_fix( + r#" +//- minicore: option, result +enum SomeOtherEnum { Ok(i32), Err(String) } + +fn foo() -> SomeOtherEnum { 0$0 } +"#, + ); + } } |
