about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMatthias Krüger <matthias.krueger@famsik.de>2024-01-22 16:54:59 +0100
committerGitHub <noreply@github.com>2024-01-22 16:54:59 +0100
commita12e2ff7b2e5b3adbaa010c507039ebcb41405c3 (patch)
tree05afcbc018d162416d21e9861755d0c551d41e0d
parentf194a84ce2e08c96e71b21ca2726b544704e263f (diff)
parentf26f52c42bdce2bf6db4a28ff33608ba3890add2 (diff)
downloadrust-a12e2ff7b2e5b3adbaa010c507039ebcb41405c3.tar.gz
rust-a12e2ff7b2e5b3adbaa010c507039ebcb41405c3.zip
Rollup merge of #120137 - compiler-errors:validate-aggregates, r=nnethercote
Validate AggregateKind types in MIR

Would have helped me catch some bugs when writing shims for async closures
-rw-r--r--compiler/rustc_const_eval/src/transform/validate.rs62
1 files changed, 61 insertions, 1 deletions
diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs
index 9c2f336e912..5564902396e 100644
--- a/compiler/rustc_const_eval/src/transform/validate.rs
+++ b/compiler/rustc_const_eval/src/transform/validate.rs
@@ -796,7 +796,67 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
             };
         }
         match rvalue {
-            Rvalue::Use(_) | Rvalue::CopyForDeref(_) | Rvalue::Aggregate(..) => {}
+            Rvalue::Use(_) | Rvalue::CopyForDeref(_) => {}
+            Rvalue::Aggregate(kind, fields) => match **kind {
+                AggregateKind::Tuple => {}
+                AggregateKind::Array(dest) => {
+                    for src in fields {
+                        if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) {
+                            self.fail(location, "array field has the wrong type");
+                        }
+                    }
+                }
+                AggregateKind::Adt(def_id, idx, args, _, Some(field)) => {
+                    let adt_def = self.tcx.adt_def(def_id);
+                    assert!(adt_def.is_union());
+                    assert_eq!(idx, FIRST_VARIANT);
+                    let dest = adt_def.non_enum_variant().fields[field].ty(self.tcx, args);
+                    if fields.len() != 1 {
+                        self.fail(location, "unions should have one initialized field");
+                    }
+                    if !self.mir_assign_valid_types(fields.raw[0].ty(self.body, self.tcx), dest) {
+                        self.fail(location, "union field has the wrong type");
+                    }
+                }
+                AggregateKind::Adt(def_id, idx, args, _, None) => {
+                    let adt_def = self.tcx.adt_def(def_id);
+                    assert!(!adt_def.is_union());
+                    let variant = &adt_def.variants()[idx];
+                    if variant.fields.len() != fields.len() {
+                        self.fail(location, "adt has the wrong number of initialized fields");
+                    }
+                    for (src, dest) in std::iter::zip(fields, &variant.fields) {
+                        if !self.mir_assign_valid_types(
+                            src.ty(self.body, self.tcx),
+                            dest.ty(self.tcx, args),
+                        ) {
+                            self.fail(location, "adt field has the wrong type");
+                        }
+                    }
+                }
+                AggregateKind::Closure(_, args) => {
+                    let upvars = args.as_closure().upvar_tys();
+                    if upvars.len() != fields.len() {
+                        self.fail(location, "closure has the wrong number of initialized fields");
+                    }
+                    for (src, dest) in std::iter::zip(fields, upvars) {
+                        if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) {
+                            self.fail(location, "closure field has the wrong type");
+                        }
+                    }
+                }
+                AggregateKind::Coroutine(_, args) => {
+                    let upvars = args.as_coroutine().upvar_tys();
+                    if upvars.len() != fields.len() {
+                        self.fail(location, "coroutine has the wrong number of initialized fields");
+                    }
+                    for (src, dest) in std::iter::zip(fields, upvars) {
+                        if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) {
+                            self.fail(location, "coroutine field has the wrong type");
+                        }
+                    }
+                }
+            },
             Rvalue::Ref(_, BorrowKind::Fake, _) => {
                 if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) {
                     self.fail(