about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-02-14 07:27:11 +0000
committerbors <bors@rust-lang.org>2024-02-14 07:27:11 +0000
commitbb89df6903539e7014b8db29bccd6a9ee9553122 (patch)
tree1898591699565938209b492bc21d9151654bf31d
parentcc1c0990ab6f7136d1d54fd008598095b2c53c66 (diff)
parent7320623f3a646bcd6eba661d14cd6f3891ede495 (diff)
downloadrust-bb89df6903539e7014b8db29bccd6a9ee9553122.tar.gz
rust-bb89df6903539e7014b8db29bccd6a9ee9553122.zip
Auto merge of #121018 - oli-obk:impl_unsafety, r=TaKO8Ki
Fully stop using the HIR in trait impl checks

At least I hope I found all happy path usages. I'll need to check if I can figure out a way to make queries declare that they don't access the HIR except in error paths
-rw-r--r--compiler/rustc_hir_analysis/src/coherence/builtin.rs78
-rw-r--r--compiler/rustc_hir_analysis/src/coherence/mod.rs25
-rw-r--r--compiler/rustc_hir_analysis/src/coherence/unsafety.rs44
-rw-r--r--compiler/rustc_hir_analysis/src/collect.rs1
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs1
5 files changed, 80 insertions, 69 deletions
diff --git a/compiler/rustc_hir_analysis/src/coherence/builtin.rs b/compiler/rustc_hir_analysis/src/coherence/builtin.rs
index 370c6c607d7..6c3a9b747ef 100644
--- a/compiler/rustc_hir_analysis/src/coherence/builtin.rs
+++ b/compiler/rustc_hir_analysis/src/coherence/builtin.rs
@@ -15,7 +15,7 @@ use rustc_infer::infer::{DefineOpaqueTypes, TyCtxtInferExt};
 use rustc_infer::traits::Obligation;
 use rustc_middle::ty::adjustment::CoerceUnsizedInfo;
 use rustc_middle::ty::{self, suggest_constraining_type_params, Ty, TyCtxt, TypeVisitableExt};
-use rustc_span::Span;
+use rustc_span::{Span, DUMMY_SP};
 use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt;
 use rustc_trait_selection::traits::misc::{
     type_allowed_to_implement_const_param_ty, type_allowed_to_implement_copy,
@@ -25,13 +25,14 @@ use rustc_trait_selection::traits::ObligationCtxt;
 use rustc_trait_selection::traits::{self, ObligationCause};
 use std::collections::BTreeMap;
 
-pub fn check_trait(
-    tcx: TyCtxt<'_>,
+pub fn check_trait<'tcx>(
+    tcx: TyCtxt<'tcx>,
     trait_def_id: DefId,
     impl_def_id: LocalDefId,
+    impl_header: ty::ImplTraitHeader<'tcx>,
 ) -> Result<(), ErrorGuaranteed> {
     let lang_items = tcx.lang_items();
-    let checker = Checker { tcx, trait_def_id, impl_def_id };
+    let checker = Checker { tcx, trait_def_id, impl_def_id, impl_header };
     let mut res = checker.check(lang_items.drop_trait(), visit_implementation_of_drop);
     res = res.and(checker.check(lang_items.copy_trait(), visit_implementation_of_copy));
     res = res.and(
@@ -50,24 +51,25 @@ struct Checker<'tcx> {
     tcx: TyCtxt<'tcx>,
     trait_def_id: DefId,
     impl_def_id: LocalDefId,
+    impl_header: ty::ImplTraitHeader<'tcx>,
 }
 
 impl<'tcx> Checker<'tcx> {
     fn check(
         &self,
         trait_def_id: Option<DefId>,
-        f: impl FnOnce(TyCtxt<'tcx>, LocalDefId) -> Result<(), ErrorGuaranteed>,
+        f: impl FnOnce(&Self) -> Result<(), ErrorGuaranteed>,
     ) -> Result<(), ErrorGuaranteed> {
-        if Some(self.trait_def_id) == trait_def_id { f(self.tcx, self.impl_def_id) } else { Ok(()) }
+        if Some(self.trait_def_id) == trait_def_id { f(self) } else { Ok(()) }
     }
 }
 
-fn visit_implementation_of_drop(
-    tcx: TyCtxt<'_>,
-    impl_did: LocalDefId,
-) -> Result<(), ErrorGuaranteed> {
+fn visit_implementation_of_drop(checker: &Checker<'_>) -> Result<(), ErrorGuaranteed> {
+    let tcx = checker.tcx;
+    let header = checker.impl_header;
+    let impl_did = checker.impl_def_id;
     // Destructors only work on local ADT types.
-    match tcx.type_of(impl_did).instantiate_identity().kind() {
+    match header.trait_ref.self_ty().kind() {
         ty::Adt(def, _) if def.did().is_local() => return Ok(()),
         ty::Error(_) => return Ok(()),
         _ => {}
@@ -78,13 +80,13 @@ fn visit_implementation_of_drop(
     Err(tcx.dcx().emit_err(errors::DropImplOnWrongItem { span: impl_.self_ty.span }))
 }
 
-fn visit_implementation_of_copy(
-    tcx: TyCtxt<'_>,
-    impl_did: LocalDefId,
-) -> Result<(), ErrorGuaranteed> {
+fn visit_implementation_of_copy(checker: &Checker<'_>) -> Result<(), ErrorGuaranteed> {
+    let tcx = checker.tcx;
+    let impl_header = checker.impl_header;
+    let impl_did = checker.impl_def_id;
     debug!("visit_implementation_of_copy: impl_did={:?}", impl_did);
 
-    let self_type = tcx.type_of(impl_did).instantiate_identity();
+    let self_type = impl_header.trait_ref.self_ty();
     debug!("visit_implementation_of_copy: self_type={:?} (bound)", self_type);
 
     let param_env = tcx.param_env(impl_did);
@@ -92,56 +94,58 @@ fn visit_implementation_of_copy(
 
     debug!("visit_implementation_of_copy: self_type={:?} (free)", self_type);
 
-    if let ty::ImplPolarity::Negative = tcx.impl_polarity(impl_did) {
+    if let ty::ImplPolarity::Negative = impl_header.polarity {
         return Ok(());
     }
-    let span = tcx.hir().expect_item(impl_did).expect_impl().self_ty.span;
 
-    let cause = traits::ObligationCause::misc(span, impl_did);
+    let cause = traits::ObligationCause::misc(DUMMY_SP, impl_did);
     match type_allowed_to_implement_copy(tcx, param_env, self_type, cause) {
         Ok(()) => Ok(()),
         Err(CopyImplementationError::InfringingFields(fields)) => {
+            let span = tcx.hir().expect_item(impl_did).expect_impl().self_ty.span;
             Err(infringing_fields_error(tcx, fields, LangItem::Copy, impl_did, span))
         }
         Err(CopyImplementationError::NotAnAdt) => {
+            let span = tcx.hir().expect_item(impl_did).expect_impl().self_ty.span;
             Err(tcx.dcx().emit_err(errors::CopyImplOnNonAdt { span }))
         }
         Err(CopyImplementationError::HasDestructor) => {
+            let span = tcx.hir().expect_item(impl_did).expect_impl().self_ty.span;
             Err(tcx.dcx().emit_err(errors::CopyImplOnTypeWithDtor { span }))
         }
     }
 }
 
-fn visit_implementation_of_const_param_ty(
-    tcx: TyCtxt<'_>,
-    impl_did: LocalDefId,
-) -> Result<(), ErrorGuaranteed> {
-    let self_type = tcx.type_of(impl_did).instantiate_identity();
+fn visit_implementation_of_const_param_ty(checker: &Checker<'_>) -> Result<(), ErrorGuaranteed> {
+    let tcx = checker.tcx;
+    let header = checker.impl_header;
+    let impl_did = checker.impl_def_id;
+    let self_type = header.trait_ref.self_ty();
     assert!(!self_type.has_escaping_bound_vars());
 
     let param_env = tcx.param_env(impl_did);
 
-    if let ty::ImplPolarity::Negative = tcx.impl_polarity(impl_did) {
+    if let ty::ImplPolarity::Negative = header.polarity {
         return Ok(());
     }
-    let span = tcx.hir().expect_item(impl_did).expect_impl().self_ty.span;
 
-    let cause = traits::ObligationCause::misc(span, impl_did);
+    let cause = traits::ObligationCause::misc(DUMMY_SP, impl_did);
     match type_allowed_to_implement_const_param_ty(tcx, param_env, self_type, cause) {
         Ok(()) => Ok(()),
         Err(ConstParamTyImplementationError::InfrigingFields(fields)) => {
+            let span = tcx.hir().expect_item(impl_did).expect_impl().self_ty.span;
             Err(infringing_fields_error(tcx, fields, LangItem::ConstParamTy, impl_did, span))
         }
         Err(ConstParamTyImplementationError::NotAnAdtOrBuiltinAllowed) => {
+            let span = tcx.hir().expect_item(impl_did).expect_impl().self_ty.span;
             Err(tcx.dcx().emit_err(errors::ConstParamTyImplOnNonAdt { span }))
         }
     }
 }
 
-fn visit_implementation_of_coerce_unsized(
-    tcx: TyCtxt<'_>,
-    impl_did: LocalDefId,
-) -> Result<(), ErrorGuaranteed> {
+fn visit_implementation_of_coerce_unsized(checker: &Checker<'_>) -> Result<(), ErrorGuaranteed> {
+    let tcx = checker.tcx;
+    let impl_did = checker.impl_def_id;
     debug!("visit_implementation_of_coerce_unsized: impl_did={:?}", impl_did);
 
     // Just compute this for the side-effects, in particular reporting
@@ -151,20 +155,20 @@ fn visit_implementation_of_coerce_unsized(
     tcx.at(span).ensure().coerce_unsized_info(impl_did)
 }
 
-fn visit_implementation_of_dispatch_from_dyn(
-    tcx: TyCtxt<'_>,
-    impl_did: LocalDefId,
-) -> Result<(), ErrorGuaranteed> {
+fn visit_implementation_of_dispatch_from_dyn(checker: &Checker<'_>) -> Result<(), ErrorGuaranteed> {
+    let tcx = checker.tcx;
+    let header = checker.impl_header;
+    let impl_did = checker.impl_def_id;
+    let trait_ref = header.trait_ref;
     debug!("visit_implementation_of_dispatch_from_dyn: impl_did={:?}", impl_did);
 
     let span = tcx.def_span(impl_did);
 
     let dispatch_from_dyn_trait = tcx.require_lang_item(LangItem::DispatchFromDyn, Some(span));
 
-    let source = tcx.type_of(impl_did).instantiate_identity();
+    let source = trait_ref.self_ty();
     assert!(!source.has_escaping_bound_vars());
     let target = {
-        let trait_ref = tcx.impl_trait_ref(impl_did).unwrap().instantiate_identity();
         assert_eq!(trait_ref.def_id, dispatch_from_dyn_trait);
 
         trait_ref.args.type_at(1)
diff --git a/compiler/rustc_hir_analysis/src/coherence/mod.rs b/compiler/rustc_hir_analysis/src/coherence/mod.rs
index 7f59763f2a0..d6281fa08f7 100644
--- a/compiler/rustc_hir_analysis/src/coherence/mod.rs
+++ b/compiler/rustc_hir_analysis/src/coherence/mod.rs
@@ -23,6 +23,7 @@ fn check_impl(
     tcx: TyCtxt<'_>,
     impl_def_id: LocalDefId,
     trait_ref: ty::TraitRef<'_>,
+    trait_def: &ty::TraitDef,
 ) -> Result<(), ErrorGuaranteed> {
     debug!(
         "(checking implementation) adding impl for trait '{:?}', item '{}'",
@@ -36,19 +37,20 @@ fn check_impl(
         return Ok(());
     }
 
-    enforce_trait_manually_implementable(tcx, impl_def_id, trait_ref.def_id)
-        .and(enforce_empty_impls_for_marker_traits(tcx, impl_def_id, trait_ref.def_id))
+    enforce_trait_manually_implementable(tcx, impl_def_id, trait_ref.def_id, trait_def)
+        .and(enforce_empty_impls_for_marker_traits(tcx, impl_def_id, trait_ref.def_id, trait_def))
 }
 
 fn enforce_trait_manually_implementable(
     tcx: TyCtxt<'_>,
     impl_def_id: LocalDefId,
     trait_def_id: DefId,
+    trait_def: &ty::TraitDef,
 ) -> Result<(), ErrorGuaranteed> {
     let impl_header_span = tcx.def_span(impl_def_id);
 
     // Disallow *all* explicit impls of traits marked `#[rustc_deny_explicit_impl]`
-    if tcx.trait_def(trait_def_id).deny_explicit_impl {
+    if trait_def.deny_explicit_impl {
         let trait_name = tcx.item_name(trait_def_id);
         let mut err = struct_span_code_err!(
             tcx.dcx(),
@@ -67,8 +69,7 @@ fn enforce_trait_manually_implementable(
         return Err(err.emit());
     }
 
-    if let ty::trait_def::TraitSpecializationKind::AlwaysApplicable =
-        tcx.trait_def(trait_def_id).specialization_kind
+    if let ty::trait_def::TraitSpecializationKind::AlwaysApplicable = trait_def.specialization_kind
     {
         if !tcx.features().specialization
             && !tcx.features().min_specialization
@@ -87,8 +88,9 @@ fn enforce_empty_impls_for_marker_traits(
     tcx: TyCtxt<'_>,
     impl_def_id: LocalDefId,
     trait_def_id: DefId,
+    trait_def: &ty::TraitDef,
 ) -> Result<(), ErrorGuaranteed> {
-    if !tcx.trait_def(trait_def_id).is_marker {
+    if !trait_def.is_marker {
         return Ok(());
     }
 
@@ -132,14 +134,15 @@ fn coherent_trait(tcx: TyCtxt<'_>, def_id: DefId) -> Result<(), ErrorGuaranteed>
     let mut res = tcx.ensure().specialization_graph_of(def_id);
 
     for &impl_def_id in impls {
-        let trait_ref = tcx.impl_trait_ref(impl_def_id).unwrap().instantiate_identity();
+        let trait_header = tcx.impl_trait_header(impl_def_id).unwrap().instantiate_identity();
+        let trait_def = tcx.trait_def(trait_header.trait_ref.def_id);
 
-        res = res.and(check_impl(tcx, impl_def_id, trait_ref));
-        res = res.and(check_object_overlap(tcx, impl_def_id, trait_ref));
+        res = res.and(check_impl(tcx, impl_def_id, trait_header.trait_ref, trait_def));
+        res = res.and(check_object_overlap(tcx, impl_def_id, trait_header.trait_ref));
 
-        res = res.and(unsafety::check_item(tcx, impl_def_id, trait_ref));
+        res = res.and(unsafety::check_item(tcx, impl_def_id, trait_header, trait_def));
         res = res.and(tcx.ensure().orphan_check_impl(impl_def_id));
-        res = res.and(builtin::check_trait(tcx, def_id, impl_def_id));
+        res = res.and(builtin::check_trait(tcx, def_id, impl_def_id, trait_header));
     }
 
     res
diff --git a/compiler/rustc_hir_analysis/src/coherence/unsafety.rs b/compiler/rustc_hir_analysis/src/coherence/unsafety.rs
index d217d53587d..688760a3912 100644
--- a/compiler/rustc_hir_analysis/src/coherence/unsafety.rs
+++ b/compiler/rustc_hir_analysis/src/coherence/unsafety.rs
@@ -2,23 +2,23 @@
 //! crate or pertains to a type defined in this crate.
 
 use rustc_errors::{codes::*, struct_span_code_err};
-use rustc_hir as hir;
 use rustc_hir::Unsafety;
-use rustc_middle::ty::{TraitRef, TyCtxt};
+use rustc_middle::ty::{ImplPolarity::*, ImplTraitHeader, TraitDef, TyCtxt};
 use rustc_span::def_id::LocalDefId;
 use rustc_span::ErrorGuaranteed;
 
 pub(super) fn check_item(
     tcx: TyCtxt<'_>,
     def_id: LocalDefId,
-    trait_ref: TraitRef<'_>,
+    trait_header: ImplTraitHeader<'_>,
+    trait_def: &TraitDef,
 ) -> Result<(), ErrorGuaranteed> {
-    let item = tcx.hir().expect_item(def_id);
-    let impl_ = item.expect_impl();
-    let trait_def = tcx.trait_def(trait_ref.def_id);
-    let unsafe_attr = impl_.generics.params.iter().find(|p| p.pure_wrt_drop).map(|_| "may_dangle");
-    match (trait_def.unsafety, unsafe_attr, impl_.unsafety, impl_.polarity) {
-        (Unsafety::Normal, None, Unsafety::Unsafe, hir::ImplPolarity::Positive) => {
+    let trait_ref = trait_header.trait_ref;
+    let unsafe_attr =
+        tcx.generics_of(def_id).params.iter().find(|p| p.pure_wrt_drop).map(|_| "may_dangle");
+    match (trait_def.unsafety, unsafe_attr, trait_header.unsafety, trait_header.polarity) {
+        (Unsafety::Normal, None, Unsafety::Unsafe, Positive | Reservation) => {
+            let span = tcx.def_span(def_id);
             return Err(struct_span_code_err!(
                 tcx.dcx(),
                 tcx.def_span(def_id),
@@ -27,7 +27,7 @@ pub(super) fn check_item(
                 trait_ref.print_trait_sugared()
             )
             .with_span_suggestion_verbose(
-                item.span.with_hi(item.span.lo() + rustc_span::BytePos(7)),
+                span.with_hi(span.lo() + rustc_span::BytePos(7)),
                 "remove `unsafe` from this trait implementation",
                 "",
                 rustc_errors::Applicability::MachineApplicable,
@@ -35,10 +35,11 @@ pub(super) fn check_item(
             .emit());
         }
 
-        (Unsafety::Unsafe, _, Unsafety::Normal, hir::ImplPolarity::Positive) => {
+        (Unsafety::Unsafe, _, Unsafety::Normal, Positive | Reservation) => {
+            let span = tcx.def_span(def_id);
             return Err(struct_span_code_err!(
                 tcx.dcx(),
-                tcx.def_span(def_id),
+                span,
                 E0200,
                 "the trait `{}` requires an `unsafe impl` declaration",
                 trait_ref.print_trait_sugared()
@@ -50,7 +51,7 @@ pub(super) fn check_item(
                 trait_ref.print_trait_sugared()
             ))
             .with_span_suggestion_verbose(
-                item.span.shrink_to_lo(),
+                span.shrink_to_lo(),
                 "add `unsafe` to this trait implementation",
                 "unsafe ",
                 rustc_errors::Applicability::MaybeIncorrect,
@@ -58,10 +59,11 @@ pub(super) fn check_item(
             .emit());
         }
 
-        (Unsafety::Normal, Some(attr_name), Unsafety::Normal, hir::ImplPolarity::Positive) => {
+        (Unsafety::Normal, Some(attr_name), Unsafety::Normal, Positive | Reservation) => {
+            let span = tcx.def_span(def_id);
             return Err(struct_span_code_err!(
                 tcx.dcx(),
-                tcx.def_span(def_id),
+                span,
                 E0569,
                 "requires an `unsafe impl` declaration due to `#[{}]` attribute",
                 attr_name
@@ -73,7 +75,7 @@ pub(super) fn check_item(
                 trait_ref.print_trait_sugared()
             ))
             .with_span_suggestion_verbose(
-                item.span.shrink_to_lo(),
+                span.shrink_to_lo(),
                 "add `unsafe` to this trait implementation",
                 "unsafe ",
                 rustc_errors::Applicability::MaybeIncorrect,
@@ -81,14 +83,14 @@ pub(super) fn check_item(
             .emit());
         }
 
-        (_, _, Unsafety::Unsafe, hir::ImplPolarity::Negative(_)) => {
+        (_, _, Unsafety::Unsafe, Negative) => {
             // Reported in AST validation
-            tcx.dcx().span_delayed_bug(item.span, "unsafe negative impl");
+            tcx.dcx().span_delayed_bug(tcx.def_span(def_id), "unsafe negative impl");
             Ok(())
         }
-        (_, _, Unsafety::Normal, hir::ImplPolarity::Negative(_))
-        | (Unsafety::Unsafe, _, Unsafety::Unsafe, hir::ImplPolarity::Positive)
-        | (Unsafety::Normal, Some(_), Unsafety::Unsafe, hir::ImplPolarity::Positive)
+        (_, _, Unsafety::Normal, Negative)
+        | (Unsafety::Unsafe, _, Unsafety::Unsafe, Positive | Reservation)
+        | (Unsafety::Normal, Some(_), Unsafety::Unsafe, Positive | Reservation)
         | (Unsafety::Normal, None, Unsafety::Normal, _) => Ok(()),
     }
 }
diff --git a/compiler/rustc_hir_analysis/src/collect.rs b/compiler/rustc_hir_analysis/src/collect.rs
index e8787d159ae..4891dae47c6 100644
--- a/compiler/rustc_hir_analysis/src/collect.rs
+++ b/compiler/rustc_hir_analysis/src/collect.rs
@@ -1539,6 +1539,7 @@ fn impl_trait_header(
             };
             ty::EarlyBinder::bind(ty::ImplTraitHeader {
                 trait_ref,
+                unsafety: impl_.unsafety,
                 polarity: polarity_of_impl(tcx, def_id,  impl_, item.span)
             })
         })
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 6ee74ef2fb6..2d6c6cfbcd1 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -252,6 +252,7 @@ pub struct ImplHeader<'tcx> {
 pub struct ImplTraitHeader<'tcx> {
     pub trait_ref: ty::TraitRef<'tcx>,
     pub polarity: ImplPolarity,
+    pub unsafety: hir::Unsafety,
 }
 
 #[derive(Copy, Clone, PartialEq, Eq, Debug, TypeFoldable, TypeVisitable)]