about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2022-11-27 12:18:42 +0000
committerbors <bors@rust-lang.org>2022-11-27 12:18:42 +0000
commit6d61be8e65ac0fd45eaf178e1f7a1ec6b582de1f (patch)
treebfb92430118ee60900e620f3f0742298e34e2d2a
parent34e2bc6a541cca670307cec94cfc5546016705d6 (diff)
parent8e03f18e37d2782189391955bc56d3aebead81f5 (diff)
downloadrust-6d61be8e65ac0fd45eaf178e1f7a1ec6b582de1f.tar.gz
rust-6d61be8e65ac0fd45eaf178e1f7a1ec6b582de1f.zip
Auto merge of #13681 - lowr:fix/extract-function-tail-expr, r=Veykril
fix: check tail expressions more precisely in `extract_function`

Fixes #13620

When extracting expressions with control flows into a function, we can avoid wrapping tail expressions in `Option` or `Result` when they are also tail expressions of the container we're extracting from (see #7840, #9773). This is controlled by `ContainerInfo::is_in_tail`, but we've been computing it by checking if the tail expression of the range to extract is contained in the container's syntactically last expression, which may be a block that contains both tail and non-tail expressions (e.g. in #13620, the range to be extracted is not a tail expression but we set the flag to true).

This PR tries to compute the flag as precise as possible by utilizing `for_each_tail_expr()` (and also moves the flag to `Function` struct as it's more of a property of the function to be extracted than of the container).
-rw-r--r--crates/ide-assists/src/handlers/extract_function.rs210
1 files changed, 186 insertions, 24 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index c1e2f19ab18..0483cfdc646 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -11,7 +11,9 @@ use ide_db::{
     helpers::mod_path_to_ast,
     imports::insert_use::{insert_use, ImportScope},
     search::{FileReference, ReferenceCategory, SearchScope},
-    syntax_helpers::node_ext::{preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr},
+    syntax_helpers::node_ext::{
+        for_each_tail_expr, preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr,
+    },
     FxIndexSet, RootDatabase,
 };
 use itertools::Itertools;
@@ -78,7 +80,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
     };
 
     let body = extraction_target(&node, range)?;
-    let container_info = body.analyze_container(&ctx.sema)?;
+    let (container_info, contains_tail_expr) = body.analyze_container(&ctx.sema)?;
 
     let (locals_used, self_param) = body.analyze(&ctx.sema);
 
@@ -119,6 +121,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
                 ret_ty,
                 body,
                 outliving_locals,
+                contains_tail_expr,
                 mods: container_info,
             };
 
@@ -245,6 +248,8 @@ struct Function {
     ret_ty: RetType,
     body: FunctionBody,
     outliving_locals: Vec<OutlivedLocal>,
+    /// Whether at least one of the container's tail expr is contained in the range we're extracting.
+    contains_tail_expr: bool,
     mods: ContainerInfo,
 }
 
@@ -265,7 +270,7 @@ enum ParamKind {
     MutRef,
 }
 
-#[derive(Debug, Eq, PartialEq)]
+#[derive(Debug)]
 enum FunType {
     Unit,
     Single(hir::Type),
@@ -294,7 +299,6 @@ struct ControlFlow {
 #[derive(Clone, Debug)]
 struct ContainerInfo {
     is_const: bool,
-    is_in_tail: bool,
     parent_loop: Option<SyntaxNode>,
     /// The function's return type, const's type etc.
     ret_type: Option<hir::Type>,
@@ -743,7 +747,10 @@ impl FunctionBody {
         (res, self_param)
     }
 
-    fn analyze_container(&self, sema: &Semantics<'_, RootDatabase>) -> Option<ContainerInfo> {
+    fn analyze_container(
+        &self,
+        sema: &Semantics<'_, RootDatabase>,
+    ) -> Option<(ContainerInfo, bool)> {
         let mut ancestors = self.parent()?.ancestors();
         let infer_expr_opt = |expr| sema.type_of_expr(&expr?).map(TypeInfo::adjusted);
         let mut parent_loop = None;
@@ -815,28 +822,36 @@ impl FunctionBody {
                 }
             };
         };
-        let container_tail = match expr? {
-            ast::Expr::BlockExpr(block) => block.tail_expr(),
-            expr => Some(expr),
-        };
-        let is_in_tail =
-            container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
-                container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())
+
+        let expr = expr?;
+        let contains_tail_expr = if let Some(body_tail) = self.tail_expr() {
+            let mut contains_tail_expr = false;
+            let tail_expr_range = body_tail.syntax().text_range();
+            for_each_tail_expr(&expr, &mut |e| {
+                if tail_expr_range.contains_range(e.syntax().text_range()) {
+                    contains_tail_expr = true;
+                }
             });
+            contains_tail_expr
+        } else {
+            false
+        };
 
         let parent = self.parent()?;
         let parents = generic_parents(&parent);
         let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect();
         let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect();
 
-        Some(ContainerInfo {
-            is_in_tail,
-            is_const,
-            parent_loop,
-            ret_type: ty,
-            generic_param_lists,
-            where_clauses,
-        })
+        Some((
+            ContainerInfo {
+                is_const,
+                parent_loop,
+                ret_type: ty,
+                generic_param_lists,
+                where_clauses,
+            },
+            contains_tail_expr,
+        ))
     }
 
     fn return_ty(&self, ctx: &AssistContext<'_>) -> Option<RetType> {
@@ -1368,7 +1383,7 @@ impl FlowHandler {
             None => FlowHandler::None,
             Some(flow_kind) => {
                 let action = flow_kind.clone();
-                if *ret_ty == FunType::Unit {
+                if let FunType::Unit = ret_ty {
                     match flow_kind {
                         FlowKind::Return(None)
                         | FlowKind::Break(_, None)
@@ -1633,7 +1648,7 @@ impl Function {
 
     fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option<ast::RetType> {
         let fun_ty = self.return_type(ctx);
-        let handler = if self.mods.is_in_tail {
+        let handler = if self.contains_tail_expr {
             FlowHandler::None
         } else {
             FlowHandler::from_ret_ty(self, &fun_ty)
@@ -1707,7 +1722,7 @@ fn make_body(
     fun: &Function,
 ) -> ast::BlockExpr {
     let ret_ty = fun.return_type(ctx);
-    let handler = if fun.mods.is_in_tail {
+    let handler = if fun.contains_tail_expr {
         FlowHandler::None
     } else {
         FlowHandler::from_ret_ty(fun, &ret_ty)
@@ -1946,7 +1961,7 @@ fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) {
                 if nested_scope.is_none() {
                     if let Some(expr) = ast::Expr::cast(e.clone()) {
                         match expr {
-                            ast::Expr::ReturnExpr(return_expr) if nested_scope.is_none() => {
+                            ast::Expr::ReturnExpr(return_expr) => {
                                 let expr = return_expr.expr();
                                 if let Some(replacement) = make_rewritten_flow(handler, expr) {
                                     ted::replace(return_expr.syntax(), replacement.syntax())
@@ -5585,4 +5600,151 @@ fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
 "#,
         );
     }
+
+    #[test]
+    fn non_tail_expr_of_tail_expr_loop() {
+        check_assist(
+            extract_function,
+            r#"
+pub fn f() {
+    loop {
+        $0if true {
+            continue;
+        }$0
+
+        if false {
+            break;
+        }
+    }
+}
+"#,
+            r#"
+pub fn f() {
+    loop {
+        if let ControlFlow::Break(_) = fun_name() {
+            continue;
+        }
+
+        if false {
+            break;
+        }
+    }
+}
+
+fn $0fun_name() -> ControlFlow<()> {
+    if true {
+        return ControlFlow::Break(());
+    }
+    ControlFlow::Continue(())
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn non_tail_expr_of_tail_if_block() {
+        // FIXME: double semicolon
+        check_assist(
+            extract_function,
+            r#"
+//- minicore: option, try
+impl<T> core::ops::Try for Option<T> {
+    type Output = T;
+    type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+    if true {
+        let a = $0if true {
+            Some(())?
+        } else {
+            ()
+        }$0;
+        Some(a)
+    } else {
+        None
+    }
+}
+"#,
+            r#"
+impl<T> core::ops::Try for Option<T> {
+    type Output = T;
+    type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+    if true {
+        let a = fun_name()?;;
+        Some(a)
+    } else {
+        None
+    }
+}
+
+fn $0fun_name() -> Option<()> {
+    Some(if true {
+        Some(())?
+    } else {
+        ()
+    })
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn tail_expr_of_tail_block_nested() {
+        check_assist(
+            extract_function,
+            r#"
+//- minicore: option, try
+impl<T> core::ops::Try for Option<T> {
+    type Output = T;
+    type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+    if true {
+        $0{
+            let a = if true {
+                Some(())?
+            } else {
+                ()
+            };
+            Some(a)
+        }$0
+    } else {
+        None
+    }
+}
+"#,
+            r#"
+impl<T> core::ops::Try for Option<T> {
+    type Output = T;
+    type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+    if true {
+        fun_name()?
+    } else {
+        None
+    }
+}
+
+fn $0fun_name() -> Option<()> {
+    let a = if true {
+        Some(())?
+    } else {
+        ()
+    };
+    Some(a)
+}
+"#,
+        );
+    }
 }