about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJason Newcomb <jsnewcomb@pm.me>2023-07-30 00:00:19 -0400
committerJason Newcomb <jsnewcomb@pm.me>2023-07-30 00:38:04 -0400
commit71cc39e1f2f4cd7e8cd6fa09c65ecab06ce9cc80 (patch)
tree5385b60cf738288cddeb7791ee62bf061c917c14
parent29730969b196781e92ca8c8b2d0a8bf9bb105e22 (diff)
downloadrust-71cc39e1f2f4cd7e8cd6fa09c65ecab06ce9cc80.tar.gz
rust-71cc39e1f2f4cd7e8cd6fa09c65ecab06ce9cc80.zip
Add debug assertions to `implements_trait`
Improve debug assertions for `make_projection`
-rw-r--r--clippy_lints/src/derive.rs4
-rw-r--r--clippy_lints/src/incorrect_impls.rs8
-rw-r--r--clippy_lints/src/loops/explicit_iter_loop.rs2
-rw-r--r--clippy_lints/src/methods/or_fun_call.rs2
-rw-r--r--clippy_lints/src/needless_pass_by_value.rs10
-rw-r--r--clippy_utils/src/ty.rs184
6 files changed, 117 insertions, 93 deletions
diff --git a/clippy_lints/src/derive.rs b/clippy_lints/src/derive.rs
index c343f248d06..ad7b3b27383 100644
--- a/clippy_lints/src/derive.rs
+++ b/clippy_lints/src/derive.rs
@@ -471,12 +471,12 @@ fn check_partial_eq_without_eq<'tcx>(cx: &LateContext<'tcx>, span: Span, trait_r
         if let Some(def_id) = trait_ref.trait_def_id();
         if cx.tcx.is_diagnostic_item(sym::PartialEq, def_id);
         let param_env = param_env_for_derived_eq(cx.tcx, adt.did(), eq_trait_def_id);
-        if !implements_trait_with_env(cx.tcx, param_env, ty, eq_trait_def_id, []);
+        if !implements_trait_with_env(cx.tcx, param_env, ty, eq_trait_def_id, &[]);
         // If all of our fields implement `Eq`, we can implement `Eq` too
         if adt
             .all_fields()
             .map(|f| f.ty(cx.tcx, args))
-            .all(|ty| implements_trait_with_env(cx.tcx, param_env, ty, eq_trait_def_id, []));
+            .all(|ty| implements_trait_with_env(cx.tcx, param_env, ty, eq_trait_def_id, &[]));
         then {
             span_lint_and_sugg(
                 cx,
diff --git a/clippy_lints/src/incorrect_impls.rs b/clippy_lints/src/incorrect_impls.rs
index bce010bb168..f4edb17ae80 100644
--- a/clippy_lints/src/incorrect_impls.rs
+++ b/clippy_lints/src/incorrect_impls.rs
@@ -189,12 +189,8 @@ impl LateLintPass<'_> for IncorrectImpls {
                 .diagnostic_items(trait_impl.def_id.krate)
                 .name_to_id
                 .get(&sym::Ord)
-            && implements_trait(
-                    cx,
-                    hir_ty_to_ty(cx.tcx, imp.self_ty),
-                    *ord_def_id,
-                    trait_impl.args,
-                )
+            && trait_impl.self_ty() == trait_impl.args.type_at(1)
+            && implements_trait(cx, hir_ty_to_ty(cx.tcx, imp.self_ty), *ord_def_id, &[])
         {
             // If the `cmp` call likely needs to be fully qualified in the suggestion
             // (like `std::cmp::Ord::cmp`). It's unfortunate we must put this here but we can't
diff --git a/clippy_lints/src/loops/explicit_iter_loop.rs b/clippy_lints/src/loops/explicit_iter_loop.rs
index a84a0a6eeb8..7b8c88235a9 100644
--- a/clippy_lints/src/loops/explicit_iter_loop.rs
+++ b/clippy_lints/src/loops/explicit_iter_loop.rs
@@ -109,7 +109,7 @@ fn is_ref_iterable<'tcx>(
         && let sig = cx.tcx.liberate_late_bound_regions(fn_id, cx.tcx.fn_sig(fn_id).skip_binder())
         && let &[req_self_ty, req_res_ty] = &**sig.inputs_and_output
         && let param_env = cx.tcx.param_env(fn_id)
-        && implements_trait_with_env(cx.tcx, param_env, req_self_ty, trait_id, [])
+        && implements_trait_with_env(cx.tcx, param_env, req_self_ty, trait_id, &[])
         && let Some(into_iter_ty) =
             make_normalized_projection_with_regions(cx.tcx, param_env, trait_id, sym!(IntoIter), [req_self_ty])
         && let req_res_ty = normalize_with_regions(cx.tcx, param_env, req_res_ty)
diff --git a/clippy_lints/src/methods/or_fun_call.rs b/clippy_lints/src/methods/or_fun_call.rs
index 23039f2f3f1..8b2f57160af 100644
--- a/clippy_lints/src/methods/or_fun_call.rs
+++ b/clippy_lints/src/methods/or_fun_call.rs
@@ -57,7 +57,7 @@ pub(super) fn check<'tcx>(
                 cx.tcx
                     .get_diagnostic_item(sym::Default)
                     .map_or(false, |default_trait_id| {
-                        implements_trait(cx, output_ty, default_trait_id, args)
+                        implements_trait(cx, output_ty, default_trait_id, &[])
                     })
             } else {
                 false
diff --git a/clippy_lints/src/needless_pass_by_value.rs b/clippy_lints/src/needless_pass_by_value.rs
index 5e26601537f..5ee26966fa7 100644
--- a/clippy_lints/src/needless_pass_by_value.rs
+++ b/clippy_lints/src/needless_pass_by_value.rs
@@ -2,7 +2,7 @@ use clippy_utils::diagnostics::{multispan_sugg, span_lint_and_then};
 use clippy_utils::ptr::get_spans;
 use clippy_utils::source::{snippet, snippet_opt};
 use clippy_utils::ty::{
-    implements_trait, implements_trait_with_env, is_copy, is_type_diagnostic_item, is_type_lang_item,
+    implements_trait, implements_trait_with_env_from_iter, is_copy, is_type_diagnostic_item, is_type_lang_item,
 };
 use clippy_utils::{get_trait_def_id, is_self, paths};
 use if_chain::if_chain;
@@ -182,7 +182,13 @@ impl<'tcx> LateLintPass<'tcx> for NeedlessPassByValue {
                 if !ty.is_mutable_ptr();
                 if !is_copy(cx, ty);
                 if ty.is_sized(cx.tcx, cx.param_env);
-                if !allowed_traits.iter().any(|&t| implements_trait_with_env(cx.tcx, cx.param_env, ty, t, [None]));
+                if !allowed_traits.iter().any(|&t| implements_trait_with_env_from_iter(
+                    cx.tcx,
+                    cx.param_env,
+                    ty,
+                    t,
+                    [Option::<ty::GenericArg<'tcx>>::None],
+                ));
                 if !implements_borrow_trait;
                 if !all_borrowable_trait;
 
diff --git a/clippy_utils/src/ty.rs b/clippy_utils/src/ty.rs
index 44ec8813f19..dbb6b986952 100644
--- a/clippy_utils/src/ty.rs
+++ b/clippy_utils/src/ty.rs
@@ -3,6 +3,7 @@
 #![allow(clippy::module_name_repetitions)]
 
 use core::ops::ControlFlow;
+use itertools::Itertools;
 use rustc_ast::ast::Mutability;
 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
 use rustc_hir as hir;
@@ -13,17 +14,19 @@ use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKi
 use rustc_infer::infer::TyCtxtInferExt;
 use rustc_lint::LateContext;
 use rustc_middle::mir::interpret::{ConstValue, Scalar};
+use rustc_middle::traits::EvaluationResult;
 use rustc_middle::ty::layout::ValidityRequirement;
 use rustc_middle::ty::{
-    self, AdtDef, AliasTy, AssocKind, Binder, BoundRegion, FnSig, GenericArg, GenericArgKind, GenericArgsRef, IntTy,
-    List, ParamEnv, Region, RegionKind, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor,
-    UintTy, VariantDef, VariantDiscr,
+    self, AdtDef, AliasTy, AssocKind, Binder, BoundRegion, FnSig, GenericArg, GenericArgKind, GenericArgsRef,
+    GenericParamDefKind, IntTy, List, ParamEnv, Region, RegionKind, ToPredicate, TraitRef, Ty, TyCtxt,
+    TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor, UintTy, VariantDef, VariantDiscr,
 };
 use rustc_span::symbol::Ident;
 use rustc_span::{sym, Span, Symbol, DUMMY_SP};
 use rustc_target::abi::{Size, VariantIdx};
-use rustc_trait_selection::infer::InferCtxtExt;
+use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _;
 use rustc_trait_selection::traits::query::normalize::QueryNormalizeExt;
+use rustc_trait_selection::traits::{Obligation, ObligationCause};
 use std::iter;
 
 use crate::{match_def_path, path_res, paths};
@@ -207,15 +210,9 @@ pub fn implements_trait<'tcx>(
     cx: &LateContext<'tcx>,
     ty: Ty<'tcx>,
     trait_id: DefId,
-    ty_params: &[GenericArg<'tcx>],
+    args: &[GenericArg<'tcx>],
 ) -> bool {
-    implements_trait_with_env(
-        cx.tcx,
-        cx.param_env,
-        ty,
-        trait_id,
-        ty_params.iter().map(|&arg| Some(arg)),
-    )
+    implements_trait_with_env_from_iter(cx.tcx, cx.param_env, ty, trait_id, args.iter().map(|&x| Some(x)))
 }
 
 /// Same as `implements_trait` but allows using a `ParamEnv` different from the lint context.
@@ -224,7 +221,18 @@ pub fn implements_trait_with_env<'tcx>(
     param_env: ParamEnv<'tcx>,
     ty: Ty<'tcx>,
     trait_id: DefId,
-    ty_params: impl IntoIterator<Item = Option<GenericArg<'tcx>>>,
+    args: &[GenericArg<'tcx>],
+) -> bool {
+    implements_trait_with_env_from_iter(tcx, param_env, ty, trait_id, args.iter().map(|&x| Some(x)))
+}
+
+/// Same as `implements_trait_from_env` but takes the arguments as an iterator.
+pub fn implements_trait_with_env_from_iter<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    param_env: ParamEnv<'tcx>,
+    ty: Ty<'tcx>,
+    trait_id: DefId,
+    args: impl IntoIterator<Item = impl Into<Option<GenericArg<'tcx>>>>,
 ) -> bool {
     // Clippy shouldn't have infer types
     assert!(!ty.has_infer());
@@ -233,19 +241,37 @@ pub fn implements_trait_with_env<'tcx>(
     if ty.has_escaping_bound_vars() {
         return false;
     }
+
     let infcx = tcx.infer_ctxt().build();
-    let orig = TypeVariableOrigin {
-        kind: TypeVariableOriginKind::MiscVariable,
-        span: DUMMY_SP,
-    };
-    let ty_params = tcx.mk_args_from_iter(
-        ty_params
+    let trait_ref = TraitRef::new(
+        tcx,
+        trait_id,
+        Some(GenericArg::from(ty))
             .into_iter()
-            .map(|arg| arg.unwrap_or_else(|| infcx.next_ty_var(orig).into())),
+            .chain(args.into_iter().map(|arg| {
+                arg.into().unwrap_or_else(|| {
+                    let orig = TypeVariableOrigin {
+                        kind: TypeVariableOriginKind::MiscVariable,
+                        span: DUMMY_SP,
+                    };
+                    infcx.next_ty_var(orig).into()
+                })
+            })),
     );
+
+    debug_assert_eq!(tcx.def_kind(trait_id), DefKind::Trait);
+    #[cfg(debug_assertions)]
+    assert_generic_args_match(tcx, trait_id, trait_ref.args);
+
+    let obligation = Obligation {
+        cause: ObligationCause::dummy(),
+        param_env,
+        recursion_depth: 0,
+        predicate: ty::Binder::dummy(trait_ref).without_const().to_predicate(tcx),
+    };
     infcx
-        .type_implements_trait(trait_id, [ty.into()].into_iter().chain(ty_params), param_env)
-        .must_apply_modulo_regions()
+        .evaluate_obligation(&obligation)
+        .is_ok_and(EvaluationResult::must_apply_modulo_regions)
 }
 
 /// Checks whether this type implements `Drop`.
@@ -1014,12 +1040,60 @@ pub fn approx_ty_size<'tcx>(cx: &LateContext<'tcx>, ty: Ty<'tcx>) -> u64 {
     }
 }
 
+/// Asserts that the given arguments match the generic parameters of the given item.
+#[allow(dead_code)]
+fn assert_generic_args_match<'tcx>(tcx: TyCtxt<'tcx>, did: DefId, args: &[GenericArg<'tcx>]) {
+    let g = tcx.generics_of(did);
+    let parent = g.parent.map(|did| tcx.generics_of(did));
+    let count = g.parent_count + g.params.len();
+    let params = parent
+        .map_or([].as_slice(), |p| p.params.as_slice())
+        .iter()
+        .chain(&g.params)
+        .map(|x| &x.kind);
+
+    assert!(
+        count == args.len(),
+        "wrong number of arguments for `{did:?}`: expected `{count}`, found {}\n\
+            note: the expected arguments are: `[{}]`\n\
+            the given arguments are: `{args:#?}`",
+        args.len(),
+        params.clone().map(GenericParamDefKind::descr).format(", "),
+    );
+
+    if let Some((idx, (param, arg))) =
+        params
+            .clone()
+            .zip(args.iter().map(|&x| x.unpack()))
+            .enumerate()
+            .find(|(_, (param, arg))| match (param, arg) {
+                (GenericParamDefKind::Lifetime, GenericArgKind::Lifetime(_))
+                | (GenericParamDefKind::Type { .. }, GenericArgKind::Type(_))
+                | (GenericParamDefKind::Const { .. }, GenericArgKind::Const(_)) => false,
+                (
+                    GenericParamDefKind::Lifetime
+                    | GenericParamDefKind::Type { .. }
+                    | GenericParamDefKind::Const { .. },
+                    _,
+                ) => true,
+            })
+    {
+        panic!(
+            "incorrect argument for `{did:?}` at index `{idx}`: expected a {}, found `{arg:?}`\n\
+                note: the expected arguments are `[{}]`\n\
+                the given arguments are `{args:#?}`",
+            param.descr(),
+            params.clone().map(GenericParamDefKind::descr).format(", "),
+        );
+    }
+}
+
 /// Makes the projection type for the named associated type in the given impl or trait impl.
 ///
 /// This function is for associated types which are "known" to exist, and as such, will only return
 /// `None` when debug assertions are disabled in order to prevent ICE's. With debug assertions
 /// enabled this will check that the named associated type exists, the correct number of
-/// substitutions are given, and that the correct kinds of substitutions are given (lifetime,
+/// arguments are given, and that the correct kinds of arguments are given (lifetime,
 /// constant or type). This will not check if type normalization would succeed.
 pub fn make_projection<'tcx>(
     tcx: TyCtxt<'tcx>,
@@ -1043,49 +1117,7 @@ pub fn make_projection<'tcx>(
             return None;
         };
         #[cfg(debug_assertions)]
-        {
-            let generics = tcx.generics_of(assoc_item.def_id);
-            let generic_count = generics.parent_count + generics.params.len();
-            let params = generics
-                .parent
-                .map_or([].as_slice(), |id| &*tcx.generics_of(id).params)
-                .iter()
-                .chain(&generics.params)
-                .map(|x| &x.kind);
-
-            debug_assert!(
-                generic_count == args.len(),
-                "wrong number of args for `{:?}`: found `{}` expected `{generic_count}`.\n\
-                    note: the expected parameters are: {:#?}\n\
-                    the given arguments are: `{args:#?}`",
-                assoc_item.def_id,
-                args.len(),
-                params.map(ty::GenericParamDefKind::descr).collect::<Vec<_>>(),
-            );
-
-            if let Some((idx, (param, arg))) = params
-                .clone()
-                .zip(args.iter().map(GenericArg::unpack))
-                .enumerate()
-                .find(|(_, (param, arg))| {
-                    !matches!(
-                        (param, arg),
-                        (ty::GenericParamDefKind::Lifetime, GenericArgKind::Lifetime(_))
-                            | (ty::GenericParamDefKind::Type { .. }, GenericArgKind::Type(_))
-                            | (ty::GenericParamDefKind::Const { .. }, GenericArgKind::Const(_))
-                    )
-                })
-            {
-                debug_assert!(
-                    false,
-                    "mismatched subst type at index {idx}: expected a {}, found `{arg:?}`\n\
-                        note: the expected parameters are {:#?}\n\
-                        the given arguments are {args:#?}",
-                    param.descr(),
-                    params.map(ty::GenericParamDefKind::descr).collect::<Vec<_>>()
-                );
-            }
-        }
+        assert_generic_args_match(tcx, assoc_item.def_id, args);
 
         Some(tcx.mk_alias_ty(assoc_item.def_id, args))
     }
@@ -1100,7 +1132,7 @@ pub fn make_projection<'tcx>(
 /// Normalizes the named associated type in the given impl or trait impl.
 ///
 /// This function is for associated types which are "known" to be valid with the given
-/// substitutions, and as such, will only return `None` when debug assertions are disabled in order
+/// arguments, and as such, will only return `None` when debug assertions are disabled in order
 /// to prevent ICE's. With debug assertions enabled this will check that type normalization
 /// succeeds as well as everything checked by `make_projection`.
 pub fn make_normalized_projection<'tcx>(
@@ -1112,17 +1144,12 @@ pub fn make_normalized_projection<'tcx>(
 ) -> Option<Ty<'tcx>> {
     fn helper<'tcx>(tcx: TyCtxt<'tcx>, param_env: ParamEnv<'tcx>, ty: AliasTy<'tcx>) -> Option<Ty<'tcx>> {
         #[cfg(debug_assertions)]
-        if let Some((i, subst)) = ty
-            .args
-            .iter()
-            .enumerate()
-            .find(|(_, subst)| subst.has_late_bound_regions())
-        {
+        if let Some((i, arg)) = ty.args.iter().enumerate().find(|(_, arg)| arg.has_late_bound_regions()) {
             debug_assert!(
                 false,
                 "args contain late-bound region at index `{i}` which can't be normalized.\n\
                     use `TyCtxt::erase_late_bound_regions`\n\
-                    note: subst is `{subst:#?}`",
+                    note: arg is `{arg:#?}`",
             );
             return None;
         }
@@ -1190,17 +1217,12 @@ pub fn make_normalized_projection_with_regions<'tcx>(
 ) -> Option<Ty<'tcx>> {
     fn helper<'tcx>(tcx: TyCtxt<'tcx>, param_env: ParamEnv<'tcx>, ty: AliasTy<'tcx>) -> Option<Ty<'tcx>> {
         #[cfg(debug_assertions)]
-        if let Some((i, subst)) = ty
-            .args
-            .iter()
-            .enumerate()
-            .find(|(_, subst)| subst.has_late_bound_regions())
-        {
+        if let Some((i, arg)) = ty.args.iter().enumerate().find(|(_, arg)| arg.has_late_bound_regions()) {
             debug_assert!(
                 false,
                 "args contain late-bound region at index `{i}` which can't be normalized.\n\
                     use `TyCtxt::erase_late_bound_regions`\n\
-                    note: subst is `{subst:#?}`",
+                    note: arg is `{arg:#?}`",
             );
             return None;
         }