about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_hir_analysis/src/autoderef.rs76
-rw-r--r--compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs29
-rw-r--r--compiler/rustc_hir_typeck/src/writeback.rs36
-rw-r--r--compiler/rustc_trait_selection/src/traits/mod.rs2
-rw-r--r--compiler/rustc_trait_selection/src/traits/structural_normalize.rs55
-rw-r--r--tests/ui/traits/new-solver/normalize-rcvr-for-inherent.rs25
-rw-r--r--tests/ui/traits/new-solver/structural-resolve-field.rs13
7 files changed, 196 insertions, 40 deletions
diff --git a/compiler/rustc_hir_analysis/src/autoderef.rs b/compiler/rustc_hir_analysis/src/autoderef.rs
index 1cf93c86f4f..d6d1498d708 100644
--- a/compiler/rustc_hir_analysis/src/autoderef.rs
+++ b/compiler/rustc_hir_analysis/src/autoderef.rs
@@ -1,6 +1,5 @@
 use crate::errors::AutoDerefReachedRecursionLimit;
 use crate::traits::query::evaluate_obligation::InferCtxtExt;
-use crate::traits::NormalizeExt;
 use crate::traits::{self, TraitEngine, TraitEngineExt};
 use rustc_infer::infer::InferCtxt;
 use rustc_middle::ty::TypeVisitableExt;
@@ -9,6 +8,7 @@ use rustc_session::Limit;
 use rustc_span::def_id::LocalDefId;
 use rustc_span::def_id::LOCAL_CRATE;
 use rustc_span::Span;
+use rustc_trait_selection::traits::StructurallyNormalizeExt;
 
 #[derive(Copy, Clone, Debug)]
 pub enum AutoderefKind {
@@ -66,14 +66,27 @@ impl<'a, 'tcx> Iterator for Autoderef<'a, 'tcx> {
         }
 
         // Otherwise, deref if type is derefable:
-        let (kind, new_ty) =
-            if let Some(mt) = self.state.cur_ty.builtin_deref(self.include_raw_pointers) {
-                (AutoderefKind::Builtin, mt.ty)
-            } else if let Some(ty) = self.overloaded_deref_ty(self.state.cur_ty) {
-                (AutoderefKind::Overloaded, ty)
+        let (kind, new_ty) = if let Some(ty::TypeAndMut { ty, .. }) =
+            self.state.cur_ty.builtin_deref(self.include_raw_pointers)
+        {
+            debug_assert_eq!(ty, self.infcx.resolve_vars_if_possible(ty));
+            // NOTE: we may still need to normalize the built-in deref in case
+            // we have some type like `&<Ty as Trait>::Assoc`, since users of
+            // autoderef expect this type to have been structurally normalized.
+            if self.infcx.tcx.trait_solver_next()
+                && let ty::Alias(ty::Projection, _) = ty.kind()
+            {
+                let (normalized_ty, obligations) = self.structurally_normalize(ty)?;
+                self.state.obligations.extend(obligations);
+                (AutoderefKind::Builtin, normalized_ty)
             } else {
-                return None;
-            };
+                (AutoderefKind::Builtin, ty)
+            }
+        } else if let Some(ty) = self.overloaded_deref_ty(self.state.cur_ty) {
+            (AutoderefKind::Overloaded, ty)
+        } else {
+            return None;
+        };
 
         if new_ty.references_error() {
             return None;
@@ -119,14 +132,11 @@ impl<'a, 'tcx> Autoderef<'a, 'tcx> {
 
     fn overloaded_deref_ty(&mut self, ty: Ty<'tcx>) -> Option<Ty<'tcx>> {
         debug!("overloaded_deref_ty({:?})", ty);
-
         let tcx = self.infcx.tcx;
 
         // <ty as Deref>
         let trait_ref = ty::TraitRef::new(tcx, tcx.lang_items().deref_trait()?, [ty]);
-
         let cause = traits::ObligationCause::misc(self.span, self.body_id);
-
         let obligation = traits::Obligation::new(
             tcx,
             cause.clone(),
@@ -138,26 +148,48 @@ impl<'a, 'tcx> Autoderef<'a, 'tcx> {
             return None;
         }
 
-        let normalized_ty = self
+        let (normalized_ty, obligations) =
+            self.structurally_normalize(tcx.mk_projection(tcx.lang_items().deref_target()?, [ty]))?;
+        debug!("overloaded_deref_ty({:?}) = ({:?}, {:?})", ty, normalized_ty, obligations);
+        self.state.obligations.extend(obligations);
+
+        Some(self.infcx.resolve_vars_if_possible(normalized_ty))
+    }
+
+    #[instrument(level = "debug", skip(self), ret)]
+    pub fn structurally_normalize(
+        &self,
+        ty: Ty<'tcx>,
+    ) -> Option<(Ty<'tcx>, Vec<traits::PredicateObligation<'tcx>>)> {
+        let tcx = self.infcx.tcx;
+        let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new_in_snapshot(tcx);
+
+        let cause = traits::ObligationCause::misc(self.span, self.body_id);
+        let normalized_ty = match self
             .infcx
             .at(&cause, self.param_env)
-            .normalize(tcx.mk_projection(tcx.lang_items().deref_target()?, trait_ref.substs));
-        let mut fulfillcx = <dyn TraitEngine<'tcx>>::new_in_snapshot(tcx);
-        let normalized_ty =
-            normalized_ty.into_value_registering_obligations(self.infcx, &mut *fulfillcx);
-        let errors = fulfillcx.select_where_possible(&self.infcx);
+            .structurally_normalize(ty, &mut *fulfill_cx)
+        {
+            Ok(normalized_ty) => normalized_ty,
+            Err(errors) => {
+                // This shouldn't happen, except for evaluate/fulfill mismatches,
+                // but that's not a reason for an ICE (`predicate_may_hold` is conservative
+                // by design).
+                debug!(?errors, "encountered errors while fulfilling");
+                return None;
+            }
+        };
+
+        let errors = fulfill_cx.select_where_possible(&self.infcx);
         if !errors.is_empty() {
             // This shouldn't happen, except for evaluate/fulfill mismatches,
             // but that's not a reason for an ICE (`predicate_may_hold` is conservative
             // by design).
-            debug!("overloaded_deref_ty: encountered errors {:?} while fulfilling", errors);
+            debug!(?errors, "encountered errors while fulfilling");
             return None;
         }
-        let obligations = fulfillcx.pending_obligations();
-        debug!("overloaded_deref_ty({:?}) = ({:?}, {:?})", ty, normalized_ty, obligations);
-        self.state.obligations.extend(obligations);
 
-        Some(self.infcx.resolve_vars_if_possible(normalized_ty))
+        Some((normalized_ty, fulfill_cx.pending_obligations()))
     }
 
     /// Returns the final type we ended up with, which may be an inference
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
index 039316c74dd..9721e3b427d 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
@@ -35,7 +35,9 @@ use rustc_span::symbol::{kw, sym, Ident};
 use rustc_span::Span;
 use rustc_target::abi::FieldIdx;
 use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
-use rustc_trait_selection::traits::{self, NormalizeExt, ObligationCauseCode, ObligationCtxt};
+use rustc_trait_selection::traits::{
+    self, NormalizeExt, ObligationCauseCode, ObligationCtxt, StructurallyNormalizeExt,
+};
 
 use std::collections::hash_map::Entry;
 use std::slice;
@@ -1460,10 +1462,33 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     }
 
     /// Resolves `typ` by a single level if `typ` is a type variable.
+    ///
+    /// When the new solver is enabled, this will also attempt to normalize
+    /// the type if it's a projection (note that it will not deeply normalize
+    /// projections within the type, just the outermost layer of the type).
+    ///
     /// If no resolution is possible, then an error is reported.
     /// Numeric inference variables may be left unresolved.
     pub fn structurally_resolved_type(&self, sp: Span, ty: Ty<'tcx>) -> Ty<'tcx> {
-        let ty = self.resolve_vars_with_obligations(ty);
+        let mut ty = self.resolve_vars_with_obligations(ty);
+
+        if self.tcx.trait_solver_next()
+            && let ty::Alias(ty::Projection, _) = ty.kind()
+        {
+            match self
+                .at(&self.misc(sp), self.param_env)
+                .structurally_normalize(ty, &mut **self.fulfillment_cx.borrow_mut())
+            {
+                Ok(normalized_ty) => {
+                    ty = normalized_ty;
+                },
+                Err(errors) => {
+                    let guar = self.err_ctxt().report_fulfillment_errors(&errors);
+                    return self.tcx.ty_error(guar);
+                }
+            }
+        }
+
         if !ty.is_ty_var() {
             ty
         } else {
diff --git a/compiler/rustc_hir_typeck/src/writeback.rs b/compiler/rustc_hir_typeck/src/writeback.rs
index a4c6dd4332a..2daec205cfc 100644
--- a/compiler/rustc_hir_typeck/src/writeback.rs
+++ b/compiler/rustc_hir_typeck/src/writeback.rs
@@ -9,7 +9,6 @@ use rustc_errors::ErrorGuaranteed;
 use rustc_hir as hir;
 use rustc_hir::intravisit::{self, Visitor};
 use rustc_infer::infer::error_reporting::TypeAnnotationNeeded::E0282;
-use rustc_infer::infer::InferCtxt;
 use rustc_middle::hir::place::Place as HirPlace;
 use rustc_middle::mir::FakeReadCause;
 use rustc_middle::ty::adjustment::{Adjust, Adjustment, PointerCast};
@@ -737,8 +736,7 @@ impl Locatable for hir::HirId {
 /// The Resolver. This is the type folding engine that detects
 /// unresolved types and so forth.
 struct Resolver<'cx, 'tcx> {
-    tcx: TyCtxt<'tcx>,
-    infcx: &'cx InferCtxt<'tcx>,
+    fcx: &'cx FnCtxt<'cx, 'tcx>,
     span: &'cx dyn Locatable,
     body: &'tcx hir::Body<'tcx>,
 
@@ -752,18 +750,18 @@ impl<'cx, 'tcx> Resolver<'cx, 'tcx> {
         span: &'cx dyn Locatable,
         body: &'tcx hir::Body<'tcx>,
     ) -> Resolver<'cx, 'tcx> {
-        Resolver { tcx: fcx.tcx, infcx: fcx, span, body, replaced_with_error: None }
+        Resolver { fcx, span, body, replaced_with_error: None }
     }
 
     fn report_error(&self, p: impl Into<ty::GenericArg<'tcx>>) -> ErrorGuaranteed {
-        match self.tcx.sess.has_errors() {
+        match self.fcx.tcx.sess.has_errors() {
             Some(e) => e,
             None => self
-                .infcx
+                .fcx
                 .err_ctxt()
                 .emit_inference_failure_err(
-                    self.tcx.hir().body_owner_def_id(self.body.id()),
-                    self.span.to_span(self.tcx),
+                    self.fcx.tcx.hir().body_owner_def_id(self.body.id()),
+                    self.span.to_span(self.fcx.tcx),
                     p.into(),
                     E0282,
                     false,
@@ -795,40 +793,46 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for EraseEarlyRegions<'tcx> {
 
 impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Resolver<'cx, 'tcx> {
     fn interner(&self) -> TyCtxt<'tcx> {
-        self.tcx
+        self.fcx.tcx
     }
 
     fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
-        match self.infcx.fully_resolve(t) {
+        match self.fcx.fully_resolve(t) {
+            Ok(t) if self.fcx.tcx.trait_solver_next() => {
+                // We must normalize erasing regions here, since later lints
+                // expect that types that show up in the typeck are fully
+                // normalized.
+                self.fcx.tcx.try_normalize_erasing_regions(self.fcx.param_env, t).unwrap_or(t)
+            }
             Ok(t) => {
                 // Do not anonymize late-bound regions
                 // (e.g. keep `for<'a>` named `for<'a>`).
                 // This allows NLL to generate error messages that
                 // refer to the higher-ranked lifetime names written by the user.
-                EraseEarlyRegions { tcx: self.tcx }.fold_ty(t)
+                EraseEarlyRegions { tcx: self.fcx.tcx }.fold_ty(t)
             }
             Err(_) => {
                 debug!("Resolver::fold_ty: input type `{:?}` not fully resolvable", t);
                 let e = self.report_error(t);
                 self.replaced_with_error = Some(e);
-                self.interner().ty_error(e)
+                self.fcx.tcx.ty_error(e)
             }
         }
     }
 
     fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
         debug_assert!(!r.is_late_bound(), "Should not be resolving bound region.");
-        self.tcx.lifetimes.re_erased
+        self.fcx.tcx.lifetimes.re_erased
     }
 
     fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
-        match self.infcx.fully_resolve(ct) {
-            Ok(ct) => self.tcx.erase_regions(ct),
+        match self.fcx.fully_resolve(ct) {
+            Ok(ct) => self.fcx.tcx.erase_regions(ct),
             Err(_) => {
                 debug!("Resolver::fold_const: input const `{:?}` not fully resolvable", ct);
                 let e = self.report_error(ct);
                 self.replaced_with_error = Some(e);
-                self.interner().const_error(ct.ty(), e)
+                self.fcx.tcx.const_error(ct.ty(), e)
             }
         }
     }
diff --git a/compiler/rustc_trait_selection/src/traits/mod.rs b/compiler/rustc_trait_selection/src/traits/mod.rs
index 28dad8592a8..f265230ff77 100644
--- a/compiler/rustc_trait_selection/src/traits/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/mod.rs
@@ -17,6 +17,7 @@ pub mod query;
 mod select;
 mod specialize;
 mod structural_match;
+mod structural_normalize;
 mod util;
 mod vtable;
 pub mod wf;
@@ -62,6 +63,7 @@ pub use self::specialize::{
 pub use self::structural_match::{
     search_for_adt_const_param_violation, search_for_structural_match_violation,
 };
+pub use self::structural_normalize::StructurallyNormalizeExt;
 pub use self::util::elaborate;
 pub use self::util::{expand_trait_aliases, TraitAliasExpander};
 pub use self::util::{get_vtable_index_of_object_method, impl_item_is_final, upcast_choices};
diff --git a/compiler/rustc_trait_selection/src/traits/structural_normalize.rs b/compiler/rustc_trait_selection/src/traits/structural_normalize.rs
new file mode 100644
index 00000000000..af8dd0da579
--- /dev/null
+++ b/compiler/rustc_trait_selection/src/traits/structural_normalize.rs
@@ -0,0 +1,55 @@
+use rustc_infer::infer::at::At;
+use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
+use rustc_infer::traits::{FulfillmentError, TraitEngine};
+use rustc_middle::ty::{self, Ty};
+
+use crate::traits::{query::evaluate_obligation::InferCtxtExt, NormalizeExt, Obligation};
+
+pub trait StructurallyNormalizeExt<'tcx> {
+    fn structurally_normalize(
+        &self,
+        ty: Ty<'tcx>,
+        fulfill_cx: &mut dyn TraitEngine<'tcx>,
+    ) -> Result<Ty<'tcx>, Vec<FulfillmentError<'tcx>>>;
+}
+
+impl<'tcx> StructurallyNormalizeExt<'tcx> for At<'_, 'tcx> {
+    fn structurally_normalize(
+        &self,
+        mut ty: Ty<'tcx>,
+        fulfill_cx: &mut dyn TraitEngine<'tcx>,
+    ) -> Result<Ty<'tcx>, Vec<FulfillmentError<'tcx>>> {
+        assert!(!ty.is_ty_var(), "should have resolved vars before calling");
+
+        if self.infcx.tcx.trait_solver_next() {
+            while let ty::Alias(ty::Projection, projection_ty) = *ty.kind() {
+                let new_infer_ty = self.infcx.next_ty_var(TypeVariableOrigin {
+                    kind: TypeVariableOriginKind::NormalizeProjectionType,
+                    span: self.cause.span,
+                });
+                let obligation = Obligation::new(
+                    self.infcx.tcx,
+                    self.cause.clone(),
+                    self.param_env,
+                    ty::Binder::dummy(ty::ProjectionPredicate {
+                        projection_ty,
+                        term: new_infer_ty.into(),
+                    }),
+                );
+                if self.infcx.predicate_may_hold(&obligation) {
+                    fulfill_cx.register_predicate_obligation(self.infcx, obligation);
+                    let errors = fulfill_cx.select_where_possible(self.infcx);
+                    if !errors.is_empty() {
+                        return Err(errors);
+                    }
+                    ty = self.infcx.resolve_vars_if_possible(new_infer_ty);
+                } else {
+                    break;
+                }
+            }
+            Ok(ty)
+        } else {
+            Ok(self.normalize(ty).into_value_registering_obligations(self.infcx, fulfill_cx))
+        }
+    }
+}
diff --git a/tests/ui/traits/new-solver/normalize-rcvr-for-inherent.rs b/tests/ui/traits/new-solver/normalize-rcvr-for-inherent.rs
new file mode 100644
index 00000000000..d70534feb07
--- /dev/null
+++ b/tests/ui/traits/new-solver/normalize-rcvr-for-inherent.rs
@@ -0,0 +1,25 @@
+// compile-flags: -Ztrait-solver=next
+// check-pass
+
+// Verify that we can assemble inherent impl candidates on a possibly
+// unnormalized self type.
+
+trait Foo {
+    type Assoc;
+}
+impl Foo for i32 {
+    type Assoc = Bar;
+}
+
+struct Bar;
+impl Bar {
+    fn method(&self) {}
+}
+
+fn build<T: Foo>(_: T) -> T::Assoc {
+    todo!()
+}
+
+fn main() {
+    build(1i32).method();
+}
diff --git a/tests/ui/traits/new-solver/structural-resolve-field.rs b/tests/ui/traits/new-solver/structural-resolve-field.rs
new file mode 100644
index 00000000000..01899c9ad64
--- /dev/null
+++ b/tests/ui/traits/new-solver/structural-resolve-field.rs
@@ -0,0 +1,13 @@
+// compile-flags: -Ztrait-solver=next
+// check-pass
+
+#[derive(Default)]
+struct Foo {
+    x: i32,
+}
+
+fn main() {
+    let mut xs = <[Foo; 1]>::default();
+    xs[0].x = 1;
+    (&mut xs[0]).x = 2;
+}