about summary refs log tree commit diff
path: root/compiler/rustc_mir/src/transform/instrument_coverage.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir/src/transform/instrument_coverage.rs')
-rw-r--r--compiler/rustc_mir/src/transform/instrument_coverage.rs247
1 files changed, 247 insertions, 0 deletions
diff --git a/compiler/rustc_mir/src/transform/instrument_coverage.rs b/compiler/rustc_mir/src/transform/instrument_coverage.rs
new file mode 100644
index 00000000000..f60e6da714a
--- /dev/null
+++ b/compiler/rustc_mir/src/transform/instrument_coverage.rs
@@ -0,0 +1,247 @@
+use crate::transform::{MirPass, MirSource};
+use rustc_data_structures::fingerprint::Fingerprint;
+use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
+use rustc_middle::hir;
+use rustc_middle::ich::StableHashingContext;
+use rustc_middle::mir;
+use rustc_middle::mir::coverage::*;
+use rustc_middle::mir::visit::Visitor;
+use rustc_middle::mir::{BasicBlock, Coverage, CoverageInfo, Location, Statement, StatementKind};
+use rustc_middle::ty::query::Providers;
+use rustc_middle::ty::TyCtxt;
+use rustc_span::def_id::DefId;
+use rustc_span::{FileName, Pos, RealFileName, Span, Symbol};
+
+/// Inserts call to count_code_region() as a placeholder to be replaced during code generation with
+/// the intrinsic llvm.instrprof.increment.
+pub struct InstrumentCoverage;
+
+/// The `query` provider for `CoverageInfo`, requested by `codegen_intrinsic_call()` when
+/// constructing the arguments for `llvm.instrprof.increment`.
+pub(crate) fn provide(providers: &mut Providers) {
+    providers.coverageinfo = |tcx, def_id| coverageinfo_from_mir(tcx, def_id);
+}
+
+struct CoverageVisitor {
+    info: CoverageInfo,
+}
+
+impl Visitor<'_> for CoverageVisitor {
+    fn visit_coverage(&mut self, coverage: &Coverage, _location: Location) {
+        match coverage.kind {
+            CoverageKind::Counter { id, .. } => {
+                let counter_id = u32::from(id);
+                self.info.num_counters = std::cmp::max(self.info.num_counters, counter_id + 1);
+            }
+            CoverageKind::Expression { id, .. } => {
+                let expression_index = u32::MAX - u32::from(id);
+                self.info.num_expressions =
+                    std::cmp::max(self.info.num_expressions, expression_index + 1);
+            }
+            _ => {}
+        }
+    }
+}
+
+fn coverageinfo_from_mir<'tcx>(tcx: TyCtxt<'tcx>, mir_def_id: DefId) -> CoverageInfo {
+    let mir_body = tcx.optimized_mir(mir_def_id);
+
+    // The `num_counters` argument to `llvm.instrprof.increment` is the number of injected
+    // counters, with each counter having a counter ID from `0..num_counters-1`. MIR optimization
+    // may split and duplicate some BasicBlock sequences. Simply counting the calls may not
+    // work; but computing the num_counters by adding `1` to the highest counter_id (for a given
+    // instrumented function) is valid.
+    //
+    // `num_expressions` is the number of counter expressions added to the MIR body. Both
+    // `num_counters` and `num_expressions` are used to initialize new vectors, during backend
+    // code generate, to lookup counters and expressions by simple u32 indexes.
+    let mut coverage_visitor =
+        CoverageVisitor { info: CoverageInfo { num_counters: 0, num_expressions: 0 } };
+
+    coverage_visitor.visit_body(mir_body);
+    coverage_visitor.info
+}
+
+impl<'tcx> MirPass<'tcx> for InstrumentCoverage {
+    fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, mir_body: &mut mir::Body<'tcx>) {
+        // If the InstrumentCoverage pass is called on promoted MIRs, skip them.
+        // See: https://github.com/rust-lang/rust/pull/73011#discussion_r438317601
+        if src.promoted.is_none() {
+            Instrumentor::new(tcx, src, mir_body).inject_counters();
+        }
+    }
+}
+
+struct Instrumentor<'a, 'tcx> {
+    tcx: TyCtxt<'tcx>,
+    mir_def_id: DefId,
+    mir_body: &'a mut mir::Body<'tcx>,
+    hir_body: &'tcx rustc_hir::Body<'tcx>,
+    function_source_hash: Option<u64>,
+    num_counters: u32,
+    num_expressions: u32,
+}
+
+impl<'a, 'tcx> Instrumentor<'a, 'tcx> {
+    fn new(tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, mir_body: &'a mut mir::Body<'tcx>) -> Self {
+        let mir_def_id = src.def_id();
+        let hir_body = hir_body(tcx, mir_def_id);
+        Self {
+            tcx,
+            mir_def_id,
+            mir_body,
+            hir_body,
+            function_source_hash: None,
+            num_counters: 0,
+            num_expressions: 0,
+        }
+    }
+
+    /// Counter IDs start from zero and go up.
+    fn next_counter(&mut self) -> CounterValueReference {
+        assert!(self.num_counters < u32::MAX - self.num_expressions);
+        let next = self.num_counters;
+        self.num_counters += 1;
+        CounterValueReference::from(next)
+    }
+
+    /// Expression IDs start from u32::MAX and go down because a CounterExpression can reference
+    /// (add or subtract counts) of both Counter regions and CounterExpression regions. The counter
+    /// expression operand IDs must be unique across both types.
+    fn next_expression(&mut self) -> InjectedExpressionIndex {
+        assert!(self.num_counters < u32::MAX - self.num_expressions);
+        let next = u32::MAX - self.num_expressions;
+        self.num_expressions += 1;
+        InjectedExpressionIndex::from(next)
+    }
+
+    fn function_source_hash(&mut self) -> u64 {
+        match self.function_source_hash {
+            Some(hash) => hash,
+            None => {
+                let hash = hash_mir_source(self.tcx, self.hir_body);
+                self.function_source_hash.replace(hash);
+                hash
+            }
+        }
+    }
+
+    fn inject_counters(&mut self) {
+        let body_span = self.hir_body.value.span;
+        debug!("instrumenting {:?}, span: {:?}", self.mir_def_id, body_span);
+
+        // FIXME(richkadel): As a first step, counters are only injected at the top of each
+        // function. The complete solution will inject counters at each conditional code branch.
+        let block = rustc_middle::mir::START_BLOCK;
+        let counter = self.make_counter();
+        self.inject_statement(counter, body_span, block);
+
+        // FIXME(richkadel): The next step to implement source based coverage analysis will be
+        // instrumenting branches within functions, and some regions will be counted by "counter
+        // expression". The function to inject counter expression is implemented. Replace this
+        // "fake use" with real use.
+        let fake_use = false;
+        if fake_use {
+            let add = false;
+            let fake_counter = CoverageKind::Counter {
+                function_source_hash: self.function_source_hash(),
+                id: CounterValueReference::from_u32(1),
+            };
+            let fake_expression = CoverageKind::Expression {
+                id: InjectedExpressionIndex::from(u32::MAX - 1),
+                lhs: ExpressionOperandId::from_u32(1),
+                op: Op::Add,
+                rhs: ExpressionOperandId::from_u32(2),
+            };
+
+            let lhs = fake_counter.as_operand_id();
+            let op = if add { Op::Add } else { Op::Subtract };
+            let rhs = fake_expression.as_operand_id();
+
+            let block = rustc_middle::mir::START_BLOCK;
+
+            let expression = self.make_expression(lhs, op, rhs);
+            self.inject_statement(expression, body_span, block);
+        }
+    }
+
+    fn make_counter(&mut self) -> CoverageKind {
+        CoverageKind::Counter {
+            function_source_hash: self.function_source_hash(),
+            id: self.next_counter(),
+        }
+    }
+
+    fn make_expression(
+        &mut self,
+        lhs: ExpressionOperandId,
+        op: Op,
+        rhs: ExpressionOperandId,
+    ) -> CoverageKind {
+        CoverageKind::Expression { id: self.next_expression(), lhs, op, rhs }
+    }
+
+    fn inject_statement(&mut self, coverage_kind: CoverageKind, span: Span, block: BasicBlock) {
+        let code_region = make_code_region(self.tcx, &span);
+        debug!("  injecting statement {:?} covering {:?}", coverage_kind, code_region);
+
+        let data = &mut self.mir_body[block];
+        let source_info = data.terminator().source_info;
+        let statement = Statement {
+            source_info,
+            kind: StatementKind::Coverage(box Coverage { kind: coverage_kind, code_region }),
+        };
+        data.statements.push(statement);
+    }
+}
+
+/// Convert the Span into its file name, start line and column, and end line and column
+fn make_code_region<'tcx>(tcx: TyCtxt<'tcx>, span: &Span) -> CodeRegion {
+    let source_map = tcx.sess.source_map();
+    let start = source_map.lookup_char_pos(span.lo());
+    let end = if span.hi() == span.lo() {
+        start.clone()
+    } else {
+        let end = source_map.lookup_char_pos(span.hi());
+        debug_assert_eq!(
+            start.file.name,
+            end.file.name,
+            "Region start ({:?} -> {:?}) and end ({:?} -> {:?}) don't come from the same source file!",
+            span.lo(),
+            start,
+            span.hi(),
+            end
+        );
+        end
+    };
+    match &start.file.name {
+        FileName::Real(RealFileName::Named(path)) => CodeRegion {
+            file_name: Symbol::intern(&path.to_string_lossy()),
+            start_line: start.line as u32,
+            start_col: start.col.to_u32() + 1,
+            end_line: end.line as u32,
+            end_col: end.col.to_u32() + 1,
+        },
+        _ => bug!("start.file.name should be a RealFileName, but it was: {:?}", start.file.name),
+    }
+}
+
+fn hir_body<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> &'tcx rustc_hir::Body<'tcx> {
+    let hir_node = tcx.hir().get_if_local(def_id).expect("DefId is local");
+    let fn_body_id = hir::map::associated_body(hir_node).expect("HIR node is a function with body");
+    tcx.hir().body(fn_body_id)
+}
+
+fn hash_mir_source<'tcx>(tcx: TyCtxt<'tcx>, hir_body: &'tcx rustc_hir::Body<'tcx>) -> u64 {
+    let mut hcx = tcx.create_no_span_stable_hashing_context();
+    hash(&mut hcx, &hir_body.value).to_smaller_hash()
+}
+
+fn hash(
+    hcx: &mut StableHashingContext<'tcx>,
+    node: &impl HashStable<StableHashingContext<'tcx>>,
+) -> Fingerprint {
+    let mut stable_hasher = StableHasher::new();
+    node.hash_stable(hcx, &mut stable_hasher);
+    stable_hasher.finish()
+}