about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCamille GILLOT <gillot.camille@gmail.com>2022-08-15 11:58:27 +0200
committerCamille GILLOT <gillot.camille@gmail.com>2022-08-21 12:54:26 +0200
commit10e71dfdb819ce9c2acee0b49c452b805c141041 (patch)
treeaaa9309762b5e1115a2fae5e21bd960b7b0d747d
parente4fd2bd7214dc2d3f8ec6a876b1687c9f3cf7beb (diff)
downloadrust-10e71dfdb819ce9c2acee0b49c452b805c141041.tar.gz
rust-10e71dfdb819ce9c2acee0b49c452b805c141041.zip
Also validate types before inlining.
-rw-r--r--compiler/rustc_const_eval/src/transform/validate.rs30
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs100
2 files changed, 114 insertions, 16 deletions
diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs
index 15e820f2d19..1a14cd79fa0 100644
--- a/compiler/rustc_const_eval/src/transform/validate.rs
+++ b/compiler/rustc_const_eval/src/transform/validate.rs
@@ -89,22 +89,20 @@ pub fn equal_up_to_regions<'tcx>(
 
     // Normalize lifetimes away on both sides, then compare.
     let normalize = |ty: Ty<'tcx>| {
-        tcx.normalize_erasing_regions(
-            param_env,
-            ty.fold_with(&mut BottomUpFolder {
-                tcx,
-                // FIXME: We erase all late-bound lifetimes, but this is not fully correct.
-                // If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
-                // this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
-                // since one may have an `impl SomeTrait for fn(&32)` and
-                // `impl SomeTrait for fn(&'static u32)` at the same time which
-                // specify distinct values for Assoc. (See also #56105)
-                lt_op: |_| tcx.lifetimes.re_erased,
-                // Leave consts and types unchanged.
-                ct_op: |ct| ct,
-                ty_op: |ty| ty,
-            }),
-        )
+        let ty = ty.fold_with(&mut BottomUpFolder {
+            tcx,
+            // FIXME: We erase all late-bound lifetimes, but this is not fully correct.
+            // If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
+            // this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
+            // since one may have an `impl SomeTrait for fn(&32)` and
+            // `impl SomeTrait for fn(&'static u32)` at the same time which
+            // specify distinct values for Assoc. (See also #56105)
+            lt_op: |_| tcx.lifetimes.re_erased,
+            // Leave consts and types unchanged.
+            ct_op: |ct| ct,
+            ty_op: |ty| ty,
+        });
+        tcx.try_normalize_erasing_regions(param_env, ty).unwrap_or(ty)
     };
     tcx.infer_ctxt().enter(|infcx| infcx.can_eq(param_env, normalize(src), normalize(dest)).is_ok())
 }
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 481c4c56304..b454d70d05d 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -11,6 +11,7 @@ use rustc_middle::ty::subst::Subst;
 use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyCtxt};
 use rustc_session::config::OptLevel;
 use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span};
+use rustc_target::abi::VariantIdx;
 use rustc_target::spec::abi::Abi;
 
 use super::simplify::{remove_dead_blocks, CfgSimplifier};
@@ -423,6 +424,7 @@ impl<'tcx> Inliner<'tcx> {
             instance: callsite.callee,
             callee_body,
             cost: 0,
+            validation: Ok(()),
         };
 
         // Traverse the MIR manually so we can account for the effects of inlining on the CFG.
@@ -458,6 +460,9 @@ impl<'tcx> Inliner<'tcx> {
             checker.visit_local_decl(v, &callee_body.local_decls[v]);
         }
 
+        // Abort if type validation found anything fishy.
+        checker.validation?;
+
         let cost = checker.cost;
         if let InlineAttr::Always = callee_attrs.inline {
             debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
@@ -738,6 +743,7 @@ struct CostChecker<'b, 'tcx> {
     cost: usize,
     callee_body: &'b Body<'tcx>,
     instance: ty::Instance<'tcx>,
+    validation: Result<(), &'static str>,
 }
 
 impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
@@ -818,6 +824,100 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
 
         self.super_local_decl(local, local_decl)
     }
+
+    /// This method duplicates code from MIR validation in an attempt to detect type mismatches due
+    /// to normalization failure.
+    fn visit_projection_elem(
+        &mut self,
+        local: Local,
+        proj_base: &[PlaceElem<'tcx>],
+        elem: PlaceElem<'tcx>,
+        context: PlaceContext,
+        location: Location,
+    ) {
+        if let ProjectionElem::Field(f, ty) = elem {
+            let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) };
+            let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx);
+            let check_equal = |this: &mut Self, f_ty| {
+                if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) {
+                    trace!(?ty, ?f_ty);
+                    this.validation = Err("failed to normalize projection type");
+                    return;
+                }
+            };
+
+            let kind = match parent_ty.ty.kind() {
+                &ty::Opaque(def_id, substs) => {
+                    self.tcx.bound_type_of(def_id).subst(self.tcx, substs).kind()
+                }
+                kind => kind,
+            };
+
+            match kind {
+                ty::Tuple(fields) => {
+                    let Some(f_ty) = fields.get(f.as_usize()) else {
+                        self.validation = Err("malformed MIR");
+                        return;
+                    };
+                    check_equal(self, *f_ty);
+                }
+                ty::Adt(adt_def, substs) => {
+                    let var = parent_ty.variant_index.unwrap_or(VariantIdx::from_u32(0));
+                    let Some(field) = adt_def.variant(var).fields.get(f.as_usize()) else {
+                        self.validation = Err("malformed MIR");
+                        return;
+                    };
+                    check_equal(self, field.ty(self.tcx, substs));
+                }
+                ty::Closure(_, substs) => {
+                    let substs = substs.as_closure();
+                    let Some(f_ty) = substs.upvar_tys().nth(f.as_usize()) else {
+                        self.validation = Err("malformed MIR");
+                        return;
+                    };
+                    check_equal(self, f_ty);
+                }
+                &ty::Generator(def_id, substs, _) => {
+                    let f_ty = if let Some(var) = parent_ty.variant_index {
+                        let gen_body = if def_id == self.callee_body.source.def_id() {
+                            self.callee_body
+                        } else {
+                            self.tcx.optimized_mir(def_id)
+                        };
+
+                        let Some(layout) = gen_body.generator_layout() else {
+                            self.validation = Err("malformed MIR");
+                            return;
+                        };
+
+                        let Some(&local) = layout.variant_fields[var].get(f) else {
+                            self.validation = Err("malformed MIR");
+                            return;
+                        };
+
+                        let Some(&f_ty) = layout.field_tys.get(local) else {
+                            self.validation = Err("malformed MIR");
+                            return;
+                        };
+
+                        f_ty
+                    } else {
+                        let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else {
+                            self.validation = Err("malformed MIR");
+                            return;
+                        };
+
+                        f_ty
+                    };
+
+                    check_equal(self, f_ty);
+                }
+                _ => self.validation = Err("malformed MIR"),
+            }
+        }
+
+        self.super_projection_elem(local, proj_base, elem, context, location);
+    }
 }
 
 /**