about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRyo Yoshida <low.ryoshida@gmail.com>2022-11-26 23:51:22 +0900
committerRyo Yoshida <low.ryoshida@gmail.com>2022-11-27 00:31:02 +0900
commit8e03f18e37d2782189391955bc56d3aebead81f5 (patch)
treed9e4bed16297ab5040cfd0366b032d8c3d2b4b39
parent822c61f559dc522dbd28f2886d20989a55613fc0 (diff)
downloadrust-8e03f18e37d2782189391955bc56d3aebead81f5.tar.gz
rust-8e03f18e37d2782189391955bc56d3aebead81f5.zip
fix: check if range contains tail expression
-rw-r--r--crates/ide-assists/src/handlers/extract_function.rs204
1 files changed, 183 insertions, 21 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index 10a3a33226b..0483cfdc646 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -11,7 +11,9 @@ use ide_db::{
     helpers::mod_path_to_ast,
     imports::insert_use::{insert_use, ImportScope},
     search::{FileReference, ReferenceCategory, SearchScope},
-    syntax_helpers::node_ext::{preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr},
+    syntax_helpers::node_ext::{
+        for_each_tail_expr, preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr,
+    },
     FxIndexSet, RootDatabase,
 };
 use itertools::Itertools;
@@ -78,7 +80,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
     };
 
     let body = extraction_target(&node, range)?;
-    let container_info = body.analyze_container(&ctx.sema)?;
+    let (container_info, contains_tail_expr) = body.analyze_container(&ctx.sema)?;
 
     let (locals_used, self_param) = body.analyze(&ctx.sema);
 
@@ -119,6 +121,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
                 ret_ty,
                 body,
                 outliving_locals,
+                contains_tail_expr,
                 mods: container_info,
             };
 
@@ -245,6 +248,8 @@ struct Function {
     ret_ty: RetType,
     body: FunctionBody,
     outliving_locals: Vec<OutlivedLocal>,
+    /// Whether at least one of the container's tail expr is contained in the range we're extracting.
+    contains_tail_expr: bool,
     mods: ContainerInfo,
 }
 
@@ -294,7 +299,6 @@ struct ControlFlow {
 #[derive(Clone, Debug)]
 struct ContainerInfo {
     is_const: bool,
-    is_in_tail: bool,
     parent_loop: Option<SyntaxNode>,
     /// The function's return type, const's type etc.
     ret_type: Option<hir::Type>,
@@ -743,7 +747,10 @@ impl FunctionBody {
         (res, self_param)
     }
 
-    fn analyze_container(&self, sema: &Semantics<'_, RootDatabase>) -> Option<ContainerInfo> {
+    fn analyze_container(
+        &self,
+        sema: &Semantics<'_, RootDatabase>,
+    ) -> Option<(ContainerInfo, bool)> {
         let mut ancestors = self.parent()?.ancestors();
         let infer_expr_opt = |expr| sema.type_of_expr(&expr?).map(TypeInfo::adjusted);
         let mut parent_loop = None;
@@ -815,28 +822,36 @@ impl FunctionBody {
                 }
             };
         };
-        let container_tail = match expr? {
-            ast::Expr::BlockExpr(block) => block.tail_expr(),
-            expr => Some(expr),
-        };
-        let is_in_tail =
-            container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
-                container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())
+
+        let expr = expr?;
+        let contains_tail_expr = if let Some(body_tail) = self.tail_expr() {
+            let mut contains_tail_expr = false;
+            let tail_expr_range = body_tail.syntax().text_range();
+            for_each_tail_expr(&expr, &mut |e| {
+                if tail_expr_range.contains_range(e.syntax().text_range()) {
+                    contains_tail_expr = true;
+                }
             });
+            contains_tail_expr
+        } else {
+            false
+        };
 
         let parent = self.parent()?;
         let parents = generic_parents(&parent);
         let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect();
         let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect();
 
-        Some(ContainerInfo {
-            is_in_tail,
-            is_const,
-            parent_loop,
-            ret_type: ty,
-            generic_param_lists,
-            where_clauses,
-        })
+        Some((
+            ContainerInfo {
+                is_const,
+                parent_loop,
+                ret_type: ty,
+                generic_param_lists,
+                where_clauses,
+            },
+            contains_tail_expr,
+        ))
     }
 
     fn return_ty(&self, ctx: &AssistContext<'_>) -> Option<RetType> {
@@ -1633,7 +1648,7 @@ impl Function {
 
     fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option<ast::RetType> {
         let fun_ty = self.return_type(ctx);
-        let handler = if self.mods.is_in_tail {
+        let handler = if self.contains_tail_expr {
             FlowHandler::None
         } else {
             FlowHandler::from_ret_ty(self, &fun_ty)
@@ -1707,7 +1722,7 @@ fn make_body(
     fun: &Function,
 ) -> ast::BlockExpr {
     let ret_ty = fun.return_type(ctx);
-    let handler = if fun.mods.is_in_tail {
+    let handler = if fun.contains_tail_expr {
         FlowHandler::None
     } else {
         FlowHandler::from_ret_ty(fun, &ret_ty)
@@ -5585,4 +5600,151 @@ fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
 "#,
         );
     }
+
+    #[test]
+    fn non_tail_expr_of_tail_expr_loop() {
+        check_assist(
+            extract_function,
+            r#"
+pub fn f() {
+    loop {
+        $0if true {
+            continue;
+        }$0
+
+        if false {
+            break;
+        }
+    }
+}
+"#,
+            r#"
+pub fn f() {
+    loop {
+        if let ControlFlow::Break(_) = fun_name() {
+            continue;
+        }
+
+        if false {
+            break;
+        }
+    }
+}
+
+fn $0fun_name() -> ControlFlow<()> {
+    if true {
+        return ControlFlow::Break(());
+    }
+    ControlFlow::Continue(())
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn non_tail_expr_of_tail_if_block() {
+        // FIXME: double semicolon
+        check_assist(
+            extract_function,
+            r#"
+//- minicore: option, try
+impl<T> core::ops::Try for Option<T> {
+    type Output = T;
+    type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+    if true {
+        let a = $0if true {
+            Some(())?
+        } else {
+            ()
+        }$0;
+        Some(a)
+    } else {
+        None
+    }
+}
+"#,
+            r#"
+impl<T> core::ops::Try for Option<T> {
+    type Output = T;
+    type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+    if true {
+        let a = fun_name()?;;
+        Some(a)
+    } else {
+        None
+    }
+}
+
+fn $0fun_name() -> Option<()> {
+    Some(if true {
+        Some(())?
+    } else {
+        ()
+    })
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn tail_expr_of_tail_block_nested() {
+        check_assist(
+            extract_function,
+            r#"
+//- minicore: option, try
+impl<T> core::ops::Try for Option<T> {
+    type Output = T;
+    type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+    if true {
+        $0{
+            let a = if true {
+                Some(())?
+            } else {
+                ()
+            };
+            Some(a)
+        }$0
+    } else {
+        None
+    }
+}
+"#,
+            r#"
+impl<T> core::ops::Try for Option<T> {
+    type Output = T;
+    type Residual = Option<!>;
+}
+impl<T> core::ops::FromResidual for Option<T> {}
+
+fn f() -> Option<()> {
+    if true {
+        fun_name()?
+    } else {
+        None
+    }
+}
+
+fn $0fun_name() -> Option<()> {
+    let a = if true {
+        Some(())?
+    } else {
+        ()
+    };
+    Some(a)
+}
+"#,
+        );
+    }
 }