about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-08-12 17:56:49 -0400
committerMichael Goulet <michael@errs.io>2024-09-05 06:34:42 -0400
commitf8f4d50aa34640906e0315adbf4c487712fab0cd (patch)
tree24f6a8dd215b860f8686eb46b6cb6d7266639477
parenteb33b43bab08223fa6b46abacc1e95e859fe375d (diff)
downloadrust-f8f4d50aa34640906e0315adbf4c487712fab0cd.tar.gz
rust-f8f4d50aa34640906e0315adbf4c487712fab0cd.zip
Don't worry about uncaptured contravariant lifetimes if they outlive a captured lifetime
-rw-r--r--compiler/rustc_lint/src/impl_trait_overcaptures.rs247
-rw-r--r--tests/ui/impl-trait/precise-capturing/overcaptures-2024-but-fine.rs15
2 files changed, 241 insertions, 21 deletions
diff --git a/compiler/rustc_lint/src/impl_trait_overcaptures.rs b/compiler/rustc_lint/src/impl_trait_overcaptures.rs
index 8824e1dfe50..42c800a81af 100644
--- a/compiler/rustc_lint/src/impl_trait_overcaptures.rs
+++ b/compiler/rustc_lint/src/impl_trait_overcaptures.rs
@@ -1,19 +1,28 @@
-use rustc_data_structures::fx::FxIndexSet;
+use std::cell::LazyCell;
+
+use rustc_data_structures::fx::{FxHashMap, FxIndexMap, FxIndexSet};
 use rustc_data_structures::unord::UnordSet;
 use rustc_errors::{Applicability, LintDiagnostic};
 use rustc_hir as hir;
 use rustc_hir::def::DefKind;
 use rustc_hir::def_id::{DefId, LocalDefId};
+use rustc_infer::infer::outlives::env::OutlivesEnvironment;
+use rustc_infer::infer::TyCtxtInferExt;
 use rustc_macros::LintDiagnostic;
-use rustc_middle::bug;
 use rustc_middle::middle::resolve_bound_vars::ResolvedArg;
+use rustc_middle::ty::relate::{
+    structurally_relate_consts, structurally_relate_tys, Relate, RelateResult, TypeRelation,
+};
 use rustc_middle::ty::{
     self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor,
 };
+use rustc_middle::{bug, span_bug};
 use rustc_session::lint::FutureIncompatibilityReason;
 use rustc_session::{declare_lint, declare_lint_pass};
 use rustc_span::edition::Edition;
-use rustc_span::Span;
+use rustc_span::{Span, Symbol};
+use rustc_trait_selection::traits::outlives_bounds::InferCtxtExt;
+use rustc_trait_selection::traits::ObligationCtxt;
 
 use crate::{fluent_generated as fluent, LateContext, LateLintPass};
 
@@ -119,20 +128,41 @@ impl<'tcx> LateLintPass<'tcx> for ImplTraitOvercaptures {
     }
 }
 
+#[derive(PartialEq, Eq, Hash, Debug, Copy, Clone)]
+enum ParamKind {
+    // Early-bound var.
+    Early(Symbol, u32),
+    // Late-bound var on function, not within a binder. We can capture these.
+    Free(DefId, Symbol),
+    // Late-bound var in a binder. We can't capture these yet.
+    Late,
+}
+
 fn check_fn(tcx: TyCtxt<'_>, parent_def_id: LocalDefId) {
     let sig = tcx.fn_sig(parent_def_id).instantiate_identity();
 
-    let mut in_scope_parameters = FxIndexSet::default();
+    let mut in_scope_parameters = FxIndexMap::default();
     // Populate the in_scope_parameters list first with all of the generics in scope
     let mut current_def_id = Some(parent_def_id.to_def_id());
     while let Some(def_id) = current_def_id {
         let generics = tcx.generics_of(def_id);
         for param in &generics.own_params {
-            in_scope_parameters.insert(param.def_id);
+            in_scope_parameters.insert(param.def_id, ParamKind::Early(param.name, param.index));
         }
         current_def_id = generics.parent;
     }
 
+    for bound_var in sig.bound_vars() {
+        let ty::BoundVariableKind::Region(ty::BoundRegionKind::BrNamed(def_id, name)) = bound_var
+        else {
+            span_bug!(tcx.def_span(parent_def_id), "unexpected non-lifetime binder on fn sig");
+        };
+
+        in_scope_parameters.insert(def_id, ParamKind::Free(def_id, name));
+    }
+
+    let sig = tcx.liberate_late_bound_regions(parent_def_id.to_def_id(), sig);
+
     // Then visit the signature to walk through all the binders (incl. the late-bound
     // vars on the function itself, which we need to count too).
     sig.visit_with(&mut VisitOpaqueTypes {
@@ -140,17 +170,44 @@ fn check_fn(tcx: TyCtxt<'_>, parent_def_id: LocalDefId) {
         parent_def_id,
         in_scope_parameters,
         seen: Default::default(),
+        // Lazily compute these two, since they're likely a bit expensive.
+        variances: LazyCell::new(|| {
+            let mut functional_variances = FunctionalVariances {
+                tcx: tcx,
+                variances: FxHashMap::default(),
+                ambient_variance: ty::Covariant,
+                generics: tcx.generics_of(parent_def_id),
+            };
+            let _ = functional_variances.relate(sig, sig);
+            functional_variances.variances
+        }),
+        outlives_env: LazyCell::new(|| {
+            let param_env = tcx.param_env(parent_def_id);
+            let infcx = tcx.infer_ctxt().build();
+            let ocx = ObligationCtxt::new(&infcx);
+            let assumed_wf_tys = ocx.assumed_wf_types(param_env, parent_def_id).unwrap_or_default();
+            let implied_bounds =
+                infcx.implied_bounds_tys_compat(param_env, parent_def_id, &assumed_wf_tys, false);
+            OutlivesEnvironment::with_bounds(param_env, implied_bounds)
+        }),
     });
 }
 
-struct VisitOpaqueTypes<'tcx> {
+struct VisitOpaqueTypes<'tcx, VarFn, OutlivesFn> {
     tcx: TyCtxt<'tcx>,
     parent_def_id: LocalDefId,
-    in_scope_parameters: FxIndexSet<DefId>,
+    in_scope_parameters: FxIndexMap<DefId, ParamKind>,
+    variances: LazyCell<FxHashMap<DefId, ty::Variance>, VarFn>,
+    outlives_env: LazyCell<OutlivesEnvironment<'tcx>, OutlivesFn>,
     seen: FxIndexSet<LocalDefId>,
 }
 
-impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
+impl<'tcx, VarFn, OutlivesFn> TypeVisitor<TyCtxt<'tcx>>
+    for VisitOpaqueTypes<'tcx, VarFn, OutlivesFn>
+where
+    VarFn: FnOnce() -> FxHashMap<DefId, ty::Variance>,
+    OutlivesFn: FnOnce() -> OutlivesEnvironment<'tcx>,
+{
     fn visit_binder<T: TypeVisitable<TyCtxt<'tcx>>>(
         &mut self,
         t: &ty::Binder<'tcx, T>,
@@ -163,8 +220,8 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
                 ty::BoundVariableKind::Region(ty::BoundRegionKind::BrNamed(def_id, ..))
                 | ty::BoundVariableKind::Ty(ty::BoundTyKind::Param(def_id, _)) => {
                     added.push(def_id);
-                    let unique = self.in_scope_parameters.insert(def_id);
-                    assert!(unique);
+                    let unique = self.in_scope_parameters.insert(def_id, ParamKind::Late);
+                    assert_eq!(unique, None);
                 }
                 _ => {
                     self.tcx.dcx().span_delayed_bug(
@@ -209,6 +266,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
         {
             // Compute the set of args that are captured by the opaque...
             let mut captured = FxIndexSet::default();
+            let mut captured_regions = FxIndexSet::default();
             let variances = self.tcx.variances_of(opaque_def_id);
             let mut current_def_id = Some(opaque_def_id.to_def_id());
             while let Some(def_id) = current_def_id {
@@ -218,25 +276,60 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
                     if variances[param.index as usize] != ty::Invariant {
                         continue;
                     }
+
+                    let arg = opaque_ty.args[param.index as usize];
                     // We need to turn all `ty::Param`/`ConstKind::Param` and
                     // `ReEarlyParam`/`ReBound` into def ids.
-                    captured.insert(extract_def_id_from_arg(
-                        self.tcx,
-                        generics,
-                        opaque_ty.args[param.index as usize],
-                    ));
+                    captured.insert(extract_def_id_from_arg(self.tcx, generics, arg));
+
+                    captured_regions.extend(arg.as_region());
                 }
                 current_def_id = generics.parent;
             }
 
             // Compute the set of in scope params that are not captured. Get their spans,
             // since that's all we really care about them for emitting the diagnostic.
-            let uncaptured_spans: Vec<_> = self
+            let mut uncaptured_args: FxIndexSet<_> = self
                 .in_scope_parameters
                 .iter()
-                .filter(|def_id| !captured.contains(*def_id))
-                .map(|def_id| self.tcx.def_span(def_id))
+                .filter(|&(def_id, _)| !captured.contains(def_id))
+                .collect();
+
+            // These are args that we know are likely fine to "overcapture", since they can be
+            // contravariantly shortened to one of the already-captured lifetimes that they
+            // outlive.
+            let covariant_long_args: FxIndexSet<_> = uncaptured_args
+                .iter()
+                .copied()
+                .filter(|&(def_id, kind)| {
+                    let Some(ty::Bivariant | ty::Contravariant) = self.variances.get(def_id) else {
+                        return false;
+                    };
+                    let DefKind::LifetimeParam = self.tcx.def_kind(def_id) else {
+                        return false;
+                    };
+                    let uncaptured = match *kind {
+                        ParamKind::Early(name, index) => ty::Region::new_early_param(
+                            self.tcx,
+                            ty::EarlyParamRegion { name, index },
+                        ),
+                        ParamKind::Free(def_id, name) => ty::Region::new_late_param(
+                            self.tcx,
+                            self.parent_def_id.to_def_id(),
+                            ty::BoundRegionKind::BrNamed(def_id, name),
+                        ),
+                        ParamKind::Late => return false,
+                    };
+                    // Does this region outlive any captured region?
+                    captured_regions.iter().any(|r| {
+                        self.outlives_env
+                            .free_region_map()
+                            .sub_free_regions(self.tcx, *r, uncaptured)
+                    })
+                })
                 .collect();
+            // We don't care to warn on these args.
+            uncaptured_args.retain(|arg| !covariant_long_args.contains(arg));
 
             let opaque_span = self.tcx.def_span(opaque_def_id);
             let new_capture_rules =
@@ -246,7 +339,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
             // `use<>` syntax on it, and we're < edition 2024, then warn the user.
             if !new_capture_rules
                 && !opaque.bounds.iter().any(|bound| matches!(bound, hir::GenericBound::Use(..)))
-                && !uncaptured_spans.is_empty()
+                && !uncaptured_args.is_empty()
             {
                 let suggestion = if let Ok(snippet) =
                     self.tcx.sess.source_map().span_to_snippet(opaque_span)
@@ -274,6 +367,11 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
                     None
                 };
 
+                let uncaptured_spans: Vec<_> = uncaptured_args
+                    .into_iter()
+                    .map(|(def_id, _)| self.tcx.def_span(def_id))
+                    .collect();
+
                 self.tcx.emit_node_span_lint(
                     IMPL_TRAIT_OVERCAPTURES,
                     self.tcx.local_def_id_to_hir_id(opaque_def_id),
@@ -327,7 +425,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
                 if self
                     .in_scope_parameters
                     .iter()
-                    .all(|def_id| explicitly_captured.contains(def_id))
+                    .all(|(def_id, _)| explicitly_captured.contains(def_id))
                 {
                     self.tcx.emit_node_span_lint(
                         IMPL_TRAIT_REDUNDANT_CAPTURES,
@@ -396,7 +494,11 @@ fn extract_def_id_from_arg<'tcx>(
             ty::ReBound(
                 _,
                 ty::BoundRegion { kind: ty::BoundRegionKind::BrNamed(def_id, ..), .. },
-            ) => def_id,
+            )
+            | ty::ReLateParam(ty::LateParamRegion {
+                scope: _,
+                bound_region: ty::BoundRegionKind::BrNamed(def_id, ..),
+            }) => def_id,
             _ => unreachable!(),
         },
         ty::GenericArgKind::Type(ty) => {
@@ -413,3 +515,106 @@ fn extract_def_id_from_arg<'tcx>(
         }
     }
 }
+
+/// Computes the variances of regions that appear in the type, but considering
+/// late-bound regions too, which don't have their variance computed usually.
+///
+/// Like generalization, this is a unary operation implemented on top of the binary
+/// relation infrastructure, mostly because it's much easier to have the relation
+/// track the variance for you, rather than having to do it yourself.
+struct FunctionalVariances<'tcx> {
+    tcx: TyCtxt<'tcx>,
+    variances: FxHashMap<DefId, ty::Variance>,
+    ambient_variance: ty::Variance,
+    generics: &'tcx ty::Generics,
+}
+
+impl<'tcx> TypeRelation<TyCtxt<'tcx>> for FunctionalVariances<'tcx> {
+    fn cx(&self) -> TyCtxt<'tcx> {
+        self.tcx
+    }
+
+    fn relate_with_variance<T: ty::relate::Relate<TyCtxt<'tcx>>>(
+        &mut self,
+        variance: rustc_type_ir::Variance,
+        _: ty::VarianceDiagInfo<TyCtxt<'tcx>>,
+        a: T,
+        b: T,
+    ) -> RelateResult<'tcx, T> {
+        let old_variance = self.ambient_variance;
+        self.ambient_variance = self.ambient_variance.xform(variance);
+        self.relate(a, b)?;
+        self.ambient_variance = old_variance;
+        Ok(a)
+    }
+
+    fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
+        structurally_relate_tys(self, a, b)?;
+        Ok(a)
+    }
+
+    fn regions(
+        &mut self,
+        a: ty::Region<'tcx>,
+        _: ty::Region<'tcx>,
+    ) -> RelateResult<'tcx, ty::Region<'tcx>> {
+        let def_id = match *a {
+            ty::ReEarlyParam(ebr) => self.generics.region_param(ebr, self.tcx).def_id,
+            ty::ReBound(
+                _,
+                ty::BoundRegion { kind: ty::BoundRegionKind::BrNamed(def_id, ..), .. },
+            )
+            | ty::ReLateParam(ty::LateParamRegion {
+                scope: _,
+                bound_region: ty::BoundRegionKind::BrNamed(def_id, ..),
+            }) => def_id,
+            _ => {
+                return Ok(a);
+            }
+        };
+
+        if let Some(variance) = self.variances.get_mut(&def_id) {
+            *variance = unify(*variance, self.ambient_variance);
+        } else {
+            self.variances.insert(def_id, self.ambient_variance);
+        }
+
+        Ok(a)
+    }
+
+    fn consts(
+        &mut self,
+        a: ty::Const<'tcx>,
+        b: ty::Const<'tcx>,
+    ) -> RelateResult<'tcx, ty::Const<'tcx>> {
+        structurally_relate_consts(self, a, b)?;
+        Ok(a)
+    }
+
+    fn binders<T>(
+        &mut self,
+        a: ty::Binder<'tcx, T>,
+        b: ty::Binder<'tcx, T>,
+    ) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
+    where
+        T: Relate<TyCtxt<'tcx>>,
+    {
+        self.relate(a.skip_binder(), b.skip_binder())?;
+        Ok(a)
+    }
+}
+
+/// What is the variance that satisfies the two variances?
+fn unify(a: ty::Variance, b: ty::Variance) -> ty::Variance {
+    match (a, b) {
+        // Bivariance is lattice bottom.
+        (ty::Bivariant, other) | (other, ty::Bivariant) => other,
+        // Invariant is lattice top.
+        (ty::Invariant, _) | (_, ty::Invariant) => ty::Invariant,
+        // If type is required to be covariant and contravariant, then it's invariant.
+        (ty::Contravariant, ty::Covariant) | (ty::Covariant, ty::Contravariant) => ty::Invariant,
+        // Otherwise, co + co = co, contra + contra = contra.
+        (ty::Contravariant, ty::Contravariant) => ty::Contravariant,
+        (ty::Covariant, ty::Covariant) => ty::Covariant,
+    }
+}
diff --git a/tests/ui/impl-trait/precise-capturing/overcaptures-2024-but-fine.rs b/tests/ui/impl-trait/precise-capturing/overcaptures-2024-but-fine.rs
new file mode 100644
index 00000000000..e30f785b0ae
--- /dev/null
+++ b/tests/ui/impl-trait/precise-capturing/overcaptures-2024-but-fine.rs
@@ -0,0 +1,15 @@
+//@ check-pass
+
+#![deny(impl_trait_overcaptures)]
+
+struct Ctxt<'tcx>(&'tcx ());
+
+// In `compute`, we don't care that we're "overcapturing" `'tcx`
+// in edition 2024, because it can be shortened at the call site
+// and we know it outlives `'_`.
+
+impl<'tcx> Ctxt<'tcx> {
+    fn compute(&self) -> impl Sized + '_ {}
+}
+
+fn main() {}