about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2025-05-28 19:28:11 +0000
committerMichael Goulet <michael@errs.io>2025-05-29 11:16:44 +0000
commit2a339ce492349a27c80a0bb9e63510f1798beae1 (patch)
tree4821f7df61442284e36ff38c413d97217585127e
parent38081f22c2d7380f272aa1d7fa9b935637701c2d (diff)
downloadrust-2a339ce492349a27c80a0bb9e63510f1798beae1.tar.gz
rust-2a339ce492349a27c80a0bb9e63510f1798beae1.zip
Structurally normalize types as needed in projection_ty_core
-rw-r--r--compiler/rustc_borrowck/src/type_check/mod.rs12
-rw-r--r--compiler/rustc_middle/src/mir/statement.rs78
-rw-r--r--compiler/rustc_mir_build/src/builder/custom/parse/instruction.rs4
-rw-r--r--tests/ui/traits/next-solver/normalize/normalize-place-elem.rs32
4 files changed, 81 insertions, 45 deletions
diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs
index 8c512257120..4d57e7aac58 100644
--- a/compiler/rustc_borrowck/src/type_check/mod.rs
+++ b/compiler/rustc_borrowck/src/type_check/mod.rs
@@ -474,17 +474,15 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
             let projected_ty = curr_projected_ty.projection_ty_core(
                 tcx,
                 proj,
-                |this, field, ()| {
-                    let ty = this.field_ty(tcx, field);
-                    self.structurally_resolve(ty, locations)
-                },
-                |_, _| unreachable!(),
+                |ty| self.structurally_resolve(ty, locations),
+                |ty, variant_index, field, ()| PlaceTy::field_ty(tcx, ty, variant_index, field),
+                |_| unreachable!(),
             );
             curr_projected_ty = projected_ty;
         }
         trace!(?curr_projected_ty);
 
-        let ty = curr_projected_ty.ty;
+        let ty = self.normalize(curr_projected_ty.ty, locations);
         self.relate_types(ty, v.xform(ty::Contravariant), a, locations, category)?;
 
         Ok(())
@@ -1852,7 +1850,7 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
             | ProjectionElem::Downcast(..) => {}
             ProjectionElem::Field(field, fty) => {
                 let fty = self.normalize(fty, location);
-                let ty = base_ty.field_ty(tcx, field);
+                let ty = PlaceTy::field_ty(tcx, base_ty.ty, base_ty.variant_index, field);
                 let ty = self.normalize(ty, location);
                 debug!(?fty, ?ty);
 
diff --git a/compiler/rustc_middle/src/mir/statement.rs b/compiler/rustc_middle/src/mir/statement.rs
index d59b6df44ed..06e41e64fdc 100644
--- a/compiler/rustc_middle/src/mir/statement.rs
+++ b/compiler/rustc_middle/src/mir/statement.rs
@@ -88,26 +88,31 @@ impl<'tcx> PlaceTy<'tcx> {
     ///
     /// Note that the resulting type has not been normalized.
     #[instrument(level = "debug", skip(tcx), ret)]
-    pub fn field_ty(self, tcx: TyCtxt<'tcx>, f: FieldIdx) -> Ty<'tcx> {
-        if let Some(variant_index) = self.variant_index {
-            match *self.ty.kind() {
+    pub fn field_ty(
+        tcx: TyCtxt<'tcx>,
+        self_ty: Ty<'tcx>,
+        variant_idx: Option<VariantIdx>,
+        f: FieldIdx,
+    ) -> Ty<'tcx> {
+        if let Some(variant_index) = variant_idx {
+            match *self_ty.kind() {
                 ty::Adt(adt_def, args) if adt_def.is_enum() => {
                     adt_def.variant(variant_index).fields[f].ty(tcx, args)
                 }
                 ty::Coroutine(def_id, args) => {
                     let mut variants = args.as_coroutine().state_tys(def_id, tcx);
                     let Some(mut variant) = variants.nth(variant_index.into()) else {
-                        bug!("variant {variant_index:?} of coroutine out of range: {self:?}");
+                        bug!("variant {variant_index:?} of coroutine out of range: {self_ty:?}");
                     };
 
-                    variant
-                        .nth(f.index())
-                        .unwrap_or_else(|| bug!("field {f:?} out of range: {self:?}"))
+                    variant.nth(f.index()).unwrap_or_else(|| {
+                        bug!("field {f:?} out of range of variant: {self_ty:?} {variant_idx:?}")
+                    })
                 }
-                _ => bug!("can't downcast non-adt non-coroutine type: {self:?}"),
+                _ => bug!("can't downcast non-adt non-coroutine type: {self_ty:?}"),
             }
         } else {
-            match self.ty.kind() {
+            match self_ty.kind() {
                 ty::Adt(adt_def, args) if !adt_def.is_enum() => {
                     adt_def.non_enum_variant().fields[f].ty(tcx, args)
                 }
@@ -116,26 +121,25 @@ impl<'tcx> PlaceTy<'tcx> {
                     .upvar_tys()
                     .get(f.index())
                     .copied()
-                    .unwrap_or_else(|| bug!("field {f:?} out of range: {self:?}")),
+                    .unwrap_or_else(|| bug!("field {f:?} out of range: {self_ty:?}")),
                 ty::CoroutineClosure(_, args) => args
                     .as_coroutine_closure()
                     .upvar_tys()
                     .get(f.index())
                     .copied()
-                    .unwrap_or_else(|| bug!("field {f:?} out of range: {self:?}")),
+                    .unwrap_or_else(|| bug!("field {f:?} out of range: {self_ty:?}")),
                 // Only prefix fields (upvars and current state) are
                 // accessible without a variant index.
-                ty::Coroutine(_, args) => args
-                    .as_coroutine()
-                    .prefix_tys()
-                    .get(f.index())
-                    .copied()
-                    .unwrap_or_else(|| bug!("field {f:?} out of range: {self:?}")),
+                ty::Coroutine(_, args) => {
+                    args.as_coroutine().prefix_tys().get(f.index()).copied().unwrap_or_else(|| {
+                        bug!("field {f:?} out of range of prefixes for {self_ty}")
+                    })
+                }
                 ty::Tuple(tys) => tys
                     .get(f.index())
                     .copied()
-                    .unwrap_or_else(|| bug!("field {f:?} out of range: {self:?}")),
-                _ => bug!("can't project out of {self:?}"),
+                    .unwrap_or_else(|| bug!("field {f:?} out of range: {self_ty:?}")),
+                _ => bug!("can't project out of {self_ty:?}"),
             }
         }
     }
@@ -148,11 +152,11 @@ impl<'tcx> PlaceTy<'tcx> {
         elems.iter().fold(self, |place_ty, &elem| place_ty.projection_ty(tcx, elem))
     }
 
-    /// Convenience wrapper around `projection_ty_core` for
-    /// `PlaceElem`, where we can just use the `Ty` that is already
-    /// stored inline on field projection elems.
+    /// Convenience wrapper around `projection_ty_core` for `PlaceElem`,
+    /// where we can just use the `Ty` that is already stored inline on
+    /// field projection elems.
     pub fn projection_ty(self, tcx: TyCtxt<'tcx>, elem: PlaceElem<'tcx>) -> PlaceTy<'tcx> {
-        self.projection_ty_core(tcx, &elem, |_, _, ty| ty, |_, ty| ty)
+        self.projection_ty_core(tcx, &elem, |ty| ty, |_, _, _, ty| ty, |ty| ty)
     }
 
     /// `place_ty.projection_ty_core(tcx, elem, |...| { ... })`
@@ -164,8 +168,9 @@ impl<'tcx> PlaceTy<'tcx> {
         self,
         tcx: TyCtxt<'tcx>,
         elem: &ProjectionElem<V, T>,
-        mut handle_field: impl FnMut(&Self, FieldIdx, T) -> Ty<'tcx>,
-        mut handle_opaque_cast_and_subtype: impl FnMut(&Self, T) -> Ty<'tcx>,
+        mut structurally_normalize: impl FnMut(Ty<'tcx>) -> Ty<'tcx>,
+        mut handle_field: impl FnMut(Ty<'tcx>, Option<VariantIdx>, FieldIdx, T) -> Ty<'tcx>,
+        mut handle_opaque_cast_and_subtype: impl FnMut(T) -> Ty<'tcx>,
     ) -> PlaceTy<'tcx>
     where
         V: ::std::fmt::Debug,
@@ -176,16 +181,16 @@ impl<'tcx> PlaceTy<'tcx> {
         }
         let answer = match *elem {
             ProjectionElem::Deref => {
-                let ty = self.ty.builtin_deref(true).unwrap_or_else(|| {
+                let ty = structurally_normalize(self.ty).builtin_deref(true).unwrap_or_else(|| {
                     bug!("deref projection of non-dereferenceable ty {:?}", self)
                 });
                 PlaceTy::from_ty(ty)
             }
             ProjectionElem::Index(_) | ProjectionElem::ConstantIndex { .. } => {
-                PlaceTy::from_ty(self.ty.builtin_index().unwrap())
+                PlaceTy::from_ty(structurally_normalize(self.ty).builtin_index().unwrap())
             }
             ProjectionElem::Subslice { from, to, from_end } => {
-                PlaceTy::from_ty(match self.ty.kind() {
+                PlaceTy::from_ty(match structurally_normalize(self.ty).kind() {
                     ty::Slice(..) => self.ty,
                     ty::Array(inner, _) if !from_end => Ty::new_array(tcx, *inner, to - from),
                     ty::Array(inner, size) if from_end => {
@@ -201,17 +206,18 @@ impl<'tcx> PlaceTy<'tcx> {
             ProjectionElem::Downcast(_name, index) => {
                 PlaceTy { ty: self.ty, variant_index: Some(index) }
             }
-            ProjectionElem::Field(f, fty) => PlaceTy::from_ty(handle_field(&self, f, fty)),
-            ProjectionElem::OpaqueCast(ty) => {
-                PlaceTy::from_ty(handle_opaque_cast_and_subtype(&self, ty))
-            }
-            ProjectionElem::Subtype(ty) => {
-                PlaceTy::from_ty(handle_opaque_cast_and_subtype(&self, ty))
-            }
+            ProjectionElem::Field(f, fty) => PlaceTy::from_ty(handle_field(
+                structurally_normalize(self.ty),
+                self.variant_index,
+                f,
+                fty,
+            )),
+            ProjectionElem::OpaqueCast(ty) => PlaceTy::from_ty(handle_opaque_cast_and_subtype(ty)),
+            ProjectionElem::Subtype(ty) => PlaceTy::from_ty(handle_opaque_cast_and_subtype(ty)),
 
             // FIXME(unsafe_binders): Rename `handle_opaque_cast_and_subtype` to be more general.
             ProjectionElem::UnwrapUnsafeBinder(ty) => {
-                PlaceTy::from_ty(handle_opaque_cast_and_subtype(&self, ty))
+                PlaceTy::from_ty(handle_opaque_cast_and_subtype(ty))
             }
         };
         debug!("projection_ty self: {:?} elem: {:?} yields: {:?}", self, elem, answer);
diff --git a/compiler/rustc_mir_build/src/builder/custom/parse/instruction.rs b/compiler/rustc_mir_build/src/builder/custom/parse/instruction.rs
index 494ee33fd8b..9825b947fe0 100644
--- a/compiler/rustc_mir_build/src/builder/custom/parse/instruction.rs
+++ b/compiler/rustc_mir_build/src/builder/custom/parse/instruction.rs
@@ -323,9 +323,9 @@ impl<'a, 'tcx> ParseCtxt<'a, 'tcx> {
     fn parse_place_inner(&self, expr_id: ExprId) -> PResult<(Place<'tcx>, PlaceTy<'tcx>)> {
         let (parent, proj) = parse_by_kind!(self, expr_id, expr, "place",
             @call(mir_field, args) => {
-                let (parent, ty) = self.parse_place_inner(args[0])?;
+                let (parent, place_ty) = self.parse_place_inner(args[0])?;
                 let field = FieldIdx::from_u32(self.parse_integer_literal(args[1])? as u32);
-                let field_ty = ty.field_ty(self.tcx, field);
+                let field_ty = PlaceTy::field_ty(self.tcx, place_ty.ty, place_ty.variant_index, field);
                 let proj = PlaceElem::Field(field, field_ty);
                 let place = parent.project_deeper(&[proj], self.tcx);
                 return Ok((place, PlaceTy::from_ty(field_ty)));
diff --git a/tests/ui/traits/next-solver/normalize/normalize-place-elem.rs b/tests/ui/traits/next-solver/normalize/normalize-place-elem.rs
new file mode 100644
index 00000000000..9a600875bb9
--- /dev/null
+++ b/tests/ui/traits/next-solver/normalize/normalize-place-elem.rs
@@ -0,0 +1,32 @@
+//@ check-pass
+//@ compile-flags: -Znext-solver
+
+// Regression test for <https://github.com/rust-lang/trait-system-refactor-initiative/issues/221>.
+// Ensure that we normalize after applying projection elems in MIR typeck.
+
+use std::marker::PhantomData;
+
+#[derive(Copy, Clone)]
+struct Span;
+
+trait AstKind {
+    type Inner;
+}
+
+struct WithSpan;
+impl AstKind for WithSpan {
+    type Inner
+        = (i32,);
+}
+
+struct Expr<'a> { f: &'a <WithSpan as AstKind>::Inner }
+
+impl Expr<'_> {
+    fn span(self) {
+        match self {
+            Self { f: (n,) } => {},
+        }
+    }
+}
+
+fn main() {}