about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2022-08-17 19:02:55 +0000
committerMichael Goulet <michael@errs.io>2022-08-17 19:02:55 +0000
commit64bd8c1dc427c8bcd69bed7c240b1f1f38bedbd3 (patch)
tree2167162a6159200487f5376637abffdb7b6b9adc
parent86c6ebee8fa0a5ad1e18e375113b06bd2849b634 (diff)
downloadrust-64bd8c1dc427c8bcd69bed7c240b1f1f38bedbd3.tar.gz
rust-64bd8c1dc427c8bcd69bed7c240b1f1f38bedbd3.zip
Make same_type_modulo_infer a proper TypeRelation
-rw-r--r--compiler/rustc_infer/src/infer/error_reporting/mod.rs132
-rw-r--r--compiler/rustc_middle/src/ty/sty.rs4
-rw-r--r--src/test/ui/implied-bounds/issue-100690.rs45
-rw-r--r--src/test/ui/implied-bounds/issue-100690.stderr22
4 files changed, 150 insertions, 53 deletions
diff --git a/compiler/rustc_infer/src/infer/error_reporting/mod.rs b/compiler/rustc_infer/src/infer/error_reporting/mod.rs
index 3c7dac2bfd8..d55708acc9a 100644
--- a/compiler/rustc_infer/src/infer/error_reporting/mod.rs
+++ b/compiler/rustc_infer/src/infer/error_reporting/mod.rs
@@ -66,6 +66,7 @@ use rustc_hir::lang_items::LangItem;
 use rustc_hir::Node;
 use rustc_middle::dep_graph::DepContext;
 use rustc_middle::ty::print::with_no_trimmed_paths;
+use rustc_middle::ty::relate::{self, RelateResult, TypeRelation};
 use rustc_middle::ty::{
     self, error::TypeError, Binder, List, Region, Subst, Ty, TyCtxt, TypeFoldable,
     TypeSuperVisitable, TypeVisitable,
@@ -2660,67 +2661,92 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
     /// Float types, respectively). When comparing two ADTs, these rules apply recursively.
     pub fn same_type_modulo_infer(&self, a: Ty<'tcx>, b: Ty<'tcx>) -> bool {
         let (a, b) = self.resolve_vars_if_possible((a, b));
-        match (a.kind(), b.kind()) {
-            (&ty::Adt(def_a, substs_a), &ty::Adt(def_b, substs_b)) => {
-                if def_a != def_b {
-                    return false;
-                }
+        SameTypeModuloInfer(self).relate(a, b).is_ok()
+    }
+}
 
-                substs_a
-                    .types()
-                    .zip(substs_b.types())
-                    .all(|(a, b)| self.same_type_modulo_infer(a, b))
-            }
-            (&ty::FnDef(did_a, substs_a), &ty::FnDef(did_b, substs_b)) => {
-                if did_a != did_b {
-                    return false;
-                }
+struct SameTypeModuloInfer<'a, 'tcx>(&'a InferCtxt<'a, 'tcx>);
 
-                substs_a
-                    .types()
-                    .zip(substs_b.types())
-                    .all(|(a, b)| self.same_type_modulo_infer(a, b))
-            }
-            (&ty::Int(_) | &ty::Uint(_), &ty::Infer(ty::InferTy::IntVar(_)))
+impl<'tcx> TypeRelation<'tcx> for SameTypeModuloInfer<'_, 'tcx> {
+    fn tcx(&self) -> TyCtxt<'tcx> {
+        self.0.tcx
+    }
+
+    fn param_env(&self) -> ty::ParamEnv<'tcx> {
+        // Unused, only for consts which we treat as always equal
+        ty::ParamEnv::empty()
+    }
+
+    fn tag(&self) -> &'static str {
+        "SameTypeModuloInfer"
+    }
+
+    fn a_is_expected(&self) -> bool {
+        true
+    }
+
+    fn relate_with_variance<T: relate::Relate<'tcx>>(
+        &mut self,
+        _variance: ty::Variance,
+        _info: ty::VarianceDiagInfo<'tcx>,
+        a: T,
+        b: T,
+    ) -> relate::RelateResult<'tcx, T> {
+        self.relate(a, b)
+    }
+
+    fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
+        match (a.kind(), b.kind()) {
+            (ty::Int(_) | ty::Uint(_), ty::Infer(ty::InferTy::IntVar(_)))
             | (
-                &ty::Infer(ty::InferTy::IntVar(_)),
-                &ty::Int(_) | &ty::Uint(_) | &ty::Infer(ty::InferTy::IntVar(_)),
+                ty::Infer(ty::InferTy::IntVar(_)),
+                ty::Int(_) | ty::Uint(_) | ty::Infer(ty::InferTy::IntVar(_)),
             )
-            | (&ty::Float(_), &ty::Infer(ty::InferTy::FloatVar(_)))
+            | (ty::Float(_), ty::Infer(ty::InferTy::FloatVar(_)))
             | (
-                &ty::Infer(ty::InferTy::FloatVar(_)),
-                &ty::Float(_) | &ty::Infer(ty::InferTy::FloatVar(_)),
+                ty::Infer(ty::InferTy::FloatVar(_)),
+                ty::Float(_) | ty::Infer(ty::InferTy::FloatVar(_)),
             )
-            | (&ty::Infer(ty::InferTy::TyVar(_)), _)
-            | (_, &ty::Infer(ty::InferTy::TyVar(_))) => true,
-            (&ty::Ref(_, ty_a, mut_a), &ty::Ref(_, ty_b, mut_b)) => {
-                mut_a == mut_b && self.same_type_modulo_infer(ty_a, ty_b)
-            }
-            (&ty::RawPtr(a), &ty::RawPtr(b)) => {
-                a.mutbl == b.mutbl && self.same_type_modulo_infer(a.ty, b.ty)
-            }
-            (&ty::Slice(a), &ty::Slice(b)) => self.same_type_modulo_infer(a, b),
-            (&ty::Array(a_ty, a_ct), &ty::Array(b_ty, b_ct)) => {
-                self.same_type_modulo_infer(a_ty, b_ty) && a_ct == b_ct
-            }
-            (&ty::Tuple(a), &ty::Tuple(b)) => {
-                if a.len() != b.len() {
-                    return false;
-                }
-                std::iter::zip(a.iter(), b.iter()).all(|(a, b)| self.same_type_modulo_infer(a, b))
-            }
-            (&ty::FnPtr(a), &ty::FnPtr(b)) => {
-                let a = a.skip_binder().inputs_and_output;
-                let b = b.skip_binder().inputs_and_output;
-                if a.len() != b.len() {
-                    return false;
-                }
-                std::iter::zip(a.iter(), b.iter()).all(|(a, b)| self.same_type_modulo_infer(a, b))
-            }
-            // FIXME(compiler-errors): This needs to be generalized more
-            _ => a == b,
+            | (ty::Infer(ty::InferTy::TyVar(_)), _)
+            | (_, ty::Infer(ty::InferTy::TyVar(_))) => Ok(a),
+            (ty::Infer(_), _) | (_, ty::Infer(_)) => Err(TypeError::Mismatch),
+            _ => relate::super_relate_tys(self, a, b),
         }
     }
+
+    fn regions(
+        &mut self,
+        a: ty::Region<'tcx>,
+        b: ty::Region<'tcx>,
+    ) -> RelateResult<'tcx, ty::Region<'tcx>> {
+        if (a.is_var() && b.is_free_or_static()) || (b.is_var() && a.is_free_or_static()) || a == b
+        {
+            Ok(a)
+        } else {
+            Err(TypeError::Mismatch)
+        }
+    }
+
+    fn binders<T>(
+        &mut self,
+        a: ty::Binder<'tcx, T>,
+        b: ty::Binder<'tcx, T>,
+    ) -> relate::RelateResult<'tcx, ty::Binder<'tcx, T>>
+    where
+        T: relate::Relate<'tcx>,
+    {
+        Ok(ty::Binder::dummy(self.relate(a.skip_binder(), b.skip_binder())?))
+    }
+
+    fn consts(
+        &mut self,
+        a: ty::Const<'tcx>,
+        _b: ty::Const<'tcx>,
+    ) -> relate::RelateResult<'tcx, ty::Const<'tcx>> {
+        // FIXME(compiler-errors): This could at least do some first-order
+        // relation
+        Ok(a)
+    }
 }
 
 impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs
index 52c3a38861e..0070575f213 100644
--- a/compiler/rustc_middle/src/ty/sty.rs
+++ b/compiler/rustc_middle/src/ty/sty.rs
@@ -1617,6 +1617,10 @@ impl<'tcx> Region<'tcx> {
             _ => self.is_free(),
         }
     }
+
+    pub fn is_var(self) -> bool {
+        matches!(self.kind(), ty::ReVar(_))
+    }
 }
 
 /// Type utilities
diff --git a/src/test/ui/implied-bounds/issue-100690.rs b/src/test/ui/implied-bounds/issue-100690.rs
new file mode 100644
index 00000000000..5599cd410ba
--- /dev/null
+++ b/src/test/ui/implied-bounds/issue-100690.rs
@@ -0,0 +1,45 @@
+// This code (probably) _should_ compile, but it currently does not because we
+// are not smart enough about implied bounds.
+
+use std::io;
+
+fn real_dispatch<T, F>(f: F) -> Result<(), io::Error>
+//~^ NOTE required by a bound in this
+where
+    F: FnOnce(&mut UIView<T>) -> Result<(), io::Error> + Send + 'static,
+    //~^ NOTE required by this bound in `real_dispatch`
+    //~| NOTE required by a bound in `real_dispatch`
+{
+    todo!()
+}
+
+#[derive(Debug)]
+struct UIView<'a, T: 'a> {
+    _phantom: std::marker::PhantomData<&'a mut T>,
+}
+
+trait Handle<'a, T: 'a, V, R> {
+    fn dispatch<F>(&self, f: F) -> Result<(), io::Error>
+    where
+        F: FnOnce(&mut V) -> R + Send + 'static;
+}
+
+#[derive(Debug, Clone)]
+struct TUIHandle<T> {
+    _phantom: std::marker::PhantomData<T>,
+}
+
+impl<'a, T: 'a> Handle<'a, T, UIView<'a, T>, Result<(), io::Error>> for TUIHandle<T> {
+    fn dispatch<F>(&self, f: F) -> Result<(), io::Error>
+    where
+        F: FnOnce(&mut UIView<'a, T>) -> Result<(), io::Error> + Send + 'static,
+    {
+        real_dispatch(f)
+        //~^ ERROR expected a `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
+        //~| NOTE expected an `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
+        //~| NOTE expected a closure with arguments
+        //~| NOTE required by a bound introduced by this call
+    }
+}
+
+fn main() {}
diff --git a/src/test/ui/implied-bounds/issue-100690.stderr b/src/test/ui/implied-bounds/issue-100690.stderr
new file mode 100644
index 00000000000..3f6af70d8ed
--- /dev/null
+++ b/src/test/ui/implied-bounds/issue-100690.stderr
@@ -0,0 +1,22 @@
+error[E0277]: expected a `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
+  --> $DIR/issue-100690.rs:37:23
+   |
+LL |         real_dispatch(f)
+   |         ------------- ^ expected an `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
+   |         |
+   |         required by a bound introduced by this call
+   |
+   = note: expected a closure with arguments `(&mut UIView<'a, T>,)`
+              found a closure with arguments `(&mut UIView<'_, T>,)`
+note: required by a bound in `real_dispatch`
+  --> $DIR/issue-100690.rs:9:8
+   |
+LL | fn real_dispatch<T, F>(f: F) -> Result<(), io::Error>
+   |    ------------- required by a bound in this
+...
+LL |     F: FnOnce(&mut UIView<T>) -> Result<(), io::Error> + Send + 'static,
+   |        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `real_dispatch`
+
+error: aborting due to previous error
+
+For more information about this error, try `rustc --explain E0277`.