about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRich Kadel <richkadel@google.com>2020-06-22 23:31:41 -0700
committerRich Kadel <richkadel@google.com>2020-06-22 23:50:30 -0700
commit977ce57d915914139c4aa643e63f368913e5f437 (patch)
tree2a66c67004ce3486958d045ebd0b9f7edf038c26
parenta04514026824f9342ab93d9b608e3ec5dab53dad (diff)
downloadrust-977ce57d915914139c4aa643e63f368913e5f437.tar.gz
rust-977ce57d915914139c4aa643e63f368913e5f437.zip
Updated query for num_counters to compute from max index
Also added FIXME comments to note the possible need to accommodate
counter increment calls in source-based functions that differ from the
function context of the caller instance (e.g., inline functions).
-rw-r--r--src/librustc_codegen_llvm/intrinsic.rs3
-rw-r--r--src/librustc_mir/transform/instrument_coverage.rs28
2 files changed, 27 insertions, 4 deletions
diff --git a/src/librustc_codegen_llvm/intrinsic.rs b/src/librustc_codegen_llvm/intrinsic.rs
index b9193a85b1e..dfe97b1ee2e 100644
--- a/src/librustc_codegen_llvm/intrinsic.rs
+++ b/src/librustc_codegen_llvm/intrinsic.rs
@@ -140,6 +140,9 @@ impl IntrinsicCallMethods<'tcx> for Builder<'a, 'll, 'tcx> {
                 self.call(llfn, &[], None)
             }
             "count_code_region" => {
+                // FIXME(richkadel): The current implementation assumes the MIR for the given
+                // caller_instance represents a single function. Validate and/or correct if inlining
+                // and/or monomorphization invalidates these assumptions.
                 let coverage_data = tcx.coverage_data(caller_instance.def_id());
                 let mangled_fn = tcx.symbol_name(caller_instance);
                 let (mangled_fn_name, _len_val) = self.const_str(mangled_fn.name);
diff --git a/src/librustc_mir/transform/instrument_coverage.rs b/src/librustc_mir/transform/instrument_coverage.rs
index 27aaf47bbf2..06b648ab5a9 100644
--- a/src/librustc_mir/transform/instrument_coverage.rs
+++ b/src/librustc_mir/transform/instrument_coverage.rs
@@ -5,15 +5,15 @@ use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
 use rustc_hir::lang_items;
 use rustc_middle::hir;
 use rustc_middle::ich::StableHashingContext;
-use rustc_middle::mir::interpret::Scalar;
+use rustc_middle::mir::interpret::{ConstValue, Scalar};
 use rustc_middle::mir::{
     self, traversal, BasicBlock, BasicBlockData, CoverageData, Operand, Place, SourceInfo,
     StatementKind, Terminator, TerminatorKind, START_BLOCK,
 };
 use rustc_middle::ty;
 use rustc_middle::ty::query::Providers;
-use rustc_middle::ty::FnDef;
 use rustc_middle::ty::TyCtxt;
+use rustc_middle::ty::{ConstKind, FnDef};
 use rustc_span::def_id::DefId;
 use rustc_span::Span;
 
@@ -26,16 +26,36 @@ pub struct InstrumentCoverage;
 pub(crate) fn provide(providers: &mut Providers<'_>) {
     providers.coverage_data = |tcx, def_id| {
         let mir_body = tcx.optimized_mir(def_id);
+        // FIXME(richkadel): The current implementation assumes the MIR for the given DefId
+        // represents a single function. Validate and/or correct if inlining and/or monomorphization
+        // invalidates these assumptions.
         let count_code_region_fn =
             tcx.require_lang_item(lang_items::CountCodeRegionFnLangItem, None);
         let mut num_counters: u32 = 0;
+        // The `num_counters` argument to `llvm.instrprof.increment` is the number of injected
+        // counters, with each counter having an index from `0..num_counters-1`. MIR optimization
+        // may split and duplicate some BasicBlock sequences. Simply counting the calls may not
+        // not work; but computing the num_counters by adding `1` to the highest index (for a given
+        // instrumented function) is valid.
         for (_, data) in traversal::preorder(mir_body) {
             if let Some(terminator) = &data.terminator {
-                if let TerminatorKind::Call { func: Operand::Constant(func), .. } = &terminator.kind
+                if let TerminatorKind::Call { func: Operand::Constant(func), args, .. } =
+                    &terminator.kind
                 {
                     if let FnDef(called_fn_def_id, _) = func.literal.ty.kind {
                         if called_fn_def_id == count_code_region_fn {
-                            num_counters += 1;
+                            if let Operand::Constant(constant) =
+                                args.get(0).expect("count_code_region has at least one arg")
+                            {
+                                if let ConstKind::Value(ConstValue::Scalar(value)) =
+                                    constant.literal.val
+                                {
+                                    let index = value
+                                        .to_u32()
+                                        .expect("count_code_region index at arg0 is u32");
+                                    num_counters = std::cmp::max(num_counters, index + 1);
+                                }
+                            }
                         }
                     }
                 }