about summary refs log tree commit diff
path: root/compiler/rustc_middle/src
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-11-16 20:18:13 +0000
committerMichael Goulet <michael@errs.io>2024-12-10 16:52:20 +0000
commita7fa4cbcb498b80b126a954b5944f19a11e28dec (patch)
tree462d7ced92d08e35a6cae86caeb804c8dbefb635 /compiler/rustc_middle/src
parent3b057796264e637b31c28809322028a32d22ae6f (diff)
downloadrust-a7fa4cbcb498b80b126a954b5944f19a11e28dec.tar.gz
rust-a7fa4cbcb498b80b126a954b5944f19a11e28dec.zip
Implement projection and shim for AFIDT
Diffstat (limited to 'compiler/rustc_middle/src')
-rw-r--r--compiler/rustc_middle/src/ty/instance.rs37
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs1
-rw-r--r--compiler/rustc_middle/src/ty/return_position_impl_trait_in_trait.rs95
3 files changed, 116 insertions, 17 deletions
diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs
index 65c909e70f6..1dd564d9798 100644
--- a/compiler/rustc_middle/src/ty/instance.rs
+++ b/compiler/rustc_middle/src/ty/instance.rs
@@ -677,23 +677,26 @@ impl<'tcx> Instance<'tcx> {
                 //
                 // 1) The underlying method expects a caller location parameter
                 // in the ABI
-                if resolved.def.requires_caller_location(tcx)
-                        // 2) The caller location parameter comes from having `#[track_caller]`
-                        // on the implementation, and *not* on the trait method.
-                        && !tcx.should_inherit_track_caller(def)
-                        // If the method implementation comes from the trait definition itself
-                        // (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
-                        // then we don't need to generate a shim. This check is needed because
-                        // `should_inherit_track_caller` returns `false` if our method
-                        // implementation comes from the trait block, and not an impl block
-                        && !matches!(
-                            tcx.opt_associated_item(def),
-                            Some(ty::AssocItem {
-                                container: ty::AssocItemContainer::Trait,
-                                ..
-                            })
-                        )
-                {
+                let needs_track_caller_shim = resolved.def.requires_caller_location(tcx)
+                    // 2) The caller location parameter comes from having `#[track_caller]`
+                    // on the implementation, and *not* on the trait method.
+                    && !tcx.should_inherit_track_caller(def)
+                    // If the method implementation comes from the trait definition itself
+                    // (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
+                    // then we don't need to generate a shim. This check is needed because
+                    // `should_inherit_track_caller` returns `false` if our method
+                    // implementation comes from the trait block, and not an impl block
+                    && !matches!(
+                        tcx.opt_associated_item(def),
+                        Some(ty::AssocItem {
+                            container: ty::AssocItemContainer::Trait,
+                            ..
+                        })
+                    );
+                // We also need to generate a shim if this is an AFIT.
+                let needs_rpitit_shim =
+                    tcx.return_position_impl_trait_in_trait_shim_data(def).is_some();
+                if needs_track_caller_shim || needs_rpitit_shim {
                     if tcx.is_closure_like(def) {
                         debug!(
                             " => vtable fn pointer created for closure with #[track_caller]: {:?} for method {:?} {:?}",
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 70e0568b202..80b11892a42 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -146,6 +146,7 @@ mod opaque_types;
 mod parameterized;
 mod predicate;
 mod region;
+mod return_position_impl_trait_in_trait;
 mod rvalue_scopes;
 mod structural_impls;
 #[allow(hidden_glob_reexports)]
diff --git a/compiler/rustc_middle/src/ty/return_position_impl_trait_in_trait.rs b/compiler/rustc_middle/src/ty/return_position_impl_trait_in_trait.rs
new file mode 100644
index 00000000000..21c605f8296
--- /dev/null
+++ b/compiler/rustc_middle/src/ty/return_position_impl_trait_in_trait.rs
@@ -0,0 +1,95 @@
+use rustc_hir::def_id::DefId;
+
+use crate::ty::{self, ExistentialPredicateStableCmpExt, TyCtxt};
+
+impl<'tcx> TyCtxt<'tcx> {
+    /// Given a `def_id` of a trait or impl method, compute whether that method needs to
+    /// have an RPITIT shim applied to it for it to be object safe. If so, return the
+    /// `def_id` of the RPITIT, and also the args of trait method that returns the RPITIT.
+    ///
+    /// NOTE that these args are not, in general, the same as than the RPITIT's args. They
+    /// are a subset of those args,  since they do not include the late-bound lifetimes of
+    /// the RPITIT. Depending on the context, these will need to be dealt with in different
+    /// ways -- in codegen, it's okay to fill them with ReErased.
+    pub fn return_position_impl_trait_in_trait_shim_data(
+        self,
+        def_id: DefId,
+    ) -> Option<(DefId, ty::EarlyBinder<'tcx, ty::GenericArgsRef<'tcx>>)> {
+        let assoc_item = self.opt_associated_item(def_id)?;
+
+        let (trait_item_def_id, opt_impl_def_id) = match assoc_item.container {
+            ty::AssocItemContainer::Impl => {
+                (assoc_item.trait_item_def_id?, Some(self.parent(def_id)))
+            }
+            ty::AssocItemContainer::Trait => (def_id, None),
+        };
+
+        let sig = self.fn_sig(trait_item_def_id);
+
+        // Check if the trait returns an RPITIT.
+        let ty::Alias(ty::Projection, ty::AliasTy { def_id, .. }) =
+            *sig.skip_binder().skip_binder().output().kind()
+        else {
+            return None;
+        };
+        if !self.is_impl_trait_in_trait(def_id) {
+            return None;
+        }
+
+        let args = if let Some(impl_def_id) = opt_impl_def_id {
+            // Rebase the args from the RPITIT onto the impl trait ref, so we can later
+            // substitute them with the method args of the *impl* method, since that's
+            // the instance we're building a vtable shim for.
+            ty::GenericArgs::identity_for_item(self, trait_item_def_id).rebase_onto(
+                self,
+                self.parent(trait_item_def_id),
+                self.impl_trait_ref(impl_def_id)
+                    .expect("expected impl trait ref from parent of impl item")
+                    .instantiate_identity()
+                    .args,
+            )
+        } else {
+            // This is when we have a default trait implementation.
+            ty::GenericArgs::identity_for_item(self, trait_item_def_id)
+        };
+
+        Some((def_id, ty::EarlyBinder::bind(args)))
+    }
+
+    /// Given a `DefId` of an RPITIT and its args, return the existential predicates
+    /// that corresponds to the RPITIT's bounds with the self type erased.
+    pub fn item_bounds_to_existential_predicates(
+        self,
+        def_id: DefId,
+        args: ty::GenericArgsRef<'tcx>,
+    ) -> &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> {
+        let mut bounds: Vec<_> = self
+            .item_super_predicates(def_id)
+            .iter_instantiated(self, args)
+            .filter_map(|clause| {
+                clause
+                    .kind()
+                    .map_bound(|clause| match clause {
+                        ty::ClauseKind::Trait(trait_pred) => Some(ty::ExistentialPredicate::Trait(
+                            ty::ExistentialTraitRef::erase_self_ty(self, trait_pred.trait_ref),
+                        )),
+                        ty::ClauseKind::Projection(projection_pred) => {
+                            Some(ty::ExistentialPredicate::Projection(
+                                ty::ExistentialProjection::erase_self_ty(self, projection_pred),
+                            ))
+                        }
+                        ty::ClauseKind::TypeOutlives(_) => {
+                            // Type outlives bounds don't really turn into anything,
+                            // since we must use an intersection region for the `dyn*`'s
+                            // region anyways.
+                            None
+                        }
+                        _ => unreachable!("unexpected clause in item bounds: {clause:?}"),
+                    })
+                    .transpose()
+            })
+            .collect();
+        bounds.sort_by(|a, b| a.skip_binder().stable_cmp(self, &b.skip_binder()));
+        self.mk_poly_existential_predicates(&bounds)
+    }
+}