about summary refs log tree commit diff
path: root/compiler/rustc_hir_analysis/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_hir_analysis/src')
-rw-r--r--compiler/rustc_hir_analysis/src/check/entry.rs54
-rw-r--r--compiler/rustc_hir_analysis/src/check/intrinsic.rs13
-rw-r--r--compiler/rustc_hir_analysis/src/check/mod.rs91
-rw-r--r--compiler/rustc_hir_analysis/src/lib.rs26
4 files changed, 120 insertions, 64 deletions
diff --git a/compiler/rustc_hir_analysis/src/check/entry.rs b/compiler/rustc_hir_analysis/src/check/entry.rs
index fcaefe0261b..27af10af8e6 100644
--- a/compiler/rustc_hir_analysis/src/check/entry.rs
+++ b/compiler/rustc_hir_analysis/src/check/entry.rs
@@ -11,8 +11,8 @@ use rustc_trait_selection::traits::{self, ObligationCause, ObligationCauseCode};
 
 use std::ops::Not;
 
+use super::check_function_signature;
 use crate::errors;
-use crate::require_same_types;
 
 pub(crate) fn check_for_entry_fn(tcx: TyCtxt<'_>) {
     match tcx.entry_fn(()) {
@@ -162,33 +162,33 @@ fn check_main_fn_ty(tcx: TyCtxt<'_>, main_def_id: DefId) {
             error = true;
         }
         // now we can take the return type of the given main function
-        expected_return_type = main_fnsig.output();
+        expected_return_type = norm_return_ty;
     } else {
         // standard () main return type
-        expected_return_type = ty::Binder::dummy(Ty::new_unit(tcx));
+        expected_return_type = tcx.types.unit;
     }
 
     if error {
         return;
     }
 
-    let se_ty = Ty::new_fn_ptr(
-        tcx,
-        expected_return_type.map_bound(|expected_return_type| {
-            tcx.mk_fn_sig([], expected_return_type, false, hir::Unsafety::Normal, Abi::Rust)
-        }),
-    );
+    let expected_sig = ty::Binder::dummy(tcx.mk_fn_sig(
+        [],
+        expected_return_type,
+        false,
+        hir::Unsafety::Normal,
+        Abi::Rust,
+    ));
 
-    require_same_types(
+    check_function_signature(
         tcx,
-        &ObligationCause::new(
+        ObligationCause::new(
             main_span,
             main_diagnostics_def_id,
             ObligationCauseCode::MainFunctionType,
         ),
-        param_env,
-        se_ty,
-        Ty::new_fn_ptr(tcx, main_fnsig),
+        main_def_id,
+        expected_sig,
     );
 }
 
@@ -247,27 +247,23 @@ fn check_start_fn_ty(tcx: TyCtxt<'_>, start_def_id: DefId) {
                 }
             }
 
-            let se_ty = Ty::new_fn_ptr(
-                tcx,
-                ty::Binder::dummy(tcx.mk_fn_sig(
-                    [tcx.types.isize, Ty::new_imm_ptr(tcx, Ty::new_imm_ptr(tcx, tcx.types.u8))],
-                    tcx.types.isize,
-                    false,
-                    hir::Unsafety::Normal,
-                    Abi::Rust,
-                )),
-            );
+            let expected_sig = ty::Binder::dummy(tcx.mk_fn_sig(
+                [tcx.types.isize, Ty::new_imm_ptr(tcx, Ty::new_imm_ptr(tcx, tcx.types.u8))],
+                tcx.types.isize,
+                false,
+                hir::Unsafety::Normal,
+                Abi::Rust,
+            ));
 
-            require_same_types(
+            check_function_signature(
                 tcx,
-                &ObligationCause::new(
+                ObligationCause::new(
                     start_span,
                     start_def_id,
                     ObligationCauseCode::StartFunctionType,
                 ),
-                ty::ParamEnv::empty(), // start should not have any where bounds.
-                se_ty,
-                Ty::new_fn_ptr(tcx, tcx.fn_sig(start_def_id).instantiate_identity()),
+                start_def_id.into(),
+                expected_sig,
             );
         }
         _ => {
diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs
index f89e2e5c25b..444103ffe8f 100644
--- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs
+++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs
@@ -1,11 +1,11 @@
 //! Type-checking for the rust-intrinsic and platform-intrinsic
 //! intrinsics that the compiler exposes.
 
+use crate::check::check_function_signature;
 use crate::errors::{
     UnrecognizedAtomicOperation, UnrecognizedIntrinsicFunction,
     WrongNumberOfGenericArgumentsToIntrinsic,
 };
-use crate::require_same_types;
 
 use hir::def_id::DefId;
 use rustc_errors::{struct_span_err, DiagnosticMessage};
@@ -53,15 +53,12 @@ fn equate_intrinsic_type<'tcx>(
         && gen_count_ok(own_counts.types, n_tps, "type")
         && gen_count_ok(own_counts.consts, 0, "const")
     {
-        let fty = Ty::new_fn_ptr(tcx, sig);
         let it_def_id = it.owner_id.def_id;
-        let cause = ObligationCause::new(it.span, it_def_id, ObligationCauseCode::IntrinsicType);
-        require_same_types(
+        check_function_signature(
             tcx,
-            &cause,
-            ty::ParamEnv::empty(), // FIXME: do all intrinsics have an empty param env?
-            Ty::new_fn_ptr(tcx, tcx.fn_sig(it.owner_id).instantiate_identity()),
-            fty,
+            ObligationCause::new(it.span, it_def_id, ObligationCauseCode::IntrinsicType),
+            it_def_id.into(),
+            sig,
         );
     }
 }
diff --git a/compiler/rustc_hir_analysis/src/check/mod.rs b/compiler/rustc_hir_analysis/src/check/mod.rs
index d8331661366..b9ce4c07922 100644
--- a/compiler/rustc_hir_analysis/src/check/mod.rs
+++ b/compiler/rustc_hir_analysis/src/check/mod.rs
@@ -73,23 +73,31 @@ pub mod wfcheck;
 
 pub use check::check_abi;
 
+use std::num::NonZeroU32;
+
 use check::check_mod_item_types;
 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
 use rustc_errors::{pluralize, struct_span_err, Diagnostic, DiagnosticBuilder};
 use rustc_hir::def_id::{DefId, LocalDefId};
 use rustc_hir::intravisit::Visitor;
 use rustc_index::bit_set::BitSet;
+use rustc_infer::infer::error_reporting::ObligationCauseExt as _;
+use rustc_infer::infer::outlives::env::OutlivesEnvironment;
+use rustc_infer::infer::{self, TyCtxtInferExt as _};
+use rustc_infer::traits::ObligationCause;
 use rustc_middle::query::Providers;
+use rustc_middle::ty::error::{ExpectedFound, TypeError};
 use rustc_middle::ty::{self, Ty, TyCtxt};
 use rustc_middle::ty::{GenericArgs, GenericArgsRef};
 use rustc_session::parse::feature_err;
 use rustc_span::source_map::DUMMY_SP;
 use rustc_span::symbol::{kw, Ident};
-use rustc_span::{self, BytePos, Span, Symbol};
+use rustc_span::{self, def_id::CRATE_DEF_ID, BytePos, Span, Symbol};
 use rustc_target::abi::VariantIdx;
 use rustc_target::spec::abi::Abi;
 use rustc_trait_selection::traits::error_reporting::suggestions::ReturnsVisitor;
-use std::num::NonZeroU32;
+use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
+use rustc_trait_selection::traits::ObligationCtxt;
 
 use crate::errors;
 use crate::require_c_abi_if_c_variadic;
@@ -546,3 +554,82 @@ fn bad_non_zero_sized_fields<'tcx>(
 pub fn potentially_plural_count(count: usize, word: &str) -> String {
     format!("{} {}{}", count, word, pluralize!(count))
 }
+
+pub fn check_function_signature<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    mut cause: ObligationCause<'tcx>,
+    fn_id: DefId,
+    expected_sig: ty::PolyFnSig<'tcx>,
+) {
+    let local_id = fn_id.as_local().unwrap_or(CRATE_DEF_ID);
+
+    let param_env = ty::ParamEnv::empty();
+
+    let infcx = &tcx.infer_ctxt().build();
+    let ocx = ObligationCtxt::new(infcx);
+
+    let unnormalized_actual_sig = infcx.instantiate_binder_with_fresh_vars(
+        cause.span,
+        infer::HigherRankedType,
+        tcx.fn_sig(fn_id).instantiate_identity(),
+    );
+
+    let norm_cause = ObligationCause::misc(cause.span, local_id);
+    let actual_sig = ocx.normalize(&norm_cause, param_env, unnormalized_actual_sig);
+
+    let expected_sig = tcx.liberate_late_bound_regions(fn_id, expected_sig);
+
+    match ocx.eq(&cause, param_env, expected_sig, actual_sig) {
+        Ok(()) => {
+            let errors = ocx.select_all_or_error();
+            if !errors.is_empty() {
+                infcx.err_ctxt().report_fulfillment_errors(&errors);
+                return;
+            }
+        }
+        Err(err) => {
+            let err_ctxt = infcx.err_ctxt();
+            if fn_id.is_local() {
+                cause.span = extract_span_for_error_reporting(tcx, err, &cause, local_id);
+            }
+            let failure_code = cause.as_failure_code_diag(err, cause.span, vec![]);
+            let mut diag = tcx.sess.create_err(failure_code);
+            err_ctxt.note_type_err(
+                &mut diag,
+                &cause,
+                None,
+                Some(infer::ValuePairs::Sigs(ExpectedFound {
+                    expected: expected_sig,
+                    found: actual_sig,
+                })),
+                err,
+                false,
+                false,
+            );
+            diag.emit();
+            return;
+        }
+    }
+
+    let outlives_env = OutlivesEnvironment::new(param_env);
+    let _ = ocx.resolve_regions_and_report_errors(local_id, &outlives_env);
+
+    fn extract_span_for_error_reporting<'tcx>(
+        tcx: TyCtxt<'tcx>,
+        err: TypeError<'_>,
+        cause: &ObligationCause<'tcx>,
+        fn_id: LocalDefId,
+    ) -> rustc_span::Span {
+        let mut args = {
+            let node = tcx.hir().expect_owner(fn_id);
+            let decl = node.fn_decl().unwrap_or_else(|| bug!("expected fn decl, found {:?}", node));
+            decl.inputs.iter().map(|t| t.span).chain(std::iter::once(decl.output.span()))
+        };
+
+        match err {
+            TypeError::ArgumentMutability(i)
+            | TypeError::ArgumentSorts(ExpectedFound { .. }, i) => args.nth(i).unwrap(),
+            _ => cause.span(),
+        }
+    }
+}
diff --git a/compiler/rustc_hir_analysis/src/lib.rs b/compiler/rustc_hir_analysis/src/lib.rs
index ef788935efb..03963925d3d 100644
--- a/compiler/rustc_hir_analysis/src/lib.rs
+++ b/compiler/rustc_hir_analysis/src/lib.rs
@@ -99,7 +99,6 @@ use rustc_errors::ErrorGuaranteed;
 use rustc_errors::{DiagnosticMessage, SubdiagnosticMessage};
 use rustc_fluent_macro::fluent_messages;
 use rustc_hir as hir;
-use rustc_infer::infer::TyCtxtInferExt;
 use rustc_middle::middle;
 use rustc_middle::query::Providers;
 use rustc_middle::ty::{self, Ty, TyCtxt};
@@ -107,8 +106,7 @@ use rustc_middle::util;
 use rustc_session::parse::feature_err;
 use rustc_span::{symbol::sym, Span, DUMMY_SP};
 use rustc_target::spec::abi::Abi;
-use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
-use rustc_trait_selection::traits::{self, ObligationCause, ObligationCtxt};
+use rustc_trait_selection::traits;
 
 use astconv::{AstConv, OnlySelfBounds};
 use bounds::Bounds;
@@ -151,28 +149,6 @@ fn require_c_abi_if_c_variadic(tcx: TyCtxt<'_>, decl: &hir::FnDecl<'_>, abi: Abi
     tcx.sess.emit_err(errors::VariadicFunctionCompatibleConvention { span, conventions });
 }
 
-fn require_same_types<'tcx>(
-    tcx: TyCtxt<'tcx>,
-    cause: &ObligationCause<'tcx>,
-    param_env: ty::ParamEnv<'tcx>,
-    expected: Ty<'tcx>,
-    actual: Ty<'tcx>,
-) {
-    let infcx = &tcx.infer_ctxt().build();
-    let ocx = ObligationCtxt::new(infcx);
-    match ocx.eq(cause, param_env, expected, actual) {
-        Ok(()) => {
-            let errors = ocx.select_all_or_error();
-            if !errors.is_empty() {
-                infcx.err_ctxt().report_fulfillment_errors(&errors);
-            }
-        }
-        Err(err) => {
-            infcx.err_ctxt().report_mismatched_types(cause, expected, actual, err).emit();
-        }
-    }
-}
-
 pub fn provide(providers: &mut Providers) {
     collect::provide(providers);
     coherence::provide(providers);