about summary refs log tree commit diff
path: root/compiler/rustc_infer/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_infer/src')
-rw-r--r--compiler/rustc_infer/src/infer/error_reporting/mod.rs21
-rw-r--r--compiler/rustc_infer/src/infer/error_reporting/nice_region_error/find_anon_type.rs69
-rw-r--r--compiler/rustc_infer/src/infer/error_reporting/suggest.rs88
-rw-r--r--compiler/rustc_infer/src/lib.rs1
4 files changed, 85 insertions, 94 deletions
diff --git a/compiler/rustc_infer/src/infer/error_reporting/mod.rs b/compiler/rustc_infer/src/infer/error_reporting/mod.rs
index 362ca3b4833..ea5c6b8c057 100644
--- a/compiler/rustc_infer/src/infer/error_reporting/mod.rs
+++ b/compiler/rustc_infer/src/infer/error_reporting/mod.rs
@@ -79,7 +79,7 @@ use rustc_middle::ty::{
 use rustc_span::{sym, symbol::kw, BytePos, DesugaringKind, Pos, Span};
 use rustc_target::spec::abi;
 use std::borrow::Cow;
-use std::ops::Deref;
+use std::ops::{ControlFlow, Deref};
 use std::path::PathBuf;
 use std::{cmp, fmt, iter};
 
@@ -2129,15 +2129,12 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
         let tykind = match self.tcx.opt_hir_node_by_def_id(trace.cause.body_id) {
             Some(hir::Node::Item(hir::Item { kind: hir::ItemKind::Fn(_, _, body_id), .. })) => {
                 let body = hir.body(*body_id);
-                struct LetVisitor<'v> {
+                struct LetVisitor {
                     span: Span,
-                    result: Option<&'v hir::Ty<'v>>,
                 }
-                impl<'v> Visitor<'v> for LetVisitor<'v> {
-                    fn visit_stmt(&mut self, s: &'v hir::Stmt<'v>) {
-                        if self.result.is_some() {
-                            return;
-                        }
+                impl<'v> Visitor<'v> for LetVisitor {
+                    type Result = ControlFlow<&'v hir::TyKind<'v>>;
+                    fn visit_stmt(&mut self, s: &'v hir::Stmt<'v>) -> Self::Result {
                         // Find a local statement where the initializer has
                         // the same span as the error and the type is specified.
                         if let hir::Stmt {
@@ -2151,13 +2148,13 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
                         } = s
                             && init_span == &self.span
                         {
-                            self.result = Some(*array_ty);
+                            ControlFlow::Break(&array_ty.peel_refs().kind)
+                        } else {
+                            ControlFlow::Continue(())
                         }
                     }
                 }
-                let mut visitor = LetVisitor { span, result: None };
-                visitor.visit_body(body);
-                visitor.result.map(|r| &r.peel_refs().kind)
+                LetVisitor { span }.visit_body(body).break_value()
             }
             Some(hir::Node::Item(hir::Item { kind: hir::ItemKind::Const(ty, _, _), .. })) => {
                 Some(&ty.peel_refs().kind)
diff --git a/compiler/rustc_infer/src/infer/error_reporting/nice_region_error/find_anon_type.rs b/compiler/rustc_infer/src/infer/error_reporting/nice_region_error/find_anon_type.rs
index 4f74365d06c..265a315a559 100644
--- a/compiler/rustc_infer/src/infer/error_reporting/nice_region_error/find_anon_type.rs
+++ b/compiler/rustc_infer/src/infer/error_reporting/nice_region_error/find_anon_type.rs
@@ -1,3 +1,4 @@
+use core::ops::ControlFlow;
 use rustc_hir as hir;
 use rustc_hir::intravisit::{self, Visitor};
 use rustc_middle::hir::map::Map;
@@ -43,14 +44,9 @@ fn find_component_for_bound_region<'tcx>(
     arg: &'tcx hir::Ty<'tcx>,
     br: &ty::BoundRegionKind,
 ) -> Option<&'tcx hir::Ty<'tcx>> {
-    let mut nested_visitor = FindNestedTypeVisitor {
-        tcx,
-        bound_region: *br,
-        found_type: None,
-        current_index: ty::INNERMOST,
-    };
-    nested_visitor.visit_ty(arg);
-    nested_visitor.found_type
+    FindNestedTypeVisitor { tcx, bound_region: *br, current_index: ty::INNERMOST }
+        .visit_ty(arg)
+        .break_value()
 }
 
 // The FindNestedTypeVisitor captures the corresponding `hir::Ty` of the
@@ -65,26 +61,24 @@ struct FindNestedTypeVisitor<'tcx> {
     // The bound_region corresponding to the Refree(freeregion)
     // associated with the anonymous region we are looking for.
     bound_region: ty::BoundRegionKind,
-    // The type where the anonymous lifetime appears
-    // for e.g., Vec<`&u8`> and <`&u8`>
-    found_type: Option<&'tcx hir::Ty<'tcx>>,
     current_index: ty::DebruijnIndex,
 }
 
 impl<'tcx> Visitor<'tcx> for FindNestedTypeVisitor<'tcx> {
+    type Result = ControlFlow<&'tcx hir::Ty<'tcx>>;
     type NestedFilter = nested_filter::OnlyBodies;
 
     fn nested_visit_map(&mut self) -> Self::Map {
         self.tcx.hir()
     }
 
-    fn visit_ty(&mut self, arg: &'tcx hir::Ty<'tcx>) {
+    fn visit_ty(&mut self, arg: &'tcx hir::Ty<'tcx>) -> Self::Result {
         match arg.kind {
             hir::TyKind::BareFn(_) => {
                 self.current_index.shift_in(1);
                 intravisit::walk_ty(self, arg);
                 self.current_index.shift_out(1);
-                return;
+                return ControlFlow::Continue(());
             }
 
             hir::TyKind::TraitObject(bounds, ..) => {
@@ -105,8 +99,7 @@ impl<'tcx> Visitor<'tcx> for FindNestedTypeVisitor<'tcx> {
                     (Some(rbv::ResolvedArg::EarlyBound(id)), ty::BrNamed(def_id, _)) => {
                         debug!("EarlyBound id={:?} def_id={:?}", id, def_id);
                         if id == def_id {
-                            self.found_type = Some(arg);
-                            return; // we can stop visiting now
+                            return ControlFlow::Break(arg);
                         }
                     }
 
@@ -123,8 +116,7 @@ impl<'tcx> Visitor<'tcx> for FindNestedTypeVisitor<'tcx> {
                         );
                         debug!("LateBound id={:?} def_id={:?}", id, def_id);
                         if debruijn_index == self.current_index && id == def_id {
-                            self.found_type = Some(arg);
-                            return; // we can stop visiting now
+                            return ControlFlow::Break(arg);
                         }
                     }
 
@@ -145,23 +137,30 @@ impl<'tcx> Visitor<'tcx> for FindNestedTypeVisitor<'tcx> {
             }
             // Checks if it is of type `hir::TyKind::Path` which corresponds to a struct.
             hir::TyKind::Path(_) => {
-                let subvisitor = &mut TyPathVisitor {
-                    tcx: self.tcx,
-                    found_it: false,
-                    bound_region: self.bound_region,
-                    current_index: self.current_index,
+                // Prefer using the lifetime in type arguments rather than lifetime arguments.
+                intravisit::walk_ty(self, arg)?;
+
+                // Call `walk_ty` as `visit_ty` is empty.
+                return if intravisit::walk_ty(
+                    &mut TyPathVisitor {
+                        tcx: self.tcx,
+                        bound_region: self.bound_region,
+                        current_index: self.current_index,
+                    },
+                    arg,
+                )
+                .is_break()
+                {
+                    ControlFlow::Break(arg)
+                } else {
+                    ControlFlow::Continue(())
                 };
-                intravisit::walk_ty(subvisitor, arg); // call walk_ty; as visit_ty is empty,
-                // this will visit only outermost type
-                if subvisitor.found_it {
-                    self.found_type = Some(arg);
-                }
             }
             _ => {}
         }
         // walk the embedded contents: e.g., if we are visiting `Vec<&Foo>`,
         // go on to visit `&Foo`
-        intravisit::walk_ty(self, arg);
+        intravisit::walk_ty(self, arg)
     }
 }
 
@@ -173,26 +172,25 @@ impl<'tcx> Visitor<'tcx> for FindNestedTypeVisitor<'tcx> {
 // specific part of the type in the error message.
 struct TyPathVisitor<'tcx> {
     tcx: TyCtxt<'tcx>,
-    found_it: bool,
     bound_region: ty::BoundRegionKind,
     current_index: ty::DebruijnIndex,
 }
 
 impl<'tcx> Visitor<'tcx> for TyPathVisitor<'tcx> {
+    type Result = ControlFlow<()>;
     type NestedFilter = nested_filter::OnlyBodies;
 
     fn nested_visit_map(&mut self) -> Map<'tcx> {
         self.tcx.hir()
     }
 
-    fn visit_lifetime(&mut self, lifetime: &hir::Lifetime) {
+    fn visit_lifetime(&mut self, lifetime: &hir::Lifetime) -> Self::Result {
         match (self.tcx.named_bound_var(lifetime.hir_id), self.bound_region) {
             // the lifetime of the TyPath!
             (Some(rbv::ResolvedArg::EarlyBound(id)), ty::BrNamed(def_id, _)) => {
                 debug!("EarlyBound id={:?} def_id={:?}", id, def_id);
                 if id == def_id {
-                    self.found_it = true;
-                    return; // we can stop visiting now
+                    return ControlFlow::Break(());
                 }
             }
 
@@ -201,8 +199,7 @@ impl<'tcx> Visitor<'tcx> for TyPathVisitor<'tcx> {
                 debug!("id={:?}", id);
                 debug!("def_id={:?}", def_id);
                 if debruijn_index == self.current_index && id == def_id {
-                    self.found_it = true;
-                    return; // we can stop visiting now
+                    return ControlFlow::Break(());
                 }
             }
 
@@ -220,9 +217,10 @@ impl<'tcx> Visitor<'tcx> for TyPathVisitor<'tcx> {
                 debug!("no arg found");
             }
         }
+        ControlFlow::Continue(())
     }
 
-    fn visit_ty(&mut self, arg: &'tcx hir::Ty<'tcx>) {
+    fn visit_ty(&mut self, arg: &'tcx hir::Ty<'tcx>) -> Self::Result {
         // ignore nested types
         //
         // If you have a type like `Foo<'a, &Ty>` we
@@ -231,5 +229,6 @@ impl<'tcx> Visitor<'tcx> for TyPathVisitor<'tcx> {
         // Making `visit_ty` empty will ignore the `&Ty` embedded
         // inside, it will get reached by the outer visitor.
         debug!("`Ty` corresponding to a struct is {:?}", arg);
+        ControlFlow::Continue(())
     }
 }
diff --git a/compiler/rustc_infer/src/infer/error_reporting/suggest.rs b/compiler/rustc_infer/src/infer/error_reporting/suggest.rs
index 472dab639d5..8cdf39b1739 100644
--- a/compiler/rustc_infer/src/infer/error_reporting/suggest.rs
+++ b/compiler/rustc_infer/src/infer/error_reporting/suggest.rs
@@ -1,4 +1,5 @@
 use crate::infer::error_reporting::hir::Path;
+use core::ops::ControlFlow;
 use hir::def::CtorKind;
 use hir::intravisit::{walk_expr, walk_stmt, Visitor};
 use hir::{Local, QPath};
@@ -563,62 +564,55 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
         cause: &ObligationCause<'_>,
         span: Span,
     ) -> Option<TypeErrorAdditionalDiags> {
-        let hir = self.tcx.hir();
-        if let Some(body_id) = self.tcx.hir().maybe_body_owned_by(cause.body_id) {
-            let body = hir.body(body_id);
-
-            /// Find the if expression with given span
-            struct IfVisitor {
-                pub result: bool,
-                pub found_if: bool,
-                pub err_span: Span,
-            }
-
-            impl<'v> Visitor<'v> for IfVisitor {
-                fn visit_expr(&mut self, ex: &'v hir::Expr<'v>) {
-                    if self.result {
-                        return;
-                    }
-                    match ex.kind {
-                        hir::ExprKind::If(cond, _, _) => {
-                            self.found_if = true;
-                            walk_expr(self, cond);
-                            self.found_if = false;
-                        }
-                        _ => walk_expr(self, ex),
-                    }
-                }
+        /// Find the if expression with given span
+        struct IfVisitor {
+            pub found_if: bool,
+            pub err_span: Span,
+        }
 
-                fn visit_stmt(&mut self, ex: &'v hir::Stmt<'v>) {
-                    if let hir::StmtKind::Local(hir::Local {
-                        span,
-                        pat: hir::Pat { .. },
-                        ty: None,
-                        init: Some(_),
-                        ..
-                    }) = &ex.kind
-                        && self.found_if
-                        && span.eq(&self.err_span)
-                    {
-                        self.result = true;
+        impl<'v> Visitor<'v> for IfVisitor {
+            type Result = ControlFlow<()>;
+            fn visit_expr(&mut self, ex: &'v hir::Expr<'v>) -> Self::Result {
+                match ex.kind {
+                    hir::ExprKind::If(cond, _, _) => {
+                        self.found_if = true;
+                        walk_expr(self, cond)?;
+                        self.found_if = false;
+                        ControlFlow::Continue(())
                     }
-                    walk_stmt(self, ex);
+                    _ => walk_expr(self, ex),
                 }
+            }
 
-                fn visit_body(&mut self, body: &'v hir::Body<'v>) {
-                    hir::intravisit::walk_body(self, body);
+            fn visit_stmt(&mut self, ex: &'v hir::Stmt<'v>) -> Self::Result {
+                if let hir::StmtKind::Local(hir::Local {
+                    span,
+                    pat: hir::Pat { .. },
+                    ty: None,
+                    init: Some(_),
+                    ..
+                }) = &ex.kind
+                    && self.found_if
+                    && span.eq(&self.err_span)
+                {
+                    ControlFlow::Break(())
+                } else {
+                    walk_stmt(self, ex)
                 }
             }
 
-            let mut visitor = IfVisitor { err_span: span, found_if: false, result: false };
-            visitor.visit_body(body);
-            if visitor.result {
-                return Some(TypeErrorAdditionalDiags::AddLetForLetChains {
-                    span: span.shrink_to_lo(),
-                });
+            fn visit_body(&mut self, body: &'v hir::Body<'v>) -> Self::Result {
+                hir::intravisit::walk_body(self, body)
             }
         }
-        None
+
+        self.tcx.hir().maybe_body_owned_by(cause.body_id).and_then(|body_id| {
+            let body = self.tcx.hir().body(body_id);
+            IfVisitor { err_span: span, found_if: false }
+                .visit_body(body)
+                .is_break()
+                .then(|| TypeErrorAdditionalDiags::AddLetForLetChains { span: span.shrink_to_lo() })
+        })
     }
 
     /// For "one type is more general than the other" errors on closures, suggest changing the lifetime
diff --git a/compiler/rustc_infer/src/lib.rs b/compiler/rustc_infer/src/lib.rs
index 97f9a4b291d..029bddda1e1 100644
--- a/compiler/rustc_infer/src/lib.rs
+++ b/compiler/rustc_infer/src/lib.rs
@@ -20,6 +20,7 @@
 #![allow(rustc::untranslatable_diagnostic)]
 #![feature(associated_type_bounds)]
 #![feature(box_patterns)]
+#![feature(control_flow_enum)]
 #![feature(extend_one)]
 #![feature(let_chains)]
 #![feature(if_let_guard)]