about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCamille GILLOT <gillot.camille@gmail.com>2022-04-20 20:13:42 +0200
committerCamille GILLOT <gillot.camille@gmail.com>2022-04-27 22:03:12 +0200
commit6857a8d14e1eeddda1e575f0d124ba2e6dbae0eb (patch)
tree1a448ab443783b11f8513cb3ffe0eaffb2d3aa60
parenta621b8499f61edfda48a87b993dde07e34d67998 (diff)
downloadrust-6857a8d14e1eeddda1e575f0d124ba2e6dbae0eb.tar.gz
rust-6857a8d14e1eeddda1e575f0d124ba2e6dbae0eb.zip
Create a specific struct for lifetime capture.
-rw-r--r--compiler/rustc_ast_lowering/src/lib.rs259
-rw-r--r--compiler/rustc_ast_lowering/src/path.rs52
2 files changed, 172 insertions, 139 deletions
diff --git a/compiler/rustc_ast_lowering/src/lib.rs b/compiler/rustc_ast_lowering/src/lib.rs
index bf1843a00ec..096cacff910 100644
--- a/compiler/rustc_ast_lowering/src/lib.rs
+++ b/compiler/rustc_ast_lowering/src/lib.rs
@@ -32,6 +32,7 @@
 
 #![feature(crate_visibility_modifier)]
 #![feature(box_patterns)]
+#![feature(let_chains)]
 #![feature(let_else)]
 #![feature(never_type)]
 #![recursion_limit = "256"]
@@ -135,24 +136,8 @@ struct LoweringContext<'a, 'hir: 'a> {
     /// Currently in-scope lifetimes defined in impl headers, fn headers, or HRTB.
     in_scope_lifetimes: Vec<(ParamName, LocalDefId)>,
 
-    /// Used to handle lifetimes appearing in impl-traits.  When we lower a lifetime,
-    /// it is inserted in the `FxHashMap`, and the resolution is modified so to point
-    /// to the lifetime parameter impl-trait will generate.
-    /// When traversing `for<...>` binders, they are inserted in the `FxHashSet` so
-    /// we know *not* to rebind the introduced lifetimes.
-    captured_lifetimes: Option<(
-        LocalDefId, // parent def_id for new definitions
-        FxHashMap<
-            LocalDefId, // original parameter id
-            (
-                Span,        // Span
-                NodeId,      // synthetized parameter id
-                ParamName,   // parameter name
-                LifetimeRes, // original resolution
-            ),
-        >,
-        FxHashSet<NodeId>, // traversed binders, to ignore
-    )>,
+    /// Used to handle lifetimes appearing in impl-traits.
+    captured_lifetimes: Option<LifetimeCaptureContext>,
 
     current_hir_id_owner: LocalDefId,
     item_local_id_counter: hir::ItemLocalId,
@@ -179,6 +164,9 @@ pub enum LifetimeRes {
         /// - a TraitRef's ref_id, identifying the `for<...>` binder;
         /// - a BareFn type's id;
         /// - a Path's id when this path has parenthesized generic args.
+        ///
+        /// This information is used for impl-trait lifetime captures, to know when to or not to
+        /// capture any given lifetime.
         binder: NodeId,
     },
     /// Created a generic parameter for an anonymous lifetime.
@@ -206,6 +194,28 @@ pub enum LifetimeRes {
     ElidedAnchor { start: NodeId, end: NodeId },
 }
 
+/// When we lower a lifetime, it is inserted in `captures`, and the resolution is modified so
+/// to point to the lifetime parameter impl-trait will generate.
+/// When traversing `for<...>` binders, they are inserted in `binders_to_ignore` so we know *not*
+/// to rebind the introduced lifetimes.
+#[derive(Debug)]
+struct LifetimeCaptureContext {
+    /// parent def_id for new definitions
+    parent_def_id: LocalDefId,
+    /// Set of lifetimes to rebind.
+    captures: FxHashMap<
+        LocalDefId, // original parameter id
+        (
+            Span,        // Span
+            NodeId,      // synthetized parameter id
+            ParamName,   // parameter name
+            LifetimeRes, // original resolution
+        ),
+    >,
+    /// Traversed binders.  The ids in this set should *not* be rebound.
+    binders_to_ignore: FxHashSet<NodeId>,
+}
+
 pub trait ResolverAstLowering {
     fn def_key(&self, id: DefId) -> DefKey;
 
@@ -790,6 +800,45 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         (lowered_generics, res)
     }
 
+    /// Setup lifetime capture for and impl-trait.
+    /// The captures will be added to `captures`.
+    fn while_capturing_lifetimes<T>(
+        &mut self,
+        parent_def_id: LocalDefId,
+        captures: &mut FxHashMap<LocalDefId, (Span, NodeId, ParamName, LifetimeRes)>,
+        f: impl FnOnce(&mut Self) -> T,
+    ) -> T {
+        let lifetime_stash = std::mem::replace(
+            &mut self.captured_lifetimes,
+            Some(LifetimeCaptureContext {
+                parent_def_id,
+                captures: std::mem::take(captures),
+                binders_to_ignore: Default::default(),
+            }),
+        );
+
+        let ret = f(self);
+
+        let ctxt = std::mem::replace(&mut self.captured_lifetimes, lifetime_stash).unwrap();
+        *captures = ctxt.captures;
+
+        ret
+    }
+
+    /// Register a binder to be ignored for lifetime capture.
+    #[tracing::instrument(level = "debug", skip(self, f))]
+    #[inline]
+    fn with_lifetime_binder<T>(&mut self, binder: NodeId, f: impl FnOnce(&mut Self) -> T) -> T {
+        if let Some(ctxt) = &mut self.captured_lifetimes {
+            ctxt.binders_to_ignore.insert(binder);
+        }
+        let ret = f(self);
+        if let Some(ctxt) = &mut self.captured_lifetimes {
+            ctxt.binders_to_ignore.remove(&binder);
+        }
+        ret
+    }
+
     fn with_dyn_type_scope<T>(&mut self, in_scope: bool, f: impl FnOnce(&mut Self) -> T) -> T {
         let was_in_dyn_type = self.is_in_dyn_type;
         self.is_in_dyn_type = in_scope;
@@ -1197,25 +1246,18 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
                 hir::TyKind::Rptr(lifetime, self.lower_mt(mt, itctx))
             }
             TyKind::BareFn(ref f) => self.with_in_scope_lifetime_defs(&f.generic_params, |this| {
-                if let Some((_, _, binders)) = &mut this.captured_lifetimes {
-                    binders.insert(t.id);
-                }
-
-                let ret = hir::TyKind::BareFn(this.arena.alloc(hir::BareFnTy {
-                    generic_params: this.lower_generic_params(
-                        &f.generic_params,
-                        ImplTraitContext::Disallowed(ImplTraitPosition::Generic),
-                    ),
-                    unsafety: this.lower_unsafety(f.unsafety),
-                    abi: this.lower_extern(f.ext),
-                    decl: this.lower_fn_decl(&f.decl, None, FnDeclKind::Pointer, None),
-                    param_names: this.lower_fn_params_to_names(&f.decl),
-                }));
-
-                if let Some((_, _, binders)) = &mut this.captured_lifetimes {
-                    binders.remove(&t.id);
-                }
-                ret
+                this.with_lifetime_binder(t.id, |this| {
+                    hir::TyKind::BareFn(this.arena.alloc(hir::BareFnTy {
+                        generic_params: this.lower_generic_params(
+                            &f.generic_params,
+                            ImplTraitContext::Disallowed(ImplTraitPosition::Generic),
+                        ),
+                        unsafety: this.lower_unsafety(f.unsafety),
+                        abi: this.lower_extern(f.ext),
+                        decl: this.lower_fn_decl(&f.decl, None, FnDeclKind::Pointer, None),
+                        param_names: this.lower_fn_params_to_names(&f.decl),
+                    }))
+                })
             }),
             TyKind::Never => hir::TyKind::Never,
             TyKind::Tup(ref tys) => {
@@ -1366,15 +1408,15 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
 
         let mut collected_lifetimes = FxHashMap::default();
         self.with_hir_id_owner(opaque_ty_node_id, |lctx| {
-            let capture_framework = if origin == hir::OpaqueTyOrigin::TyAlias {
-                None
+            let hir_bounds = if origin == hir::OpaqueTyOrigin::TyAlias {
+                lower_bounds(lctx)
             } else {
-                Some((opaque_ty_def_id, FxHashMap::default(), FxHashSet::default()))
+                lctx.while_capturing_lifetimes(
+                    opaque_ty_def_id,
+                    &mut collected_lifetimes,
+                    lower_bounds,
+                )
             };
-            let lifetime_stash = std::mem::replace(&mut lctx.captured_lifetimes, capture_framework);
-            let hir_bounds = lower_bounds(lctx);
-            collected_lifetimes = std::mem::replace(&mut lctx.captured_lifetimes, lifetime_stash)
-                .map_or_else(FxHashMap::default, |c| c.1);
             debug!(?collected_lifetimes);
 
             let lifetime_defs = lctx.arena.alloc_from_iter(collected_lifetimes.iter().map(
@@ -1716,24 +1758,20 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         debug!(?captures);
 
         self.with_hir_id_owner(opaque_ty_node_id, |this| {
-            let lifetime_stash = std::mem::replace(
-                &mut this.captured_lifetimes,
-                Some((opaque_ty_def_id, std::mem::take(&mut captures), FxHashSet::default())),
-            );
             debug!("lower_async_fn_ret_ty: lifetimes_to_define={:#?}", this.lifetimes_to_define);
-
-            // We have to be careful to get elision right here. The
-            // idea is that we create a lifetime parameter for each
-            // lifetime in the return type.  So, given a return type
-            // like `async fn foo(..) -> &[&u32]`, we lower to `impl
-            // Future<Output = &'1 [ &'2 u32 ]>`.
-            //
-            // Then, we will create `fn foo(..) -> Foo<'_, '_>`, and
-            // hence the elision takes place at the fn site.
             let future_bound =
-                this.lower_async_fn_output_type_to_future_bound(output, fn_def_id, span);
+                this.while_capturing_lifetimes(opaque_ty_def_id, &mut captures, |this| {
+                    // We have to be careful to get elision right here. The
+                    // idea is that we create a lifetime parameter for each
+                    // lifetime in the return type.  So, given a return type
+                    // like `async fn foo(..) -> &[&u32]`, we lower to `impl
+                    // Future<Output = &'1 [ &'2 u32 ]>`.
+                    //
+                    // Then, we will create `fn foo(..) -> Foo<'_, '_>`, and
+                    // hence the elision takes place at the fn site.
+                    this.lower_async_fn_output_type_to_future_bound(output, fn_def_id, span)
+                });
             debug!("lower_async_fn_ret_ty: future_bound={:#?}", future_bound);
-            captures = std::mem::replace(&mut this.captured_lifetimes, lifetime_stash).unwrap().1;
             debug!("lower_async_fn_ret_ty: captures={:#?}", captures);
 
             let generic_params =
@@ -1882,22 +1920,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
             LifetimeRes::Param { param, binder } => {
                 debug_assert_ne!(ident.name, kw::UnderscoreLifetime);
                 let p_name = ParamName::Plain(ident);
-                if let Some((parent_def_id, captures, binders)) = &mut self.captured_lifetimes {
+                if let Some(LifetimeCaptureContext { parent_def_id, captures, binders_to_ignore }) =
+                    &mut self.captured_lifetimes
+                    && !binders_to_ignore.contains(&binder)
+                {
                     match captures.entry(param) {
                         Entry::Occupied(_) => {}
                         Entry::Vacant(v) => {
-                            if !binders.contains(&binder) {
-                                let p_id = self.resolver.next_node_id();
-                                self.resolver.create_def(
-                                    *parent_def_id,
-                                    p_id,
-                                    DefPathData::LifetimeNs(p_name.ident().name),
-                                    ExpnId::root(),
-                                    span.with_parent(None),
-                                );
-
-                                v.insert((span, p_id, p_name, res));
-                            }
+                            let p_id = self.resolver.next_node_id();
+                            self.resolver.create_def(
+                                *parent_def_id,
+                                p_id,
+                                DefPathData::LifetimeNs(p_name.ident().name),
+                                ExpnId::root(),
+                                span.with_parent(None),
+                            );
+
+                            v.insert((span, p_id, p_name, res));
                         }
                     }
                 }
@@ -1908,24 +1947,25 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
                 // Only items are allowed to introduce fresh lifetimes,
                 // so we know `binder` has a `LocalDefId`.
                 let binder_def_id = self.resolver.local_def_id(binder);
-                if let Some((parent_def_id, captures, binders)) = &mut self.captured_lifetimes {
+                if let Some(LifetimeCaptureContext { parent_def_id, captures, binders_to_ignore }) =
+                    &mut self.captured_lifetimes
+                    && !binders_to_ignore.contains(&binder)
+                {
                     match captures.entry(param) {
                         Entry::Occupied(o) => param = self.resolver.local_def_id(o.get().1),
                         Entry::Vacant(v) => {
-                            if !binders.contains(&binder) {
-                                let p_id = self.resolver.next_node_id();
-                                let p_def_id = self.resolver.create_def(
-                                    *parent_def_id,
-                                    p_id,
-                                    DefPathData::LifetimeNs(kw::UnderscoreLifetime),
-                                    ExpnId::root(),
-                                    span.with_parent(None),
-                                );
-
-                                let p_name = ParamName::Fresh(param);
-                                v.insert((span, p_id, p_name, res));
-                                param = p_def_id;
-                            }
+                            let p_id = self.resolver.next_node_id();
+                            let p_def_id = self.resolver.create_def(
+                                *parent_def_id,
+                                p_id,
+                                DefPathData::LifetimeNs(kw::UnderscoreLifetime),
+                                ExpnId::root(),
+                                span.with_parent(None),
+                            );
+
+                            let p_name = ParamName::Fresh(param);
+                            v.insert((span, p_id, p_name, res));
+                            param = p_def_id;
                         }
                     }
                 } else if let Some(introducer) = introducer {
@@ -1948,21 +1988,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
                 } else {
                     hir::LifetimeName::Underscore
                 };
-                match &mut self.captured_lifetimes {
-                    Some((parent_def_id, captures, binders)) if !binders.contains(&binder) => {
-                        let p_id = self.resolver.next_node_id();
-                        let p_def_id = self.resolver.create_def(
-                            *parent_def_id,
-                            p_id,
-                            DefPathData::LifetimeNs(kw::UnderscoreLifetime),
-                            ExpnId::root(),
-                            span.with_parent(None),
-                        );
-                        let p_name = ParamName::Fresh(p_def_id);
-                        captures.insert(p_def_id, (span, p_id, p_name, res));
-                        hir::LifetimeName::Param(p_name)
-                    }
-                    _ => l_name,
+                if let Some(LifetimeCaptureContext { parent_def_id, captures, binders_to_ignore }) =
+                    &mut self.captured_lifetimes
+                    && !binders_to_ignore.contains(&binder)
+                {
+                    let p_id = self.resolver.next_node_id();
+                    let p_def_id = self.resolver.create_def(
+                        *parent_def_id,
+                        p_id,
+                        DefPathData::LifetimeNs(kw::UnderscoreLifetime),
+                        ExpnId::root(),
+                        span.with_parent(None),
+                    );
+                    let p_name = ParamName::Fresh(p_def_id);
+                    captures.insert(p_def_id, (span, p_id, p_name, res));
+                    hir::LifetimeName::Param(p_name)
+                } else {
+                    l_name
                 }
             }
             LifetimeRes::Static => hir::LifetimeName::Static,
@@ -2069,16 +2111,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
             self.lower_generic_params(&p.bound_generic_params, itctx.reborrow());
 
         let trait_ref = self.with_in_scope_lifetime_defs(&p.bound_generic_params, |this| {
-            if let Some((_, _, binders)) = &mut this.captured_lifetimes {
-                binders.insert(p.trait_ref.ref_id);
-            }
-
-            let trait_ref = this.lower_trait_ref(&p.trait_ref, itctx.reborrow());
-
-            if let Some((_, _, binders)) = &mut this.captured_lifetimes {
-                binders.remove(&p.trait_ref.ref_id);
-            }
-            trait_ref
+            this.with_lifetime_binder(p.trait_ref.ref_id, |this| {
+                this.lower_trait_ref(&p.trait_ref, itctx.reborrow())
+            })
         });
 
         hir::PolyTraitRef { bound_generic_params, trait_ref, span: self.lower_span(p.span) }
diff --git a/compiler/rustc_ast_lowering/src/path.rs b/compiler/rustc_ast_lowering/src/path.rs
index 315e2289b7a..3c9399c1fdf 100644
--- a/compiler/rustc_ast_lowering/src/path.rs
+++ b/compiler/rustc_ast_lowering/src/path.rs
@@ -353,33 +353,31 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         // a hidden lifetime parameter. This is needed for backwards
         // compatibility, even in contexts like an impl header where
         // we generally don't permit such things (see #51008).
-        if let Some((_, _, binders)) = &mut self.captured_lifetimes {
-            binders.insert(id);
-        }
-        let ParenthesizedArgs { span, inputs, inputs_span, output } = data;
-        let inputs = self.arena.alloc_from_iter(inputs.iter().map(|ty| {
-            self.lower_ty_direct(ty, ImplTraitContext::Disallowed(ImplTraitPosition::FnTraitParam))
-        }));
-        let output_ty = match output {
-            FnRetTy::Ty(ty) => {
-                self.lower_ty(&ty, ImplTraitContext::Disallowed(ImplTraitPosition::FnTraitReturn))
-            }
-            FnRetTy::Default(_) => self.arena.alloc(self.ty_tup(*span, &[])),
-        };
-        let args = smallvec![GenericArg::Type(self.ty_tup(*inputs_span, inputs))];
-        let binding = self.output_ty_binding(output_ty.span, output_ty);
-        if let Some((_, _, binders)) = &mut self.captured_lifetimes {
-            binders.remove(&id);
-        }
-        (
-            GenericArgsCtor {
-                args,
-                bindings: arena_vec![self; binding],
-                parenthesized: true,
-                span: data.inputs_span,
-            },
-            false,
-        )
+        self.with_lifetime_binder(id, |this| {
+            let ParenthesizedArgs { span, inputs, inputs_span, output } = data;
+            let inputs = this.arena.alloc_from_iter(inputs.iter().map(|ty| {
+                this.lower_ty_direct(
+                    ty,
+                    ImplTraitContext::Disallowed(ImplTraitPosition::FnTraitParam),
+                )
+            }));
+            let output_ty = match output {
+                FnRetTy::Ty(ty) => this
+                    .lower_ty(&ty, ImplTraitContext::Disallowed(ImplTraitPosition::FnTraitReturn)),
+                FnRetTy::Default(_) => this.arena.alloc(this.ty_tup(*span, &[])),
+            };
+            let args = smallvec![GenericArg::Type(this.ty_tup(*inputs_span, inputs))];
+            let binding = this.output_ty_binding(output_ty.span, output_ty);
+            (
+                GenericArgsCtor {
+                    args,
+                    bindings: arena_vec![this; binding],
+                    parenthesized: true,
+                    span: data.inputs_span,
+                },
+                false,
+            )
+        })
     }
 
     /// An associated type binding `Output = $ty`.