about summary refs log tree commit diff
path: root/compiler/rustc_trait_selection/src/traits/specialize/mod.rs
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-11-21 20:22:08 +0000
committerMichael Goulet <michael@errs.io>2024-12-02 22:12:08 +0000
commite91fc1bc0c05da68d218a01d550c6d12297f5703 (patch)
tree767916cd19d3874f62efced53a22aafe2fe7afe0 /compiler/rustc_trait_selection/src/traits/specialize/mod.rs
parentd49be02cf6d2e2a01264fcdef1e20c826710c0f5 (diff)
downloadrust-e91fc1bc0c05da68d218a01d550c6d12297f5703.tar.gz
rust-e91fc1bc0c05da68d218a01d550c6d12297f5703.zip
Reimplement specialization for const traits
Diffstat (limited to 'compiler/rustc_trait_selection/src/traits/specialize/mod.rs')
-rw-r--r--compiler/rustc_trait_selection/src/traits/specialize/mod.rs130
1 files changed, 112 insertions, 18 deletions
diff --git a/compiler/rustc_trait_selection/src/traits/specialize/mod.rs b/compiler/rustc_trait_selection/src/traits/specialize/mod.rs
index a9cd705465e..91a0599a3be 100644
--- a/compiler/rustc_trait_selection/src/traits/specialize/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/specialize/mod.rs
@@ -15,6 +15,7 @@ use rustc_data_structures::fx::FxIndexSet;
 use rustc_errors::codes::*;
 use rustc_errors::{Diag, EmissionGuarantee};
 use rustc_hir::def_id::{DefId, LocalDefId};
+use rustc_infer::traits::Obligation;
 use rustc_middle::bug;
 use rustc_middle::query::LocalCrate;
 use rustc_middle::ty::print::PrintTraitRefExt as _;
@@ -230,15 +231,18 @@ pub(super) fn specialization_enabled_in(tcx: TyCtxt<'_>, _: LocalCrate) -> bool
 /// `impl1` specializes `impl2` if it applies to a subset of the types `impl2` applies
 /// to.
 #[instrument(skip(tcx), level = "debug")]
-pub(super) fn specializes(tcx: TyCtxt<'_>, (impl1_def_id, impl2_def_id): (DefId, DefId)) -> bool {
+pub(super) fn specializes(
+    tcx: TyCtxt<'_>,
+    (specializing_impl_def_id, parent_impl_def_id): (DefId, DefId),
+) -> bool {
     // We check that the specializing impl comes from a crate that has specialization enabled,
     // or if the specializing impl is marked with `allow_internal_unstable`.
     //
     // We don't really care if the specialized impl (the parent) is in a crate that has
     // specialization enabled, since it's not being specialized, and it's already been checked
     // for coherence.
-    if !tcx.specialization_enabled_in(impl1_def_id.krate) {
-        let span = tcx.def_span(impl1_def_id);
+    if !tcx.specialization_enabled_in(specializing_impl_def_id.krate) {
+        let span = tcx.def_span(specializing_impl_def_id);
         if !span.allows_unstable(sym::specialization)
             && !span.allows_unstable(sym::min_specialization)
         {
@@ -246,7 +250,7 @@ pub(super) fn specializes(tcx: TyCtxt<'_>, (impl1_def_id, impl2_def_id): (DefId,
         }
     }
 
-    let impl1_trait_header = tcx.impl_trait_header(impl1_def_id).unwrap();
+    let specializing_impl_trait_header = tcx.impl_trait_header(specializing_impl_def_id).unwrap();
 
     // We determine whether there's a subset relationship by:
     //
@@ -261,27 +265,117 @@ pub(super) fn specializes(tcx: TyCtxt<'_>, (impl1_def_id, impl2_def_id): (DefId,
     // See RFC 1210 for more details and justification.
 
     // Currently we do not allow e.g., a negative impl to specialize a positive one
-    if impl1_trait_header.polarity != tcx.impl_polarity(impl2_def_id) {
+    if specializing_impl_trait_header.polarity != tcx.impl_polarity(parent_impl_def_id) {
         return false;
     }
 
-    // create a parameter environment corresponding to an identity instantiation of impl1,
-    // i.e. the most generic instantiation of impl1.
-    let param_env = tcx.param_env(impl1_def_id);
+    // create a parameter environment corresponding to an identity instantiation of the specializing impl,
+    // i.e. the most generic instantiation of the specializing impl.
+    let param_env = tcx.param_env(specializing_impl_def_id);
 
-    // Create an infcx, taking the predicates of impl1 as assumptions:
+    // Create an infcx, taking the predicates of the specializing impl as assumptions:
     let infcx = tcx.infer_ctxt().build(TypingMode::non_body_analysis());
 
-    // Attempt to prove that impl2 applies, given all of the above.
-    fulfill_implication(
-        &infcx,
+    let specializing_impl_trait_ref =
+        specializing_impl_trait_header.trait_ref.instantiate_identity();
+    let cause = &ObligationCause::dummy();
+    debug!(
+        "fulfill_implication({:?}, trait_ref={:?} |- {:?} applies)",
+        param_env, specializing_impl_trait_ref, parent_impl_def_id
+    );
+
+    // Attempt to prove that the parent impl applies, given all of the above.
+
+    let ocx = ObligationCtxt::new(&infcx);
+    let specializing_impl_trait_ref = ocx.normalize(cause, param_env, specializing_impl_trait_ref);
+
+    if !ocx.select_all_or_error().is_empty() {
+        infcx.dcx().span_delayed_bug(
+            infcx.tcx.def_span(specializing_impl_def_id),
+            format!("failed to fully normalize {specializing_impl_trait_ref}"),
+        );
+        return false;
+    }
+
+    let parent_args = infcx.fresh_args_for_item(DUMMY_SP, parent_impl_def_id);
+    let parent_impl_trait_ref = ocx.normalize(
+        cause,
         param_env,
-        impl1_trait_header.trait_ref.instantiate_identity(),
-        impl1_def_id,
-        impl2_def_id,
-        &ObligationCause::dummy(),
-    )
-    .is_ok()
+        infcx
+            .tcx
+            .impl_trait_ref(parent_impl_def_id)
+            .expect("expected source impl to be a trait impl")
+            .instantiate(infcx.tcx, parent_args),
+    );
+
+    // do the impls unify? If not, no specialization.
+    let Ok(()) = ocx.eq(cause, param_env, specializing_impl_trait_ref, parent_impl_trait_ref)
+    else {
+        return false;
+    };
+
+    // Now check that the source trait ref satisfies all the where clauses of the target impl.
+    // This is not just for correctness; we also need this to constrain any params that may
+    // only be referenced via projection predicates.
+    let predicates = ocx.normalize(
+        cause,
+        param_env,
+        infcx.tcx.predicates_of(parent_impl_def_id).instantiate(infcx.tcx, parent_args),
+    );
+    let obligations = predicates_for_generics(|_, _| cause.clone(), param_env, predicates);
+    ocx.register_obligations(obligations);
+
+    let errors = ocx.select_all_or_error();
+    if !errors.is_empty() {
+        // no dice!
+        debug!(
+            "fulfill_implication: for impls on {:?} and {:?}, \
+                 could not fulfill: {:?} given {:?}",
+            specializing_impl_trait_ref,
+            parent_impl_trait_ref,
+            errors,
+            param_env.caller_bounds()
+        );
+        return false;
+    }
+
+    // If the parent impl is const, then the specializing impl must be const.
+    if tcx.is_conditionally_const(parent_impl_def_id) {
+        let const_conditions = ocx.normalize(
+            cause,
+            param_env,
+            infcx.tcx.const_conditions(parent_impl_def_id).instantiate(infcx.tcx, parent_args),
+        );
+        ocx.register_obligations(const_conditions.into_iter().map(|(trait_ref, _)| {
+            Obligation::new(
+                infcx.tcx,
+                cause.clone(),
+                param_env,
+                trait_ref.to_host_effect_clause(infcx.tcx, ty::BoundConstness::Maybe),
+            )
+        }));
+
+        let errors = ocx.select_all_or_error();
+        if !errors.is_empty() {
+            // no dice!
+            debug!(
+                "fulfill_implication: for impls on {:?} and {:?}, \
+                 could not fulfill: {:?} given {:?}",
+                specializing_impl_trait_ref,
+                parent_impl_trait_ref,
+                errors,
+                param_env.caller_bounds()
+            );
+            return false;
+        }
+    }
+
+    debug!(
+        "fulfill_implication: an impl for {:?} specializes {:?}",
+        specializing_impl_trait_ref, parent_impl_trait_ref
+    );
+
+    true
 }
 
 /// Query provider for `specialization_graph_of`.