about summary refs log tree commit diff
path: root/compiler/rustc_mir/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir/src')
-rw-r--r--compiler/rustc_mir/src/transform/function_references.rs204
-rw-r--r--compiler/rustc_mir/src/transform/mod.rs2
2 files changed, 113 insertions, 93 deletions
diff --git a/compiler/rustc_mir/src/transform/function_references.rs b/compiler/rustc_mir/src/transform/function_references.rs
index 2daa468136f..8adeac7623b 100644
--- a/compiler/rustc_mir/src/transform/function_references.rs
+++ b/compiler/rustc_mir/src/transform/function_references.rs
@@ -1,134 +1,154 @@
 use rustc_hir::def_id::DefId;
 use rustc_middle::mir::visit::Visitor;
 use rustc_middle::mir::*;
-use rustc_middle::ty::{self, Ty, TyCtxt};
-use rustc_session::lint::builtin::FUNCTION_REFERENCES;
-use rustc_span::Span;
+use rustc_middle::ty::{self, subst::GenericArgKind, PredicateAtom, Ty, TyCtxt, TyS};
+use rustc_session::lint::builtin::FUNCTION_ITEM_REFERENCES;
+use rustc_span::{symbol::sym, Span};
 use rustc_target::spec::abi::Abi;
 
-use crate::transform::{MirPass, MirSource};
+use crate::transform::MirPass;
 
-pub struct FunctionReferences;
+pub struct FunctionItemReferences;
 
-impl<'tcx> MirPass<'tcx> for FunctionReferences {
-    fn run_pass(&self, tcx: TyCtxt<'tcx>, _src: MirSource<'tcx>, body: &mut Body<'tcx>) {
-        let source_info = SourceInfo::outermost(body.span);
-        let mut checker = FunctionRefChecker {
-            tcx,
-            body,
-            potential_lints: Vec::new(),
-            casts: Vec::new(),
-            calls: Vec::new(),
-            source_info,
-        };
+impl<'tcx> MirPass<'tcx> for FunctionItemReferences {
+    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+        let mut checker = FunctionItemRefChecker { tcx, body };
         checker.visit_body(&body);
     }
 }
 
-struct FunctionRefChecker<'a, 'tcx> {
+struct FunctionItemRefChecker<'a, 'tcx> {
     tcx: TyCtxt<'tcx>,
     body: &'a Body<'tcx>,
-    potential_lints: Vec<FunctionRefLint>,
-    casts: Vec<Span>,
-    calls: Vec<Span>,
-    source_info: SourceInfo,
 }
 
-impl<'a, 'tcx> Visitor<'tcx> for FunctionRefChecker<'a, 'tcx> {
-    fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) {
-        self.super_basic_block_data(block, data);
-        for cast_span in self.casts.drain(..) {
-            self.potential_lints.retain(|lint| lint.source_info.span != cast_span);
-        }
-        for call_span in self.calls.drain(..) {
-            self.potential_lints.retain(|lint| lint.source_info.span != call_span);
-        }
-        for lint in self.potential_lints.drain(..) {
-            lint.emit(self.tcx, self.body);
-        }
-    }
-    fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
-        self.source_info = statement.source_info;
-        self.super_statement(statement, location);
-    }
+impl<'a, 'tcx> Visitor<'tcx> for FunctionItemRefChecker<'a, 'tcx> {
     fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
-        self.source_info = terminator.source_info;
         if let TerminatorKind::Call {
             func,
-            args: _,
+            args,
             destination: _,
             cleanup: _,
             from_hir_call: _,
             fn_span: _,
         } = &terminator.kind
         {
-            let span = match func {
-                Operand::Copy(place) | Operand::Move(place) => {
-                    self.body.local_decls[place.local].source_info.span
+            let func_ty = func.ty(self.body, self.tcx);
+            if let ty::FnDef(def_id, substs_ref) = *func_ty.kind() {
+                //check arguments for `std::mem::transmute`
+                if self.tcx.is_diagnostic_item(sym::transmute, def_id) {
+                    let arg_ty = args[0].ty(self.body, self.tcx);
+                    for generic_inner_ty in arg_ty.walk() {
+                        if let GenericArgKind::Type(inner_ty) = generic_inner_ty.unpack() {
+                            if let Some(fn_id) = FunctionItemRefChecker::is_fn_ref(inner_ty) {
+                                let ident = self.tcx.item_name(fn_id).to_ident_string();
+                                let source_info = *self.body.source_info(location);
+                                let span = self.nth_arg_span(&args, 0);
+                                self.emit_lint(ident, fn_id, source_info, span);
+                            }
+                        }
+                    }
+                } else {
+                    //check arguments for any function with `std::fmt::Pointer` as a bound trait
+                    let param_env = self.tcx.param_env(def_id);
+                    let bounds = param_env.caller_bounds();
+                    for bound in bounds {
+                        if let Some(bound_ty) = self.is_pointer_trait(&bound.skip_binders()) {
+                            let arg_defs = self.tcx.fn_sig(def_id).skip_binder().inputs();
+                            for (arg_num, arg_def) in arg_defs.iter().enumerate() {
+                                for generic_inner_ty in arg_def.walk() {
+                                    if let GenericArgKind::Type(inner_ty) =
+                                        generic_inner_ty.unpack()
+                                    {
+                                        //if any type reachable from the argument types in the fn sig matches the type bound by `Pointer`
+                                        if TyS::same_type(inner_ty, bound_ty) {
+                                            //check if this type is a function reference in the function call
+                                            let norm_ty =
+                                                self.tcx.subst_and_normalize_erasing_regions(
+                                                    substs_ref, param_env, &inner_ty,
+                                                );
+                                            if let Some(fn_id) =
+                                                FunctionItemRefChecker::is_fn_ref(norm_ty)
+                                            {
+                                                let ident =
+                                                    self.tcx.item_name(fn_id).to_ident_string();
+                                                let source_info = *self.body.source_info(location);
+                                                let span = self.nth_arg_span(&args, arg_num);
+                                                self.emit_lint(ident, fn_id, source_info, span);
+                                            }
+                                        }
+                                    }
+                                }
+                            }
+                        }
+                    }
                 }
-                Operand::Constant(constant) => constant.span,
-            };
-            self.calls.push(span);
-        };
+            }
+        }
         self.super_terminator(terminator, location);
     }
-    fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
-        match rvalue {
-            Rvalue::Ref(_, _, place) | Rvalue::AddressOf(_, place) => {
-                let decl = &self.body.local_decls[place.local];
-                if let ty::FnDef(def_id, _) = decl.ty.kind {
-                    let ident = self
-                        .body
-                        .var_debug_info
-                        .iter()
-                        .find(|info| info.source_info.span == decl.source_info.span)
-                        .map(|info| info.name.to_ident_string())
-                        .unwrap_or(self.tcx.def_path_str(def_id));
-                    let lint = FunctionRefLint { ident, def_id, source_info: self.source_info };
-                    self.potential_lints.push(lint);
-                }
-            }
-            Rvalue::Cast(_, op, _) => {
-                let op_ty = op.ty(self.body, self.tcx);
-                if self.is_fn_ref(op_ty) {
-                    self.casts.push(self.source_info.span);
+    //check for `std::fmt::Pointer::<T>::fmt` where T is a function reference
+    //this is used in formatting macros, but doesn't rely on the specific expansion
+    fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
+        let op_ty = operand.ty(self.body, self.tcx);
+        if let ty::FnDef(def_id, substs_ref) = *op_ty.kind() {
+            if self.tcx.is_diagnostic_item(sym::pointer_trait_fmt, def_id) {
+                let param_ty = substs_ref.type_at(0);
+                if let Some(fn_id) = FunctionItemRefChecker::is_fn_ref(param_ty) {
+                    let source_info = *self.body.source_info(location);
+                    let callsite_ctxt = source_info.span.source_callsite().ctxt();
+                    let span = source_info.span.with_ctxt(callsite_ctxt);
+                    let ident = self.tcx.item_name(fn_id).to_ident_string();
+                    self.emit_lint(ident, fn_id, source_info, span);
                 }
             }
-            _ => {}
         }
-        self.super_rvalue(rvalue, location);
+        self.super_operand(operand, location);
     }
 }
 
-impl<'a, 'tcx> FunctionRefChecker<'a, 'tcx> {
-    fn is_fn_ref(&self, ty: Ty<'tcx>) -> bool {
-        let referent_ty = match ty.kind {
+impl<'a, 'tcx> FunctionItemRefChecker<'a, 'tcx> {
+    //return the bound parameter type if the trait is `std::fmt::Pointer`
+    fn is_pointer_trait(&self, bound: &PredicateAtom<'tcx>) -> Option<Ty<'tcx>> {
+        if let ty::PredicateAtom::Trait(predicate, _) = bound {
+            if self.tcx.is_diagnostic_item(sym::pointer_trait, predicate.def_id()) {
+                Some(predicate.trait_ref.self_ty())
+            } else {
+                None
+            }
+        } else {
+            None
+        }
+    }
+    fn is_fn_ref(ty: Ty<'tcx>) -> Option<DefId> {
+        let referent_ty = match ty.kind() {
             ty::Ref(_, referent_ty, _) => Some(referent_ty),
-            ty::RawPtr(ty_and_mut) => Some(ty_and_mut.ty),
+            ty::RawPtr(ty_and_mut) => Some(&ty_and_mut.ty),
             _ => None,
         };
         referent_ty
-            .map(|ref_ty| if let ty::FnDef(..) = ref_ty.kind { true } else { false })
-            .unwrap_or(false)
+            .map(
+                |ref_ty| {
+                    if let ty::FnDef(def_id, _) = *ref_ty.kind() { Some(def_id) } else { None }
+                },
+            )
+            .unwrap_or(None)
     }
-}
-
-struct FunctionRefLint {
-    ident: String,
-    def_id: DefId,
-    source_info: SourceInfo,
-}
-
-impl<'tcx> FunctionRefLint {
-    fn emit(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
-        let def_id = self.def_id;
-        let source_info = self.source_info;
-        let lint_root = body.source_scopes[source_info.scope]
+    fn nth_arg_span(&self, args: &Vec<Operand<'tcx>>, n: usize) -> Span {
+        match &args[n] {
+            Operand::Copy(place) | Operand::Move(place) => {
+                self.body.local_decls[place.local].source_info.span
+            }
+            Operand::Constant(constant) => constant.span,
+        }
+    }
+    fn emit_lint(&self, ident: String, fn_id: DefId, source_info: SourceInfo, span: Span) {
+        let lint_root = self.body.source_scopes[source_info.scope]
             .local_data
             .as_ref()
             .assert_crate_local()
             .lint_root;
-        let fn_sig = tcx.fn_sig(def_id);
+        let fn_sig = self.tcx.fn_sig(fn_id);
         let unsafety = fn_sig.unsafety().prefix_str();
         let abi = match fn_sig.abi() {
             Abi::Rust => String::from(""),
@@ -142,17 +162,17 @@ impl<'tcx> FunctionRefLint {
         let num_args = fn_sig.inputs().map_bound(|inputs| inputs.len()).skip_binder();
         let variadic = if fn_sig.c_variadic() { ", ..." } else { "" };
         let ret = if fn_sig.output().skip_binder().is_unit() { "" } else { " -> _" };
-        tcx.struct_span_lint_hir(FUNCTION_REFERENCES, lint_root, source_info.span, |lint| {
+        self.tcx.struct_span_lint_hir(FUNCTION_ITEM_REFERENCES, lint_root, span, |lint| {
             lint.build(&format!(
                 "cast `{}` with `as {}{}fn({}{}){}` to use it as a pointer",
-                self.ident,
+                ident,
                 unsafety,
                 abi,
                 vec!["_"; num_args].join(", "),
                 variadic,
                 ret,
             ))
-            .emit()
+            .emit();
         });
     }
 }
diff --git a/compiler/rustc_mir/src/transform/mod.rs b/compiler/rustc_mir/src/transform/mod.rs
index 3f50420b86b..e43a238c1ba 100644
--- a/compiler/rustc_mir/src/transform/mod.rs
+++ b/compiler/rustc_mir/src/transform/mod.rs
@@ -267,7 +267,7 @@ fn mir_const<'tcx>(
             // MIR-level lints.
             &check_packed_ref::CheckPackedRef,
             &check_const_item_mutation::CheckConstItemMutation,
-            &function_references::FunctionReferences,
+            &function_references::FunctionItemReferences,
             // What we need to do constant evaluation.
             &simplify::SimplifyCfg::new("initial"),
             &rustc_peek::SanityCheck,