about summary refs log tree commit diff
diff options
context:
space:
mode:
authorSantiago Pastorino <spastorino@gmail.com>2022-07-20 16:30:37 -0300
committerSantiago Pastorino <spastorino@gmail.com>2022-08-04 11:26:51 -0300
commitcda2c04592850eb8248210beb8efd9c35de0dc19 (patch)
treec38443a35a424091b7091b90ec656ab278d46441
parent84a24a1b3c183066644be2db03c48545d796f01f (diff)
downloadrust-cda2c04592850eb8248210beb8efd9c35de0dc19.tar.gz
rust-cda2c04592850eb8248210beb8efd9c35de0dc19.zip
Explicitly gather lifetimes and definitions in RPIT
-rw-r--r--compiler/rustc_ast/src/ast.rs61
-rw-r--r--compiler/rustc_ast_lowering/src/lib.rs188
2 files changed, 207 insertions, 42 deletions
diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs
index 116497109f1..3503e5de8cc 100644
--- a/compiler/rustc_ast/src/ast.rs
+++ b/compiler/rustc_ast/src/ast.rs
@@ -25,7 +25,9 @@ pub use UnsafeSource::*;
 use crate::ptr::P;
 use crate::token::{self, CommentKind, Delimiter};
 use crate::tokenstream::{DelimSpan, LazyTokenStream, TokenStream};
+use crate::visit::{self, BoundKind, LifetimeCtxt, Visitor};
 
+use rustc_data_structures::fx::FxHashMap;
 use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
 use rustc_data_structures::stack::ensure_sufficient_stack;
 use rustc_data_structures::sync::Lrc;
@@ -64,7 +66,7 @@ impl fmt::Debug for Label {
 
 /// A "Lifetime" is an annotation of the scope in which variable
 /// can be used, e.g. `'a` in `&'a i32`.
-#[derive(Clone, Encodable, Decodable, Copy)]
+#[derive(Clone, Encodable, Decodable, Copy, PartialEq, Eq)]
 pub struct Lifetime {
     pub id: NodeId,
     pub ident: Ident,
@@ -323,6 +325,63 @@ impl GenericBound {
 
 pub type GenericBounds = Vec<GenericBound>;
 
+struct LifetimeCollectVisitor<'ast> {
+    current_binders: Vec<NodeId>,
+    binders_to_ignore: FxHashMap<NodeId, Vec<NodeId>>,
+    collected_lifetimes: Vec<&'ast Lifetime>,
+}
+
+impl<'ast> Visitor<'ast> for LifetimeCollectVisitor<'ast> {
+    fn visit_lifetime(&mut self, lifetime: &'ast Lifetime, _: LifetimeCtxt) {
+        if !self.collected_lifetimes.contains(&lifetime) {
+            self.collected_lifetimes.push(lifetime);
+        }
+        self.binders_to_ignore.insert(lifetime.id, self.current_binders.clone());
+    }
+
+    fn visit_poly_trait_ref(&mut self, t: &'ast PolyTraitRef, m: &'ast TraitBoundModifier) {
+        self.current_binders.push(t.trait_ref.ref_id);
+
+        visit::walk_poly_trait_ref(self, t, m);
+
+        self.current_binders.pop();
+    }
+
+    fn visit_ty(&mut self, t: &'ast Ty) {
+        if let TyKind::BareFn(_) = t.kind {
+            self.current_binders.push(t.id);
+        }
+        visit::walk_ty(self, t);
+        if let TyKind::BareFn(_) = t.kind {
+            self.current_binders.pop();
+        }
+    }
+}
+
+pub fn lifetimes_in_ret_ty(ret_ty: &FnRetTy) -> (Vec<&Lifetime>, FxHashMap<NodeId, Vec<NodeId>>) {
+    let mut visitor = LifetimeCollectVisitor {
+        current_binders: Vec::new(),
+        binders_to_ignore: FxHashMap::default(),
+        collected_lifetimes: Vec::new(),
+    };
+    visitor.visit_fn_ret_ty(ret_ty);
+    (visitor.collected_lifetimes, visitor.binders_to_ignore)
+}
+
+pub fn lifetimes_in_bounds(
+    bounds: &GenericBounds,
+) -> (Vec<&Lifetime>, FxHashMap<NodeId, Vec<NodeId>>) {
+    let mut visitor = LifetimeCollectVisitor {
+        current_binders: Vec::new(),
+        binders_to_ignore: FxHashMap::default(),
+        collected_lifetimes: Vec::new(),
+    };
+    for bound in bounds {
+        visitor.visit_param_bound(bound, BoundKind::Bound);
+    }
+    (visitor.collected_lifetimes, visitor.binders_to_ignore)
+}
+
 /// Specifies the enforced ordering for generic parameters. In the future,
 /// if we wanted to relax this order, we could override `PartialEq` and
 /// `PartialOrd`, to allow the kinds to be unordered.
diff --git a/compiler/rustc_ast_lowering/src/lib.rs b/compiler/rustc_ast_lowering/src/lib.rs
index 72e0466761a..dfd23dbd4d7 100644
--- a/compiler/rustc_ast_lowering/src/lib.rs
+++ b/compiler/rustc_ast_lowering/src/lib.rs
@@ -1304,17 +1304,17 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
             TyKind::ImplTrait(def_node_id, ref bounds) => {
                 let span = t.span;
                 match itctx {
-                    ImplTraitContext::ReturnPositionOpaqueTy { origin } => self
-                        .lower_opaque_impl_trait(span, origin, def_node_id, |this| {
-                            this.lower_param_bounds(bounds, itctx, true)
-                        }),
+                    ImplTraitContext::ReturnPositionOpaqueTy { origin } => {
+                        self.lower_opaque_impl_trait(span, origin, def_node_id, bounds, itctx)
+                    }
                     ImplTraitContext::TypeAliasesOpaqueTy => {
                         let nested_itctx = ImplTraitContext::TypeAliasesOpaqueTy;
                         self.lower_opaque_impl_trait(
                             span,
                             hir::OpaqueTyOrigin::TyAlias,
                             def_node_id,
-                            |this| this.lower_param_bounds(bounds, nested_itctx, true),
+                            bounds,
+                            nested_itctx,
                         )
                     }
                     ImplTraitContext::Universal => {
@@ -1354,13 +1354,14 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         hir::Ty { kind, span: self.lower_span(t.span), hir_id: self.lower_node_id(t.id) }
     }
 
-    #[tracing::instrument(level = "debug", skip(self, lower_bounds))]
+    #[tracing::instrument(level = "debug", skip(self))]
     fn lower_opaque_impl_trait(
         &mut self,
         span: Span,
         origin: hir::OpaqueTyOrigin,
         opaque_ty_node_id: NodeId,
-        lower_bounds: impl FnOnce(&mut Self) -> hir::GenericBounds<'hir>,
+        bounds: &GenericBounds,
+        itctx: ImplTraitContext,
     ) -> hir::TyKind<'hir> {
         // Make sure we know that some funky desugaring has been going on here.
         // This is a first: there is code in other places like for loop
@@ -1374,23 +1375,122 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         let mut collected_lifetimes = FxHashMap::default();
         self.with_hir_id_owner(opaque_ty_node_id, |lctx| {
             let hir_bounds = if origin == hir::OpaqueTyOrigin::TyAlias {
-                lower_bounds(lctx)
+                lctx.lower_param_bounds(bounds, itctx, true)
             } else {
-                let lifetime_stash = std::mem::replace(
-                    &mut lctx.captured_lifetimes,
-                    Some(LifetimeCaptureContext {
-                        parent_def_id: opaque_ty_def_id,
-                        captures: std::mem::take(&mut collected_lifetimes),
-                        binders_to_ignore: Default::default(),
-                    }),
-                );
+                if std::env::var("NEW_COLLECT_LIFETIMES").is_ok() {
+                    let lifetime_stash = std::mem::replace(
+                        &mut lctx.captured_lifetimes,
+                        Some(LifetimeCaptureContext {
+                            parent_def_id: opaque_ty_def_id,
+                            captures: std::mem::take(&mut collected_lifetimes),
+                            binders_to_ignore: Default::default(),
+                        }),
+                    );
+
+                    let (lifetimes_in_bounds, binders_to_ignore) = ast::lifetimes_in_bounds(bounds);
+
+                    for lifetime in &lifetimes_in_bounds {
+                        let ident = lifetime.ident;
+                        let span = ident.span;
+
+                        let res = lctx
+                            .resolver
+                            .get_lifetime_res(lifetime.id)
+                            .unwrap_or(LifetimeRes::Error);
+
+                        if let Some(mut captured_lifetimes) = lctx.captured_lifetimes.take() {
+                            match res {
+                                LifetimeRes::Param { param, binder } => {
+                                    if !captured_lifetimes.binders_to_ignore.contains(&binder)
+                                        && !binders_to_ignore
+                                            .get(&lifetime.id)
+                                            .unwrap_or(&Vec::new())
+                                            .contains(&binder)
+                                    {
+                                        match captured_lifetimes.captures.entry(param) {
+                                            Entry::Occupied(_) => {}
+                                            Entry::Vacant(v) => {
+                                                let node_id = lctx.next_node_id();
+                                                let name = ParamName::Plain(ident);
+
+                                                lctx.create_def(
+                                                    captured_lifetimes.parent_def_id,
+                                                    node_id,
+                                                    DefPathData::LifetimeNs(name.ident().name),
+                                                );
+
+                                                v.insert((span, node_id, name, res));
+                                            }
+                                        }
+                                    }
+                                }
+
+                                LifetimeRes::Fresh { param, binder } => {
+                                    debug_assert_eq!(ident.name, kw::UnderscoreLifetime);
+                                    if !captured_lifetimes.binders_to_ignore.contains(&binder)
+                                        && !binders_to_ignore
+                                            .get(&lifetime.id)
+                                            .unwrap_or(&Vec::new())
+                                            .contains(&binder)
+                                    {
+                                        let param = lctx.local_def_id(param);
+                                        match captured_lifetimes.captures.entry(param) {
+                                            Entry::Occupied(_) => {}
+                                            Entry::Vacant(v) => {
+                                                let node_id = lctx.next_node_id();
+
+                                                let name = ParamName::Fresh;
+
+                                                lctx.create_def(
+                                                    captured_lifetimes.parent_def_id,
+                                                    node_id,
+                                                    DefPathData::LifetimeNs(kw::UnderscoreLifetime),
+                                                );
+
+                                                v.insert((span, node_id, name, res));
+                                            }
+                                        }
+                                    }
+                                }
+
+                                LifetimeRes::Infer | LifetimeRes::Static | LifetimeRes::Error => {}
+
+                                res => panic!(
+                                    "Unexpected lifetime resolution {:?} for {:?} at {:?}",
+                                    res, lifetime.ident, lifetime.ident.span
+                                ),
+                            }
+
+                            lctx.captured_lifetimes = Some(captured_lifetimes);
+                        }
+                    }
+
+                    let ret = lctx.lower_param_bounds(bounds, itctx, false);
+
+                    let ctxt =
+                        std::mem::replace(&mut lctx.captured_lifetimes, lifetime_stash).unwrap();
+
+                    collected_lifetimes = ctxt.captures;
+
+                    ret
+                } else {
+                    let lifetime_stash = std::mem::replace(
+                        &mut lctx.captured_lifetimes,
+                        Some(LifetimeCaptureContext {
+                            parent_def_id: opaque_ty_def_id,
+                            captures: std::mem::take(&mut collected_lifetimes),
+                            binders_to_ignore: Default::default(),
+                        }),
+                    );
 
-                let ret = lower_bounds(lctx);
+                    let ret = lctx.lower_param_bounds(bounds, itctx, true);
 
-                let ctxt = std::mem::replace(&mut lctx.captured_lifetimes, lifetime_stash).unwrap();
-                collected_lifetimes = ctxt.captures;
+                    let ctxt =
+                        std::mem::replace(&mut lctx.captured_lifetimes, lifetime_stash).unwrap();
+                    collected_lifetimes = ctxt.captures;
 
-                ret
+                    ret
+                }
             };
             debug!(?collected_lifetimes);
 
@@ -1855,16 +1955,18 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         captures: bool,
     ) -> hir::Lifetime {
         debug!(?self.captured_lifetimes);
+
         let name = match res {
             LifetimeRes::Param { mut param, binder } => {
                 let p_name = ParamName::Plain(ident);
-                if captures {
-                    if let Some(mut captured_lifetimes) = self.captured_lifetimes.take() {
+                if let Some(mut captured_lifetimes) = self.captured_lifetimes.take() {
+                    if captures {
                         if !captured_lifetimes.binders_to_ignore.contains(&binder) {
                             match captured_lifetimes.captures.entry(param) {
                                 Entry::Occupied(o) => param = self.local_def_id(o.get().1),
                                 Entry::Vacant(v) => {
                                     let p_id = self.next_node_id();
+
                                     let p_def_id = self.create_def(
                                         captured_lifetimes.parent_def_id,
                                         p_id,
@@ -1876,36 +1978,40 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
                                 }
                             }
                         }
-
-                        self.captured_lifetimes = Some(captured_lifetimes);
+                    } else {
+                        if let Entry::Occupied(o) = captured_lifetimes.captures.entry(param) {
+                            param = self.local_def_id(o.get().1);
+                        }
                     }
+                    self.captured_lifetimes = Some(captured_lifetimes);
                 }
+
                 hir::LifetimeName::Param(param, p_name)
             }
             LifetimeRes::Fresh { param, binder } => {
                 debug_assert_eq!(ident.name, kw::UnderscoreLifetime);
+
                 let mut param = self.local_def_id(param);
-                if captures {
-                    if let Some(mut captured_lifetimes) = self.captured_lifetimes.take() {
-                        if !captured_lifetimes.binders_to_ignore.contains(&binder) {
-                            match captured_lifetimes.captures.entry(param) {
-                                Entry::Occupied(o) => param = self.local_def_id(o.get().1),
-                                Entry::Vacant(v) => {
-                                    let p_id = self.next_node_id();
-                                    let p_def_id = self.create_def(
-                                        captured_lifetimes.parent_def_id,
-                                        p_id,
-                                        DefPathData::LifetimeNs(kw::UnderscoreLifetime),
-                                    );
+                if let Some(mut captured_lifetimes) = self.captured_lifetimes.take() {
+                    if !captured_lifetimes.binders_to_ignore.contains(&binder) {
+                        match captured_lifetimes.captures.entry(param) {
+                            Entry::Occupied(o) => param = self.local_def_id(o.get().1),
+                            Entry::Vacant(v) => {
+                                let p_id = self.next_node_id();
+
+                                let p_def_id = self.create_def(
+                                    captured_lifetimes.parent_def_id,
+                                    p_id,
+                                    DefPathData::LifetimeNs(kw::UnderscoreLifetime),
+                                );
 
-                                    v.insert((span, p_id, ParamName::Fresh, res));
-                                    param = p_def_id;
-                                }
+                                v.insert((span, p_id, ParamName::Fresh, res));
+                                param = p_def_id;
                             }
                         }
-
-                        self.captured_lifetimes = Some(captured_lifetimes);
                     }
+
+                    self.captured_lifetimes = Some(captured_lifetimes);
                 }
                 hir::LifetimeName::Param(param, ParamName::Fresh)
             }