diff options
Diffstat (limited to 'compiler/rustc_ast_lowering/src')
| -rw-r--r-- | compiler/rustc_ast_lowering/src/errors.rs | 7 | ||||
| -rw-r--r-- | compiler/rustc_ast_lowering/src/index.rs | 17 | ||||
| -rw-r--r-- | compiler/rustc_ast_lowering/src/lib.rs | 136 | ||||
| -rw-r--r-- | compiler/rustc_ast_lowering/src/lifetime_collector.rs | 15 |
4 files changed, 127 insertions, 48 deletions
diff --git a/compiler/rustc_ast_lowering/src/errors.rs b/compiler/rustc_ast_lowering/src/errors.rs index 6fd980ed3ca..ca0821e2c9e 100644 --- a/compiler/rustc_ast_lowering/src/errors.rs +++ b/compiler/rustc_ast_lowering/src/errors.rs @@ -414,3 +414,10 @@ pub(crate) struct AsyncBoundOnlyForFnTraits { #[primary_span] pub span: Span, } + +#[derive(Diagnostic)] +#[diag(ast_lowering_no_precise_captures_on_apit)] +pub(crate) struct NoPreciseCapturesOnApit { + #[primary_span] + pub span: Span, +} diff --git a/compiler/rustc_ast_lowering/src/index.rs b/compiler/rustc_ast_lowering/src/index.rs index 4c552289a81..93be9b9b8cf 100644 --- a/compiler/rustc_ast_lowering/src/index.rs +++ b/compiler/rustc_ast_lowering/src/index.rs @@ -385,4 +385,21 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> { fn visit_pattern_type_pattern(&mut self, p: &'hir hir::Pat<'hir>) { self.visit_pat(p) } + + fn visit_precise_capturing_arg( + &mut self, + arg: &'hir PreciseCapturingArg<'hir>, + ) -> Self::Result { + match arg { + PreciseCapturingArg::Lifetime(_) => { + // This is represented as a `Node::Lifetime`, intravisit will get to it below. + } + PreciseCapturingArg::Param(param) => self.insert( + param.ident.span, + param.hir_id, + Node::PreciseCapturingNonLifetimeArg(param), + ), + } + intravisit::walk_precise_capturing_arg(self, arg); + } } diff --git a/compiler/rustc_ast_lowering/src/lib.rs b/compiler/rustc_ast_lowering/src/lib.rs index 5005c22d4cc..a21d6019cf1 100644 --- a/compiler/rustc_ast_lowering/src/lib.rs +++ b/compiler/rustc_ast_lowering/src/lib.rs @@ -48,6 +48,7 @@ use rustc_ast::{self as ast, *}; use rustc_ast_pretty::pprust; use rustc_data_structures::captures::Captures; use rustc_data_structures::fingerprint::Fingerprint; +use rustc_data_structures::fx::FxIndexSet; use rustc_data_structures::sorted_map::SortedMap; use rustc_data_structures::stable_hasher::{HashStable, StableHasher}; use rustc_data_structures::sync::Lrc; @@ -1398,7 +1399,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { }); hir::TyKind::TraitObject(bounds, lifetime_bound, *kind) } - TyKind::ImplTrait(def_node_id, bounds) => { + TyKind::ImplTrait(def_node_id, bounds, precise_capturing) => { let span = t.span; match itctx { ImplTraitContext::OpaqueTy { origin, fn_kind } => self.lower_opaque_impl_trait( @@ -1408,8 +1409,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { bounds, fn_kind, itctx, + precise_capturing.as_deref().map(|(args, _)| args.as_slice()), ), ImplTraitContext::Universal => { + if let Some(&(_, span)) = precise_capturing.as_deref() { + self.tcx.dcx().emit_err(errors::NoPreciseCapturesOnApit { span }); + }; let span = t.span; // HACK: pprust breaks strings with newlines when the type @@ -1520,6 +1525,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { bounds: &GenericBounds, fn_kind: Option<FnDeclKind>, itctx: ImplTraitContext, + precise_capturing_args: Option<&[PreciseCapturingArg]>, ) -> hir::TyKind<'hir> { // Make sure we know that some funky desugaring has been going on here. // This is a first: there is code in other places like for loop @@ -1528,42 +1534,59 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { // frequently opened issues show. let opaque_ty_span = self.mark_span_with_reason(DesugaringKind::OpaqueTy, span, None); - let captured_lifetimes_to_duplicate = match origin { - hir::OpaqueTyOrigin::TyAlias { .. } => { - // type alias impl trait and associated type position impl trait were - // decided to capture all in-scope lifetimes, which we collect for - // all opaques during resolution. - self.resolver - .take_extra_lifetime_params(opaque_ty_node_id) - .into_iter() - .map(|(ident, id, _)| Lifetime { id, ident }) + let captured_lifetimes_to_duplicate = + if let Some(precise_capturing) = precise_capturing_args { + // We'll actually validate these later on; all we need is the list of + // lifetimes to duplicate during this portion of lowering. + precise_capturing + .iter() + .filter_map(|arg| match arg { + PreciseCapturingArg::Lifetime(lt) => Some(*lt), + PreciseCapturingArg::Arg(..) => None, + }) + // Add in all the lifetimes mentioned in the bounds. We will error + // them out later, but capturing them here is important to make sure + // they actually get resolved in resolve_bound_vars. + .chain(lifetime_collector::lifetimes_in_bounds(self.resolver, bounds)) .collect() - } - hir::OpaqueTyOrigin::FnReturn(..) => { - if matches!( - fn_kind.expect("expected RPITs to be lowered with a FnKind"), - FnDeclKind::Impl | FnDeclKind::Trait - ) || self.tcx.features().lifetime_capture_rules_2024 - || span.at_least_rust_2024() - { - // return-position impl trait in trait was decided to capture all - // in-scope lifetimes, which we collect for all opaques during resolution. - self.resolver - .take_extra_lifetime_params(opaque_ty_node_id) - .into_iter() - .map(|(ident, id, _)| Lifetime { id, ident }) - .collect() - } else { - // in fn return position, like the `fn test<'a>() -> impl Debug + 'a` - // example, we only need to duplicate lifetimes that appear in the - // bounds, since those are the only ones that are captured by the opaque. - lifetime_collector::lifetimes_in_bounds(self.resolver, bounds) + } else { + match origin { + hir::OpaqueTyOrigin::TyAlias { .. } => { + // type alias impl trait and associated type position impl trait were + // decided to capture all in-scope lifetimes, which we collect for + // all opaques during resolution. + self.resolver + .take_extra_lifetime_params(opaque_ty_node_id) + .into_iter() + .map(|(ident, id, _)| Lifetime { id, ident }) + .collect() + } + hir::OpaqueTyOrigin::FnReturn(..) => { + if matches!( + fn_kind.expect("expected RPITs to be lowered with a FnKind"), + FnDeclKind::Impl | FnDeclKind::Trait + ) || self.tcx.features().lifetime_capture_rules_2024 + || span.at_least_rust_2024() + { + // return-position impl trait in trait was decided to capture all + // in-scope lifetimes, which we collect for all opaques during resolution. + self.resolver + .take_extra_lifetime_params(opaque_ty_node_id) + .into_iter() + .map(|(ident, id, _)| Lifetime { id, ident }) + .collect() + } else { + // in fn return position, like the `fn test<'a>() -> impl Debug + 'a` + // example, we only need to duplicate lifetimes that appear in the + // bounds, since those are the only ones that are captured by the opaque. + lifetime_collector::lifetimes_in_bounds(self.resolver, bounds) + } + } + hir::OpaqueTyOrigin::AsyncFn(..) => { + unreachable!("should be using `lower_async_fn_ret_ty`") + } } - } - hir::OpaqueTyOrigin::AsyncFn(..) => { - unreachable!("should be using `lower_async_fn_ret_ty`") - } - }; + }; debug!(?captured_lifetimes_to_duplicate); self.lower_opaque_inner( @@ -1573,6 +1596,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { captured_lifetimes_to_duplicate, span, opaque_ty_span, + precise_capturing_args, |this| this.lower_param_bounds(bounds, itctx), ) } @@ -1582,9 +1606,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { opaque_ty_node_id: NodeId, origin: hir::OpaqueTyOrigin, in_trait: bool, - captured_lifetimes_to_duplicate: Vec<Lifetime>, + captured_lifetimes_to_duplicate: FxIndexSet<Lifetime>, span: Span, opaque_ty_span: Span, + precise_capturing_args: Option<&[PreciseCapturingArg]>, lower_item_bounds: impl FnOnce(&mut Self) -> &'hir [hir::GenericBound<'hir>], ) -> hir::TyKind<'hir> { let opaque_ty_def_id = self.create_def( @@ -1671,8 +1696,15 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { // Install the remapping from old to new (if any). This makes sure that // any lifetimes that would have resolved to the def-id of captured // lifetimes are remapped to the new *synthetic* lifetimes of the opaque. - let bounds = this - .with_remapping(captured_to_synthesized_mapping, |this| lower_item_bounds(this)); + let (bounds, precise_capturing_args) = + this.with_remapping(captured_to_synthesized_mapping, |this| { + ( + lower_item_bounds(this), + precise_capturing_args.map(|precise_capturing| { + this.lower_precise_capturing_args(precise_capturing) + }), + ) + }); let generic_params = this.arena.alloc_from_iter(synthesized_lifetime_definitions.iter().map( @@ -1717,6 +1749,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { origin, lifetime_mapping, in_trait, + precise_capturing_args, }; // Generate an `type Foo = impl Trait;` declaration. @@ -1749,6 +1782,30 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { ) } + fn lower_precise_capturing_args( + &mut self, + precise_capturing_args: &[PreciseCapturingArg], + ) -> &'hir [hir::PreciseCapturingArg<'hir>] { + self.arena.alloc_from_iter(precise_capturing_args.iter().map(|arg| match arg { + PreciseCapturingArg::Lifetime(lt) => { + hir::PreciseCapturingArg::Lifetime(self.lower_lifetime(lt)) + } + PreciseCapturingArg::Arg(path, id) => { + let [segment] = path.segments.as_slice() else { + panic!(); + }; + let res = self.resolver.get_partial_res(*id).map_or(Res::Err, |partial_res| { + partial_res.full_res().expect("no partial res expected for precise capture arg") + }); + hir::PreciseCapturingArg::Param(hir::PreciseCapturingNonLifetimeArg { + hir_id: self.lower_node_id(*id), + ident: self.lower_ident(segment.ident), + res: self.lower_res(res), + }) + } + })) + } + fn lower_fn_params_to_names(&mut self, decl: &FnDecl) -> &'hir [Ident] { self.arena.alloc_from_iter(decl.inputs.iter().map(|param| match param.pat.kind { PatKind::Ident(_, ident, _) => self.lower_ident(ident), @@ -1889,7 +1946,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { let opaque_ty_span = self.mark_span_with_reason(DesugaringKind::Async, span, allowed_features); - let captured_lifetimes: Vec<_> = self + let captured_lifetimes = self .resolver .take_extra_lifetime_params(opaque_ty_node_id) .into_iter() @@ -1903,6 +1960,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { captured_lifetimes, span, opaque_ty_span, + None, |this| { let bound = this.lower_coroutine_fn_output_type_to_bound( output, diff --git a/compiler/rustc_ast_lowering/src/lifetime_collector.rs b/compiler/rustc_ast_lowering/src/lifetime_collector.rs index 4b1c057cdbf..5456abd489b 100644 --- a/compiler/rustc_ast_lowering/src/lifetime_collector.rs +++ b/compiler/rustc_ast_lowering/src/lifetime_collector.rs @@ -1,6 +1,7 @@ use super::ResolverAstLoweringExt; use rustc_ast::visit::{self, BoundKind, LifetimeCtxt, Visitor}; use rustc_ast::{GenericBounds, Lifetime, NodeId, PathSegment, PolyTraitRef, Ty, TyKind}; +use rustc_data_structures::fx::FxIndexSet; use rustc_hir::def::{DefKind, LifetimeRes, Res}; use rustc_middle::span_bug; use rustc_middle::ty::ResolverAstLowering; @@ -10,27 +11,23 @@ use rustc_span::Span; struct LifetimeCollectVisitor<'ast> { resolver: &'ast ResolverAstLowering, current_binders: Vec<NodeId>, - collected_lifetimes: Vec<Lifetime>, + collected_lifetimes: FxIndexSet<Lifetime>, } impl<'ast> LifetimeCollectVisitor<'ast> { fn new(resolver: &'ast ResolverAstLowering) -> Self { - Self { resolver, current_binders: Vec::new(), collected_lifetimes: Vec::new() } + Self { resolver, current_binders: Vec::new(), collected_lifetimes: FxIndexSet::default() } } fn record_lifetime_use(&mut self, lifetime: Lifetime) { match self.resolver.get_lifetime_res(lifetime.id).unwrap_or(LifetimeRes::Error) { LifetimeRes::Param { binder, .. } | LifetimeRes::Fresh { binder, .. } => { if !self.current_binders.contains(&binder) { - if !self.collected_lifetimes.contains(&lifetime) { - self.collected_lifetimes.push(lifetime); - } + self.collected_lifetimes.insert(lifetime); } } LifetimeRes::Static | LifetimeRes::Error => { - if !self.collected_lifetimes.contains(&lifetime) { - self.collected_lifetimes.push(lifetime); - } + self.collected_lifetimes.insert(lifetime); } LifetimeRes::Infer => {} res => { @@ -111,7 +108,7 @@ impl<'ast> Visitor<'ast> for LifetimeCollectVisitor<'ast> { pub(crate) fn lifetimes_in_bounds( resolver: &ResolverAstLowering, bounds: &GenericBounds, -) -> Vec<Lifetime> { +) -> FxIndexSet<Lifetime> { let mut visitor = LifetimeCollectVisitor::new(resolver); for bound in bounds { visitor.visit_param_bound(bound, BoundKind::Bound); |
