about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDylan DPC <99973273+Dylan-DPC@users.noreply.github.com>2022-08-30 16:56:08 +0530
committerGitHub <noreply@github.com>2022-08-30 16:56:08 +0530
commit15e2e5185a22207b18d2cbc47a48b39e63e84cd0 (patch)
tree93ddc8c688b1caa448a6ebab2ea58d6b8f7c0aaa
parent9cfd161cd5b2e3b53c488086f8000aea0c21b0b2 (diff)
parente5602cb2a0e114729625cf27db819ef56a79d86e (diff)
downloadrust-15e2e5185a22207b18d2cbc47a48b39e63e84cd0.tar.gz
rust-15e2e5185a22207b18d2cbc47a48b39e63e84cd0.zip
Rollup merge of #100473 - compiler-errors:normalize-the-fn-def-sig-plz, r=lcnr
Attempt to normalize `FnDef` signature in `InferCtxt::cmp`

Stashes a normalization callback in `InferCtxt` so that the signature we get from `tcx.fn_sig(..).subst(..)` in `InferCtxt::cmp` can be properly normalized, since we cannot expect for it to have normalized types since it comes straight from astconv.

This is kind of a hack, but I will say that `@jyn514` found the fact that we present unnormalized types to be very confusing in real life code, and I agree with that feeling. Though altogether I am still a bit unsure about whether this PR is worth the effort, so I'm open to alternatives and/or just closing it outright.

On the other hand, this isn't a ridiculously heavy implementation anyways -- it's less than a hundred lines of changes, and half of that is just miscellaneous cleanup.

This is stacked onto #100471 which is basically unrelated, and it can be rebased off of that when that lands or if needed.

---

The code:
```rust
trait Foo { type Bar; }

impl<T> Foo for T {
    type Bar = i32;
}

fn foo<T>(_: <T as Foo>::Bar) {}

fn needs_i32_ref_fn(f: fn(&'static i32)) {}

fn main() {
    needs_i32_ref_fn(foo::<()>);
}
```

Before:
```
   = note: expected fn pointer `fn(&'static i32)`
                 found fn item `fn(<() as Foo>::Bar) {foo::<()>}`
```

After:
```
   = note: expected fn pointer `fn(&'static i32)`
                 found fn item `fn(i32) {foo::<()>}`
```
-rw-r--r--compiler/rustc_infer/src/infer/at.rs4
-rw-r--r--compiler/rustc_infer/src/infer/error_reporting/mod.rs11
-rw-r--r--compiler/rustc_infer/src/infer/mod.rs18
-rw-r--r--compiler/rustc_trait_selection/src/traits/engine.rs13
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs29
-rw-r--r--compiler/rustc_typeck/src/check/inherited.rs29
-rw-r--r--src/test/ui/mismatched_types/normalize-fn-sig.rs16
-rw-r--r--src/test/ui/mismatched_types/normalize-fn-sig.stderr19
8 files changed, 125 insertions, 14 deletions
diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs
index e37c0cf0fd0..00e23864871 100644
--- a/compiler/rustc_infer/src/infer/at.rs
+++ b/compiler/rustc_infer/src/infer/at.rs
@@ -78,6 +78,10 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
             err_count_on_creation: self.err_count_on_creation,
             in_snapshot: self.in_snapshot.clone(),
             universe: self.universe.clone(),
+            normalize_fn_sig_for_diagnostic: self
+                .normalize_fn_sig_for_diagnostic
+                .as_ref()
+                .map(|f| f.clone()),
         }
     }
 }
diff --git a/compiler/rustc_infer/src/infer/error_reporting/mod.rs b/compiler/rustc_infer/src/infer/error_reporting/mod.rs
index ecf75411e5f..7dc4934db09 100644
--- a/compiler/rustc_infer/src/infer/error_reporting/mod.rs
+++ b/compiler/rustc_infer/src/infer/error_reporting/mod.rs
@@ -961,12 +961,23 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
         }
     }
 
+    fn normalize_fn_sig_for_diagnostic(&self, sig: ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx> {
+        if let Some(normalize) = &self.normalize_fn_sig_for_diagnostic {
+            normalize(self, sig)
+        } else {
+            sig
+        }
+    }
+
     /// Given two `fn` signatures highlight only sub-parts that are different.
     fn cmp_fn_sig(
         &self,
         sig1: &ty::PolyFnSig<'tcx>,
         sig2: &ty::PolyFnSig<'tcx>,
     ) -> (DiagnosticStyledString, DiagnosticStyledString) {
+        let sig1 = &self.normalize_fn_sig_for_diagnostic(*sig1);
+        let sig2 = &self.normalize_fn_sig_for_diagnostic(*sig2);
+
         let get_lifetimes = |sig| {
             use rustc_hir::def::Namespace;
             let (_, sig, reg) = ty::print::FmtPrinter::new(self.tcx, Namespace::TypeNS)
diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs
index 4689ebb6cee..60ebf8b949d 100644
--- a/compiler/rustc_infer/src/infer/mod.rs
+++ b/compiler/rustc_infer/src/infer/mod.rs
@@ -337,6 +337,9 @@ pub struct InferCtxt<'a, 'tcx> {
     /// when we enter into a higher-ranked (`for<..>`) type or trait
     /// bound.
     universe: Cell<ty::UniverseIndex>,
+
+    normalize_fn_sig_for_diagnostic:
+        Option<Lrc<dyn Fn(&InferCtxt<'_, 'tcx>, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>>,
 }
 
 /// See the `error_reporting` module for more details.
@@ -540,6 +543,8 @@ pub struct InferCtxtBuilder<'tcx> {
     defining_use_anchor: DefiningAnchor,
     considering_regions: bool,
     fresh_typeck_results: Option<RefCell<ty::TypeckResults<'tcx>>>,
+    normalize_fn_sig_for_diagnostic:
+        Option<Lrc<dyn Fn(&InferCtxt<'_, 'tcx>, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>>,
 }
 
 pub trait TyCtxtInferExt<'tcx> {
@@ -553,6 +558,7 @@ impl<'tcx> TyCtxtInferExt<'tcx> for TyCtxt<'tcx> {
             defining_use_anchor: DefiningAnchor::Error,
             considering_regions: true,
             fresh_typeck_results: None,
+            normalize_fn_sig_for_diagnostic: None,
         }
     }
 }
@@ -582,6 +588,14 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
         self
     }
 
+    pub fn with_normalize_fn_sig_for_diagnostic(
+        mut self,
+        fun: Lrc<dyn Fn(&InferCtxt<'_, 'tcx>, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>,
+    ) -> Self {
+        self.normalize_fn_sig_for_diagnostic = Some(fun);
+        self
+    }
+
     /// Given a canonical value `C` as a starting point, create an
     /// inference context that contains each of the bound values
     /// within instantiated as a fresh variable. The `f` closure is
@@ -611,6 +625,7 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
             defining_use_anchor,
             considering_regions,
             ref fresh_typeck_results,
+            ref normalize_fn_sig_for_diagnostic,
         } = *self;
         let in_progress_typeck_results = fresh_typeck_results.as_ref();
         f(InferCtxt {
@@ -629,6 +644,9 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
             in_snapshot: Cell::new(false),
             skip_leak_check: Cell::new(false),
             universe: Cell::new(ty::UniverseIndex::ROOT),
+            normalize_fn_sig_for_diagnostic: normalize_fn_sig_for_diagnostic
+                .as_ref()
+                .map(|f| f.clone()),
         })
     }
 }
diff --git a/compiler/rustc_trait_selection/src/traits/engine.rs b/compiler/rustc_trait_selection/src/traits/engine.rs
index 72533a42d80..dba4d4f69da 100644
--- a/compiler/rustc_trait_selection/src/traits/engine.rs
+++ b/compiler/rustc_trait_selection/src/traits/engine.rs
@@ -17,6 +17,7 @@ use rustc_span::Span;
 
 pub trait TraitEngineExt<'tcx> {
     fn new(tcx: TyCtxt<'tcx>) -> Box<Self>;
+    fn new_in_snapshot(tcx: TyCtxt<'tcx>) -> Box<Self>;
 }
 
 impl<'tcx> TraitEngineExt<'tcx> for dyn TraitEngine<'tcx> {
@@ -27,6 +28,14 @@ impl<'tcx> TraitEngineExt<'tcx> for dyn TraitEngine<'tcx> {
             Box::new(FulfillmentContext::new())
         }
     }
+
+    fn new_in_snapshot(tcx: TyCtxt<'tcx>) -> Box<Self> {
+        if tcx.sess.opts.unstable_opts.chalk {
+            Box::new(ChalkFulfillmentContext::new())
+        } else {
+            Box::new(FulfillmentContext::new_in_snapshot())
+        }
+    }
 }
 
 /// Used if you want to have pleasant experience when dealing
@@ -41,6 +50,10 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
         Self { infcx, engine: RefCell::new(<dyn TraitEngine<'_>>::new(infcx.tcx)) }
     }
 
+    pub fn new_in_snapshot(infcx: &'a InferCtxt<'a, 'tcx>) -> Self {
+        Self { infcx, engine: RefCell::new(<dyn TraitEngine<'_>>::new_in_snapshot(infcx.tcx)) }
+    }
+
     pub fn register_obligation(&self, obligation: PredicateObligation<'tcx>) {
         self.engine.borrow_mut().register_predicate_obligation(self.infcx, obligation);
     }
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
index 54f01577c5e..02adae5bde1 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
@@ -20,7 +20,7 @@ use rustc_hir::def_id::DefId;
 use rustc_hir::intravisit::Visitor;
 use rustc_hir::lang_items::LangItem;
 use rustc_hir::{AsyncGeneratorKind, GeneratorKind, Node};
-use rustc_infer::infer::TyCtxtInferExt;
+use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
 use rustc_middle::hir::map;
 use rustc_middle::ty::{
     self, suggest_arbitrary_trait_bound, suggest_constraining_type_param, AdtKind, DefIdTree,
@@ -1589,32 +1589,38 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
         expected: ty::PolyTraitRef<'tcx>,
     ) -> DiagnosticBuilder<'tcx, ErrorGuaranteed> {
         pub(crate) fn build_fn_sig_ty<'tcx>(
-            tcx: TyCtxt<'tcx>,
+            infcx: &InferCtxt<'_, 'tcx>,
             trait_ref: ty::PolyTraitRef<'tcx>,
         ) -> Ty<'tcx> {
             let inputs = trait_ref.skip_binder().substs.type_at(1);
             let sig = match inputs.kind() {
                 ty::Tuple(inputs)
-                    if tcx.fn_trait_kind_from_lang_item(trait_ref.def_id()).is_some() =>
+                    if infcx.tcx.fn_trait_kind_from_lang_item(trait_ref.def_id()).is_some() =>
                 {
-                    tcx.mk_fn_sig(
+                    infcx.tcx.mk_fn_sig(
                         inputs.iter(),
-                        tcx.mk_ty_infer(ty::TyVar(ty::TyVid::from_u32(0))),
+                        infcx.next_ty_var(TypeVariableOrigin {
+                            span: DUMMY_SP,
+                            kind: TypeVariableOriginKind::MiscVariable,
+                        }),
                         false,
                         hir::Unsafety::Normal,
                         abi::Abi::Rust,
                     )
                 }
-                _ => tcx.mk_fn_sig(
+                _ => infcx.tcx.mk_fn_sig(
                     std::iter::once(inputs),
-                    tcx.mk_ty_infer(ty::TyVar(ty::TyVid::from_u32(0))),
+                    infcx.next_ty_var(TypeVariableOrigin {
+                        span: DUMMY_SP,
+                        kind: TypeVariableOriginKind::MiscVariable,
+                    }),
                     false,
                     hir::Unsafety::Normal,
                     abi::Abi::Rust,
                 ),
             };
 
-            tcx.mk_fn_ptr(trait_ref.rebind(sig))
+            infcx.tcx.mk_fn_ptr(trait_ref.rebind(sig))
         }
 
         let argument_kind = match expected.skip_binder().self_ty().kind() {
@@ -1634,11 +1640,10 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
         let found_span = found_span.unwrap_or(span);
         err.span_label(found_span, "found signature defined here");
 
-        let expected = build_fn_sig_ty(self.tcx, expected);
-        let found = build_fn_sig_ty(self.tcx, found);
+        let expected = build_fn_sig_ty(self, expected);
+        let found = build_fn_sig_ty(self, found);
 
-        let (expected_str, found_str) =
-            self.tcx.infer_ctxt().enter(|infcx| infcx.cmp(expected, found));
+        let (expected_str, found_str) = self.cmp(expected, found);
 
         let signature_kind = format!("{argument_kind} signature");
         err.note_expected_found(&signature_kind, expected_str, &signature_kind, found_str);
diff --git a/compiler/rustc_typeck/src/check/inherited.rs b/compiler/rustc_typeck/src/check/inherited.rs
index f3115fc5c02..1439baf5440 100644
--- a/compiler/rustc_typeck/src/check/inherited.rs
+++ b/compiler/rustc_typeck/src/check/inherited.rs
@@ -1,6 +1,7 @@
 use super::callee::DeferredCallResolution;
 
 use rustc_data_structures::fx::FxHashSet;
+use rustc_data_structures::sync::Lrc;
 use rustc_hir as hir;
 use rustc_hir::def_id::LocalDefId;
 use rustc_hir::HirIdMap;
@@ -12,7 +13,9 @@ use rustc_middle::ty::{self, Ty, TyCtxt};
 use rustc_span::def_id::LocalDefIdMap;
 use rustc_span::{self, Span};
 use rustc_trait_selection::infer::InferCtxtExt as _;
-use rustc_trait_selection::traits::{self, ObligationCause, TraitEngine, TraitEngineExt};
+use rustc_trait_selection::traits::{
+    self, ObligationCause, ObligationCtxt, TraitEngine, TraitEngineExt as _,
+};
 
 use std::cell::RefCell;
 use std::ops::Deref;
@@ -84,7 +87,29 @@ impl<'tcx> Inherited<'_, 'tcx> {
             infcx: tcx
                 .infer_ctxt()
                 .ignoring_regions()
-                .with_fresh_in_progress_typeck_results(hir_owner),
+                .with_fresh_in_progress_typeck_results(hir_owner)
+                .with_normalize_fn_sig_for_diagnostic(Lrc::new(move |infcx, fn_sig| {
+                    if fn_sig.has_escaping_bound_vars() {
+                        return fn_sig;
+                    }
+                    infcx.probe(|_| {
+                        let ocx = ObligationCtxt::new_in_snapshot(infcx);
+                        let normalized_fn_sig = ocx.normalize(
+                            ObligationCause::dummy(),
+                            // FIXME(compiler-errors): This is probably not the right param-env...
+                            infcx.tcx.param_env(def_id),
+                            fn_sig,
+                        );
+                        if ocx.select_all_or_error().is_empty() {
+                            let normalized_fn_sig =
+                                infcx.resolve_vars_if_possible(normalized_fn_sig);
+                            if !normalized_fn_sig.needs_infer() {
+                                return normalized_fn_sig;
+                            }
+                        }
+                        fn_sig
+                    })
+                })),
             def_id,
         }
     }
diff --git a/src/test/ui/mismatched_types/normalize-fn-sig.rs b/src/test/ui/mismatched_types/normalize-fn-sig.rs
new file mode 100644
index 00000000000..1a2093c44f0
--- /dev/null
+++ b/src/test/ui/mismatched_types/normalize-fn-sig.rs
@@ -0,0 +1,16 @@
+trait Foo {
+    type Bar;
+}
+
+impl<T> Foo for T {
+    type Bar = i32;
+}
+
+fn foo<T>(_: <T as Foo>::Bar, _: &'static <T as Foo>::Bar) {}
+
+fn needs_i32_ref_fn(_: fn(&'static i32, i32)) {}
+
+fn main() {
+    needs_i32_ref_fn(foo::<()>);
+    //~^ ERROR mismatched types
+}
diff --git a/src/test/ui/mismatched_types/normalize-fn-sig.stderr b/src/test/ui/mismatched_types/normalize-fn-sig.stderr
new file mode 100644
index 00000000000..6c55f29c5d1
--- /dev/null
+++ b/src/test/ui/mismatched_types/normalize-fn-sig.stderr
@@ -0,0 +1,19 @@
+error[E0308]: mismatched types
+  --> $DIR/normalize-fn-sig.rs:14:22
+   |
+LL |     needs_i32_ref_fn(foo::<()>);
+   |     ---------------- ^^^^^^^^^ expected `&i32`, found `i32`
+   |     |
+   |     arguments to this function are incorrect
+   |
+   = note: expected fn pointer `fn(&'static i32, i32)`
+                 found fn item `fn(i32, &'static i32) {foo::<()>}`
+note: function defined here
+  --> $DIR/normalize-fn-sig.rs:11:4
+   |
+LL | fn needs_i32_ref_fn(_: fn(&'static i32, i32)) {}
+   |    ^^^^^^^^^^^^^^^^ ------------------------
+
+error: aborting due to previous error
+
+For more information about this error, try `rustc --explain E0308`.