about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMatthias Krüger <matthias.krueger@famsik.de>2024-01-20 09:37:26 +0100
committerGitHub <noreply@github.com>2024-01-20 09:37:26 +0100
commit2de5ca25d2aa658553e75eedcdb6968a0d53d969 (patch)
tree08c2927194da5a18f629842beaba46cc38a18205
parent6f67208d725cce91f9eeb03e39bda5bfba68a303 (diff)
parent130b7e713e879c4c989186d94643be8c834de355 (diff)
downloadrust-2de5ca25d2aa658553e75eedcdb6968a0d53d969.tar.gz
rust-2de5ca25d2aa658553e75eedcdb6968a0d53d969.zip
Rollup merge of #119613 - gavinleroy:expose-obligations, r=lcnr
Expose Obligations created during type inference.

This PR is a first pass at exposing the trait obligations generated and solved for during the type-check progress. Exposing these obligations allows for rustc plugins to use the public interface for proof trees (provided by the next gen trait solver).

The changes proposed track *all* obligations during the type-check process, this is desirable to not only look at the trees of failed obligations, but also those of successfully proved obligations. This feature is placed behind an unstable compiler option `track-trait-obligations` which should be used together with the `next-solver` option. I should note that the main interface is the function `inspect_typeck` made public in `rustc_hir_typeck/src/lib.rs` which allows the caller to provide a callback granting access to the `FnCtxt`.

r? `@lcnr`
-rw-r--r--compiler/rustc_hir_typeck/src/lib.rs24
-rw-r--r--compiler/rustc_infer/src/infer/at.rs1
-rw-r--r--compiler/rustc_infer/src/infer/mod.rs16
-rw-r--r--compiler/rustc_infer/src/traits/mod.rs8
-rw-r--r--compiler/rustc_trait_selection/src/solve/fulfill.rs136
5 files changed, 121 insertions, 64 deletions
diff --git a/compiler/rustc_hir_typeck/src/lib.rs b/compiler/rustc_hir_typeck/src/lib.rs
index 67c35d717a1..80467ca9381 100644
--- a/compiler/rustc_hir_typeck/src/lib.rs
+++ b/compiler/rustc_hir_typeck/src/lib.rs
@@ -60,6 +60,7 @@ use rustc_hir::{HirIdMap, Node};
 use rustc_hir_analysis::astconv::AstConv;
 use rustc_hir_analysis::check::check_abi;
 use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
+use rustc_infer::traits::ObligationInspector;
 use rustc_middle::query::Providers;
 use rustc_middle::traits;
 use rustc_middle::ty::{self, Ty, TyCtxt};
@@ -139,7 +140,7 @@ fn used_trait_imports(tcx: TyCtxt<'_>, def_id: LocalDefId) -> &UnordSet<LocalDef
 
 fn typeck<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> &ty::TypeckResults<'tcx> {
     let fallback = move || tcx.type_of(def_id.to_def_id()).instantiate_identity();
-    typeck_with_fallback(tcx, def_id, fallback)
+    typeck_with_fallback(tcx, def_id, fallback, None)
 }
 
 /// Used only to get `TypeckResults` for type inference during error recovery.
@@ -149,14 +150,28 @@ fn diagnostic_only_typeck<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> &ty::T
         let span = tcx.hir().span(tcx.local_def_id_to_hir_id(def_id));
         Ty::new_error_with_message(tcx, span, "diagnostic only typeck table used")
     };
-    typeck_with_fallback(tcx, def_id, fallback)
+    typeck_with_fallback(tcx, def_id, fallback, None)
 }
 
-#[instrument(level = "debug", skip(tcx, fallback), ret)]
+/// Same as `typeck` but `inspect` is invoked on evaluation of each root obligation.
+/// Inspecting obligations only works with the new trait solver.
+/// This function is *only to be used* by external tools, it should not be
+/// called from within rustc. Note, this is not a query, and thus is not cached.
+pub fn inspect_typeck<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    def_id: LocalDefId,
+    inspect: ObligationInspector<'tcx>,
+) -> &'tcx ty::TypeckResults<'tcx> {
+    let fallback = move || tcx.type_of(def_id.to_def_id()).instantiate_identity();
+    typeck_with_fallback(tcx, def_id, fallback, Some(inspect))
+}
+
+#[instrument(level = "debug", skip(tcx, fallback, inspector), ret)]
 fn typeck_with_fallback<'tcx>(
     tcx: TyCtxt<'tcx>,
     def_id: LocalDefId,
     fallback: impl Fn() -> Ty<'tcx> + 'tcx,
+    inspector: Option<ObligationInspector<'tcx>>,
 ) -> &'tcx ty::TypeckResults<'tcx> {
     // Closures' typeck results come from their outermost function,
     // as they are part of the same "inference environment".
@@ -178,6 +193,9 @@ fn typeck_with_fallback<'tcx>(
     let param_env = tcx.param_env(def_id);
 
     let inh = Inherited::new(tcx, def_id);
+    if let Some(inspector) = inspector {
+        inh.infcx.attach_obligation_inspector(inspector);
+    }
     let mut fcx = FnCtxt::new(&inh, param_env, def_id);
 
     if let Some(hir::FnSig { header, decl, .. }) = fn_sig {
diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs
index e60e3ffeaa7..0f1af81d9f0 100644
--- a/compiler/rustc_infer/src/infer/at.rs
+++ b/compiler/rustc_infer/src/infer/at.rs
@@ -90,6 +90,7 @@ impl<'tcx> InferCtxt<'tcx> {
             universe: self.universe.clone(),
             intercrate,
             next_trait_solver: self.next_trait_solver,
+            obligation_inspector: self.obligation_inspector.clone(),
         }
     }
 }
diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs
index 002aad19c49..1eab8575fc0 100644
--- a/compiler/rustc_infer/src/infer/mod.rs
+++ b/compiler/rustc_infer/src/infer/mod.rs
@@ -13,7 +13,9 @@ use rustc_middle::infer::unify_key::{ConstVidKey, EffectVidKey};
 use self::opaque_types::OpaqueTypeStorage;
 pub(crate) use self::undo_log::{InferCtxtUndoLogs, Snapshot, UndoLog};
 
-use crate::traits::{self, ObligationCause, PredicateObligations, TraitEngine, TraitEngineExt};
+use crate::traits::{
+    self, ObligationCause, ObligationInspector, PredicateObligations, TraitEngine, TraitEngineExt,
+};
 
 use rustc_data_structures::fx::FxIndexMap;
 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
@@ -334,6 +336,8 @@ pub struct InferCtxt<'tcx> {
     pub intercrate: bool,
 
     next_trait_solver: bool,
+
+    pub obligation_inspector: Cell<Option<ObligationInspector<'tcx>>>,
 }
 
 impl<'tcx> ty::InferCtxtLike for InferCtxt<'tcx> {
@@ -708,6 +712,7 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
             universe: Cell::new(ty::UniverseIndex::ROOT),
             intercrate,
             next_trait_solver,
+            obligation_inspector: Cell::new(None),
         }
     }
 }
@@ -1718,6 +1723,15 @@ impl<'tcx> InferCtxt<'tcx> {
             }
         }
     }
+
+    /// Attach a callback to be invoked on each root obligation evaluated in the new trait solver.
+    pub fn attach_obligation_inspector(&self, inspector: ObligationInspector<'tcx>) {
+        debug_assert!(
+            self.obligation_inspector.get().is_none(),
+            "shouldn't override a set obligation inspector"
+        );
+        self.obligation_inspector.set(Some(inspector));
+    }
 }
 
 impl<'tcx> TypeErrCtxt<'_, 'tcx> {
diff --git a/compiler/rustc_infer/src/traits/mod.rs b/compiler/rustc_infer/src/traits/mod.rs
index fdae093aac8..72ec07375ac 100644
--- a/compiler/rustc_infer/src/traits/mod.rs
+++ b/compiler/rustc_infer/src/traits/mod.rs
@@ -13,12 +13,15 @@ use std::hash::{Hash, Hasher};
 
 use hir::def_id::LocalDefId;
 use rustc_hir as hir;
+use rustc_middle::traits::query::NoSolution;
+use rustc_middle::traits::solve::Certainty;
 use rustc_middle::ty::error::{ExpectedFound, TypeError};
 use rustc_middle::ty::{self, Const, ToPredicate, Ty, TyCtxt};
 use rustc_span::Span;
 
 pub use self::ImplSource::*;
 pub use self::SelectionError::*;
+use crate::infer::InferCtxt;
 
 pub use self::engine::{TraitEngine, TraitEngineExt};
 pub use self::project::MismatchedProjectionTypes;
@@ -116,6 +119,11 @@ pub type PredicateObligations<'tcx> = Vec<PredicateObligation<'tcx>>;
 
 pub type Selection<'tcx> = ImplSource<'tcx, PredicateObligation<'tcx>>;
 
+/// A callback that can be provided to `inspect_typeck`. Invoked on evaluation
+/// of root obligations.
+pub type ObligationInspector<'tcx> =
+    fn(&InferCtxt<'tcx>, &PredicateObligation<'tcx>, Result<Certainty, NoSolution>);
+
 pub struct FulfillmentError<'tcx> {
     pub obligation: PredicateObligation<'tcx>,
     pub code: FulfillmentErrorCode<'tcx>,
diff --git a/compiler/rustc_trait_selection/src/solve/fulfill.rs b/compiler/rustc_trait_selection/src/solve/fulfill.rs
index c847425ebf4..f08622816ec 100644
--- a/compiler/rustc_trait_selection/src/solve/fulfill.rs
+++ b/compiler/rustc_trait_selection/src/solve/fulfill.rs
@@ -11,7 +11,7 @@ use rustc_middle::ty;
 use rustc_middle::ty::error::{ExpectedFound, TypeError};
 
 use super::eval_ctxt::GenerateProofTree;
-use super::{Certainty, InferCtxtEvalExt};
+use super::{Certainty, Goal, InferCtxtEvalExt};
 
 /// A trait engine using the new trait solver.
 ///
@@ -43,6 +43,21 @@ impl<'tcx> FulfillmentCtxt<'tcx> {
         );
         FulfillmentCtxt { obligations: Vec::new(), usable_in_snapshot: infcx.num_open_snapshots() }
     }
+
+    fn inspect_evaluated_obligation(
+        &self,
+        infcx: &InferCtxt<'tcx>,
+        obligation: &PredicateObligation<'tcx>,
+        result: &Result<(bool, Certainty, Vec<Goal<'tcx, ty::Predicate<'tcx>>>), NoSolution>,
+    ) {
+        if let Some(inspector) = infcx.obligation_inspector.get() {
+            let result = match result {
+                Ok((_, c, _)) => Ok(*c),
+                Err(NoSolution) => Err(NoSolution),
+            };
+            (inspector)(infcx, &obligation, result);
+        }
+    }
 }
 
 impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
@@ -100,65 +115,66 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
             let mut has_changed = false;
             for obligation in mem::take(&mut self.obligations) {
                 let goal = obligation.clone().into();
-                let (changed, certainty, nested_goals) =
-                    match infcx.evaluate_root_goal(goal, GenerateProofTree::IfEnabled).0 {
-                        Ok(result) => result,
-                        Err(NoSolution) => {
-                            errors.push(FulfillmentError {
-                                obligation: obligation.clone(),
-                                code: match goal.predicate.kind().skip_binder() {
-                                    ty::PredicateKind::Clause(ty::ClauseKind::Projection(_)) => {
-                                        FulfillmentErrorCode::ProjectionError(
-                                            // FIXME: This could be a `Sorts` if the term is a type
-                                            MismatchedProjectionTypes { err: TypeError::Mismatch },
-                                        )
-                                    }
-                                    ty::PredicateKind::NormalizesTo(..) => {
-                                        FulfillmentErrorCode::ProjectionError(
-                                            MismatchedProjectionTypes { err: TypeError::Mismatch },
-                                        )
-                                    }
-                                    ty::PredicateKind::AliasRelate(_, _, _) => {
-                                        FulfillmentErrorCode::ProjectionError(
-                                            MismatchedProjectionTypes { err: TypeError::Mismatch },
-                                        )
-                                    }
-                                    ty::PredicateKind::Subtype(pred) => {
-                                        let (a, b) = infcx.instantiate_binder_with_placeholders(
-                                            goal.predicate.kind().rebind((pred.a, pred.b)),
-                                        );
-                                        let expected_found = ExpectedFound::new(true, a, b);
-                                        FulfillmentErrorCode::SubtypeError(
-                                            expected_found,
-                                            TypeError::Sorts(expected_found),
-                                        )
-                                    }
-                                    ty::PredicateKind::Coerce(pred) => {
-                                        let (a, b) = infcx.instantiate_binder_with_placeholders(
-                                            goal.predicate.kind().rebind((pred.a, pred.b)),
-                                        );
-                                        let expected_found = ExpectedFound::new(false, a, b);
-                                        FulfillmentErrorCode::SubtypeError(
-                                            expected_found,
-                                            TypeError::Sorts(expected_found),
-                                        )
-                                    }
-                                    ty::PredicateKind::Clause(_)
-                                    | ty::PredicateKind::ObjectSafe(_)
-                                    | ty::PredicateKind::Ambiguous => {
-                                        FulfillmentErrorCode::SelectionError(
-                                            SelectionError::Unimplemented,
-                                        )
-                                    }
-                                    ty::PredicateKind::ConstEquate(..) => {
-                                        bug!("unexpected goal: {goal:?}")
-                                    }
-                                },
-                                root_obligation: obligation,
-                            });
-                            continue;
-                        }
-                    };
+                let result = infcx.evaluate_root_goal(goal, GenerateProofTree::IfEnabled).0;
+                self.inspect_evaluated_obligation(infcx, &obligation, &result);
+                let (changed, certainty, nested_goals) = match result {
+                    Ok(result) => result,
+                    Err(NoSolution) => {
+                        errors.push(FulfillmentError {
+                            obligation: obligation.clone(),
+                            code: match goal.predicate.kind().skip_binder() {
+                                ty::PredicateKind::Clause(ty::ClauseKind::Projection(_)) => {
+                                    FulfillmentErrorCode::ProjectionError(
+                                        // FIXME: This could be a `Sorts` if the term is a type
+                                        MismatchedProjectionTypes { err: TypeError::Mismatch },
+                                    )
+                                }
+                                ty::PredicateKind::NormalizesTo(..) => {
+                                    FulfillmentErrorCode::ProjectionError(
+                                        MismatchedProjectionTypes { err: TypeError::Mismatch },
+                                    )
+                                }
+                                ty::PredicateKind::AliasRelate(_, _, _) => {
+                                    FulfillmentErrorCode::ProjectionError(
+                                        MismatchedProjectionTypes { err: TypeError::Mismatch },
+                                    )
+                                }
+                                ty::PredicateKind::Subtype(pred) => {
+                                    let (a, b) = infcx.instantiate_binder_with_placeholders(
+                                        goal.predicate.kind().rebind((pred.a, pred.b)),
+                                    );
+                                    let expected_found = ExpectedFound::new(true, a, b);
+                                    FulfillmentErrorCode::SubtypeError(
+                                        expected_found,
+                                        TypeError::Sorts(expected_found),
+                                    )
+                                }
+                                ty::PredicateKind::Coerce(pred) => {
+                                    let (a, b) = infcx.instantiate_binder_with_placeholders(
+                                        goal.predicate.kind().rebind((pred.a, pred.b)),
+                                    );
+                                    let expected_found = ExpectedFound::new(false, a, b);
+                                    FulfillmentErrorCode::SubtypeError(
+                                        expected_found,
+                                        TypeError::Sorts(expected_found),
+                                    )
+                                }
+                                ty::PredicateKind::Clause(_)
+                                | ty::PredicateKind::ObjectSafe(_)
+                                | ty::PredicateKind::Ambiguous => {
+                                    FulfillmentErrorCode::SelectionError(
+                                        SelectionError::Unimplemented,
+                                    )
+                                }
+                                ty::PredicateKind::ConstEquate(..) => {
+                                    bug!("unexpected goal: {goal:?}")
+                                }
+                            },
+                            root_obligation: obligation,
+                        });
+                        continue;
+                    }
+                };
                 // Push any nested goals that we get from unifying our canonical response
                 // with our obligation onto the fulfillment context.
                 self.obligations.extend(nested_goals.into_iter().map(|goal| {