about summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlex Macleod <alex@macleod.io>2023-02-01 21:31:09 +0000
committerAlex Macleod <alex@macleod.io>2023-03-06 21:38:32 +0000
commita2906a1598ca16ac080f95658a8a2eae10ff25d5 (patch)
treee66cfe3de55406f8e8f259f83ef4e4f88fcf7224
parent903595801e86c1d6bf8bd8aeeeb9127b04f19036 (diff)
downloadrust-a2906a1598ca16ac080f95658a8a2eae10ff25d5.tar.gz
rust-a2906a1598ca16ac080f95658a8a2eae10ff25d5.zip
Migrate `write.rs` to `rustc_ast::FormatArgs`
-rw-r--r--clippy_lints/src/lib.rs1
-rw-r--r--clippy_lints/src/write.rs158
-rw-r--r--clippy_utils/src/macros.rs79
3 files changed, 173 insertions, 65 deletions
diff --git a/clippy_lints/src/lib.rs b/clippy_lints/src/lib.rs
index c626e0bd998..a025e3cc12a 100644
--- a/clippy_lints/src/lib.rs
+++ b/clippy_lints/src/lib.rs
@@ -870,6 +870,7 @@ pub fn register_plugins(store: &mut rustc_lint::LintStore, sess: &Session, conf:
     let allow_dbg_in_tests = conf.allow_dbg_in_tests;
     store.register_late_pass(move |_| Box::new(dbg_macro::DbgMacro::new(allow_dbg_in_tests)));
     let allow_print_in_tests = conf.allow_print_in_tests;
+    store.register_early_pass(move || Box::new(write::Write::new(allow_print_in_tests)));
     store.register_late_pass(move |_| Box::new(write::Write::new(allow_print_in_tests)));
     let cargo_ignore_publish = conf.cargo_ignore_publish;
     store.register_late_pass(move |_| {
diff --git a/clippy_lints/src/write.rs b/clippy_lints/src/write.rs
index df335038881..ce419ec9373 100644
--- a/clippy_lints/src/write.rs
+++ b/clippy_lints/src/write.rs
@@ -1,11 +1,14 @@
 use clippy_utils::diagnostics::{span_lint, span_lint_and_then};
-use clippy_utils::macros::{root_macro_call_first_node, FormatArgsExpn, MacroCall};
+use clippy_utils::macros::{
+    find_format_args, format_arg_removal_span, populate_ast_format_args, root_macro_call_first_node, MacroCall,
+};
 use clippy_utils::source::{expand_past_previous_comma, snippet_opt};
 use clippy_utils::{is_in_cfg_test, is_in_test_function};
-use rustc_ast::LitKind;
+use rustc_ast::token::LitKind;
+use rustc_ast::{FormatArgPosition, FormatArgs, FormatArgsPiece, FormatOptions, FormatPlaceholder, FormatTrait};
 use rustc_errors::Applicability;
-use rustc_hir::{Expr, ExprKind, HirIdMap, Impl, Item, ItemKind};
-use rustc_lint::{LateContext, LateLintPass, LintContext};
+use rustc_hir::{Expr, Impl, Item, ItemKind};
+use rustc_lint::{EarlyLintPass, LateContext, LateLintPass, LintContext};
 use rustc_session::{declare_tool_lint, impl_lint_pass};
 use rustc_span::{sym, BytePos};
 
@@ -257,6 +260,12 @@ impl_lint_pass!(Write => [
     WRITE_LITERAL,
 ]);
 
+impl EarlyLintPass for Write {
+    fn check_expr(&mut self, _: &rustc_lint::EarlyContext<'_>, expr: &rustc_ast::Expr) {
+        populate_ast_format_args(expr);
+    }
+}
+
 impl<'tcx> LateLintPass<'tcx> for Write {
     fn check_item(&mut self, cx: &LateContext<'_>, item: &Item<'_>) {
         if is_debug_impl(cx, item) {
@@ -297,34 +306,40 @@ impl<'tcx> LateLintPass<'tcx> for Write {
             _ => return,
         }
 
-        let Some(format_args) = FormatArgsExpn::find_nested(cx, expr, macro_call.expn) else { return };
-
-        // ignore `writeln!(w)` and `write!(v, some_macro!())`
-        if format_args.format_string.span.from_expansion() {
-            return;
-        }
+        find_format_args(cx, expr, macro_call.expn, |format_args| {
+            // ignore `writeln!(w)` and `write!(v, some_macro!())`
+            if format_args.span.from_expansion() {
+                return;
+            }
 
-        match diag_name {
-            sym::print_macro | sym::eprint_macro | sym::write_macro => {
-                check_newline(cx, &format_args, &macro_call, name);
-            },
-            sym::println_macro | sym::eprintln_macro | sym::writeln_macro => {
-                check_empty_string(cx, &format_args, &macro_call, name);
-            },
-            _ => {},
-        }
+            match diag_name {
+                sym::print_macro | sym::eprint_macro | sym::write_macro => {
+                    check_newline(cx, format_args, &macro_call, name);
+                },
+                sym::println_macro | sym::eprintln_macro | sym::writeln_macro => {
+                    check_empty_string(cx, format_args, &macro_call, name);
+                },
+                _ => {},
+            }
 
-        check_literal(cx, &format_args, name);
+            check_literal(cx, format_args, name);
 
-        if !self.in_debug_impl {
-            for arg in &format_args.args {
-                if arg.format.r#trait == sym::Debug {
-                    span_lint(cx, USE_DEBUG, arg.span, "use of `Debug`-based formatting");
+            if !self.in_debug_impl {
+                for piece in &format_args.template {
+                    if let &FormatArgsPiece::Placeholder(FormatPlaceholder {
+                        span: Some(span),
+                        format_trait: FormatTrait::Debug,
+                        ..
+                    }) = piece
+                    {
+                        span_lint(cx, USE_DEBUG, span, "use of `Debug`-based formatting");
+                    }
                 }
             }
-        }
+        });
     }
 }
+
 fn is_debug_impl(cx: &LateContext<'_>, item: &Item<'_>) -> bool {
     if let ItemKind::Impl(Impl { of_trait: Some(trait_ref), .. }) = &item.kind
         && let Some(trait_id) = trait_ref.trait_def_id()
@@ -335,16 +350,18 @@ fn is_debug_impl(cx: &LateContext<'_>, item: &Item<'_>) -> bool {
     }
 }
 
-fn check_newline(cx: &LateContext<'_>, format_args: &FormatArgsExpn<'_>, macro_call: &MacroCall, name: &str) {
-    let format_string_parts = &format_args.format_string.parts;
-    let mut format_string_span = format_args.format_string.span;
-
-    let Some(last) = format_string_parts.last() else { return };
+fn check_newline(cx: &LateContext<'_>, format_args: &FormatArgs, macro_call: &MacroCall, name: &str) {
+    let Some(FormatArgsPiece::Literal(last)) = format_args.template.last() else { return };
 
     let count_vertical_whitespace = || {
-        format_string_parts
+        format_args
+            .template
             .iter()
-            .flat_map(|part| part.as_str().chars())
+            .filter_map(|piece| match piece {
+                FormatArgsPiece::Literal(literal) => Some(literal),
+                FormatArgsPiece::Placeholder(_) => None,
+            })
+            .flat_map(|literal| literal.as_str().chars())
             .filter(|ch| matches!(ch, '\r' | '\n'))
             .count()
     };
@@ -352,10 +369,9 @@ fn check_newline(cx: &LateContext<'_>, format_args: &FormatArgsExpn<'_>, macro_c
     if last.as_str().ends_with('\n')
         // ignore format strings with other internal vertical whitespace
         && count_vertical_whitespace() == 1
-
-        // ignore trailing arguments: `print!("Issue\n{}", 1265);`
-        && format_string_parts.len() > format_args.args.len()
     {
+        let mut format_string_span = format_args.span;
+
         let lint = if name == "write" {
             format_string_span = expand_past_previous_comma(cx, format_string_span);
 
@@ -373,7 +389,7 @@ fn check_newline(cx: &LateContext<'_>, format_args: &FormatArgsExpn<'_>, macro_c
                 let name_span = cx.sess().source_map().span_until_char(macro_call.span, '!');
                 let Some(format_snippet) = snippet_opt(cx, format_string_span) else { return };
 
-                if format_string_parts.len() == 1 && last.as_str() == "\n" {
+                if format_args.template.len() == 1 && last.as_str() == "\n" {
                     // print!("\n"), write!(f, "\n")
 
                     diag.multipart_suggestion(
@@ -398,11 +414,12 @@ fn check_newline(cx: &LateContext<'_>, format_args: &FormatArgsExpn<'_>, macro_c
     }
 }
 
-fn check_empty_string(cx: &LateContext<'_>, format_args: &FormatArgsExpn<'_>, macro_call: &MacroCall, name: &str) {
-    if let [part] = &format_args.format_string.parts[..]
-        && let mut span = format_args.format_string.span
-        && part.as_str() == "\n"
+fn check_empty_string(cx: &LateContext<'_>, format_args: &FormatArgs, macro_call: &MacroCall, name: &str) {
+    if let [FormatArgsPiece::Literal(literal)] = &format_args.template[..]
+        && literal.as_str() == "\n"
     {
+        let mut span = format_args.span;
+
         let lint = if name == "writeln" {
             span = expand_past_previous_comma(cx, span);
 
@@ -428,33 +445,43 @@ fn check_empty_string(cx: &LateContext<'_>, format_args: &FormatArgsExpn<'_>, ma
     }
 }
 
-fn check_literal(cx: &LateContext<'_>, format_args: &FormatArgsExpn<'_>, name: &str) {
-    let mut counts = HirIdMap::<usize>::default();
-    for param in format_args.params() {
-        *counts.entry(param.value.hir_id).or_default() += 1;
+fn check_literal(cx: &LateContext<'_>, format_args: &FormatArgs, name: &str) {
+    let arg_index = |argument: &FormatArgPosition| argument.index.unwrap_or_else(|pos| pos);
+
+    let mut counts = vec![0u32; format_args.arguments.all_args().len()];
+    for piece in &format_args.template {
+        if let FormatArgsPiece::Placeholder(placeholder) = piece {
+            counts[arg_index(&placeholder.argument)] += 1;
+        }
     }
 
-    for arg in &format_args.args {
-        let value = arg.param.value;
-
-        if counts[&value.hir_id] == 1
-            && arg.format.is_default()
-            && let ExprKind::Lit(lit) = &value.kind
-            && !value.span.from_expansion()
-            && let Some(value_string) = snippet_opt(cx, value.span)
-        {
-            let (replacement, replace_raw) = match lit.node {
-                LitKind::Str(..) => extract_str_literal(&value_string),
-                LitKind::Char(ch) => (
-                    match ch {
-                        '"' => "\\\"",
-                        '\'' => "'",
+    for piece in &format_args.template {
+        if let FormatArgsPiece::Placeholder(FormatPlaceholder {
+            argument,
+            span: Some(placeholder_span),
+            format_trait: FormatTrait::Display,
+            format_options,
+        }) = piece
+            && *format_options == FormatOptions::default()
+            && let index = arg_index(argument)
+            && counts[index] == 1
+            && let Some(arg) = format_args.arguments.by_index(index)
+            && let rustc_ast::ExprKind::Lit(lit) = &arg.expr.kind
+            && !arg.expr.span.from_expansion()
+            && let Some(value_string) = snippet_opt(cx, arg.expr.span)
+    {
+            let (replacement, replace_raw) = match lit.kind {
+                LitKind::Str | LitKind::StrRaw(_)  => extract_str_literal(&value_string),
+                LitKind::Char => (
+                    match lit.symbol.as_str() {
+                        "\"" => "\\\"",
+                        "\\'" => "'",
                         _ => &value_string[1..value_string.len() - 1],
                     }
                     .to_string(),
                     false,
                 ),
-                LitKind::Bool(b) => (b.to_string(), false),
+                LitKind::Bool => (lit.symbol.to_string(), false),
                 _ => continue,
             };
 
@@ -464,7 +491,9 @@ fn check_literal(cx: &LateContext<'_>, format_args: &FormatArgsExpn<'_>, name: &
                 PRINT_LITERAL
             };
 
-            let format_string_is_raw = format_args.format_string.style.is_some();
+            let Some(format_string_snippet) = snippet_opt(cx, format_args.span) else { continue };
+            let format_string_is_raw = format_string_snippet.starts_with('r');
+
             let replacement = match (format_string_is_raw, replace_raw) {
                 (false, false) => Some(replacement),
                 (false, true) => Some(replacement.replace('"', "\\\"").replace('\\', "\\\\")),
@@ -485,23 +514,24 @@ fn check_literal(cx: &LateContext<'_>, format_args: &FormatArgsExpn<'_>, name: &
             span_lint_and_then(
                 cx,
                 lint,
-                value.span,
+                arg.expr.span,
                 "literal with an empty format string",
                 |diag| {
                     if let Some(replacement) = replacement
                         // `format!("{}", "a")`, `format!("{named}", named = "b")
                         //              ~~~~~                      ~~~~~~~~~~~~~
-                        && let Some(value_span) = format_args.value_with_prev_comma_span(value.hir_id)
+                        && let Some(removal_span) = format_arg_removal_span(format_args, index)
                     {
                         let replacement = replacement.replace('{', "{{").replace('}', "}}");
                         diag.multipart_suggestion(
                             "try this",
-                            vec![(arg.span, replacement), (value_span, String::new())],
+                            vec![(*placeholder_span, replacement), (removal_span, String::new())],
                             Applicability::MachineApplicable,
                         );
                     }
                 },
             );
+
         }
     }
 }
diff --git a/clippy_utils/src/macros.rs b/clippy_utils/src/macros.rs
index be6133d3202..8e0edd72f92 100644
--- a/clippy_utils/src/macros.rs
+++ b/clippy_utils/src/macros.rs
@@ -6,6 +6,8 @@ use crate::visitors::{for_each_expr, Descend};
 use arrayvec::ArrayVec;
 use itertools::{izip, Either, Itertools};
 use rustc_ast::ast::LitKind;
+use rustc_ast::FormatArgs;
+use rustc_data_structures::fx::FxHashMap;
 use rustc_hir::intravisit::{walk_expr, Visitor};
 use rustc_hir::{self as hir, Expr, ExprField, ExprKind, HirId, LangItem, Node, QPath, TyKind};
 use rustc_lexer::unescape::unescape_literal;
@@ -15,8 +17,10 @@ use rustc_parse_format::{self as rpf, Alignment};
 use rustc_span::def_id::DefId;
 use rustc_span::hygiene::{self, MacroKind, SyntaxContext};
 use rustc_span::{sym, BytePos, ExpnData, ExpnId, ExpnKind, Pos, Span, SpanData, Symbol};
+use std::cell::RefCell;
 use std::iter::{once, zip};
-use std::ops::ControlFlow;
+use std::ops::{ControlFlow, Deref};
+use std::sync::atomic::{AtomicBool, Ordering};
 
 const FORMAT_MACRO_DIAG_ITEMS: &[Symbol] = &[
     sym::assert_eq_macro,
@@ -339,6 +343,79 @@ fn is_assert_arg(cx: &LateContext<'_>, expr: &Expr<'_>, assert_expn: ExpnId) ->
     }
 }
 
+thread_local! {
+    /// We preserve the [`FormatArgs`] structs from the early pass for use in the late pass to be
+    /// able to access the many features of a [`LateContext`].
+    ///
+    /// A thread local is used because [`FormatArgs`] is `!Send` and `!Sync`, we are making an
+    /// assumption that the early pass the populates the map and the later late passes will all be
+    /// running on the same thread.
+    static AST_FORMAT_ARGS: RefCell<FxHashMap<Span, FormatArgs>> = {
+        static CALLED: AtomicBool = AtomicBool::new(false);
+        debug_assert!(
+            !CALLED.swap(true, Ordering::SeqCst),
+            "incorrect assumption: `AST_FORMAT_ARGS` should only be accessed by a single thread",
+        );
+
+        RefCell::default()
+    };
+}
+
+/// Record [`rustc_ast::FormatArgs`] for use in late lint passes, this only needs to be called by
+/// one lint pass.
+pub fn populate_ast_format_args(expr: &rustc_ast::Expr) {
+    if let rustc_ast::ExprKind::FormatArgs(args) = &expr.kind {
+        AST_FORMAT_ARGS.with(|ast_format_args| {
+            ast_format_args.borrow_mut().insert(expr.span, args.deref().clone());
+        });
+    }
+}
+
+/// Calls `callback` with an AST [`FormatArgs`] node if one is found
+pub fn find_format_args(cx: &LateContext<'_>, start: &Expr<'_>, expn_id: ExpnId, callback: impl FnOnce(&FormatArgs)) {
+    let format_args_expr = for_each_expr(start, |expr| {
+        let ctxt = expr.span.ctxt();
+        if ctxt == start.span.ctxt() {
+            ControlFlow::Continue(Descend::Yes)
+        } else if ctxt.outer_expn().is_descendant_of(expn_id)
+            && macro_backtrace(expr.span)
+                .map(|macro_call| cx.tcx.item_name(macro_call.def_id))
+                .any(|name| matches!(name, sym::const_format_args | sym::format_args | sym::format_args_nl))
+        {
+            ControlFlow::Break(expr)
+        } else {
+            ControlFlow::Continue(Descend::No)
+        }
+    });
+
+    if let Some(format_args_expr) = format_args_expr {
+        AST_FORMAT_ARGS.with(|ast_format_args| {
+            ast_format_args.borrow().get(&format_args_expr.span).map(callback);
+        });
+    }
+}
+
+/// Returns the [`Span`] of the value at `index` extended to the previous comma, e.g. for the value
+/// `10`
+///
+/// ```ignore
+/// format("{}.{}", 10, 11)
+/// //            ^^^^
+/// ```
+pub fn format_arg_removal_span(format_args: &FormatArgs, index: usize) -> Option<Span> {
+    let ctxt = format_args.span.ctxt();
+
+    let current = hygiene::walk_chain(format_args.arguments.by_index(index)?.expr.span, ctxt);
+
+    let prev = if index == 0 {
+        format_args.span
+    } else {
+        hygiene::walk_chain(format_args.arguments.by_index(index - 1)?.expr.span, ctxt)
+    };
+
+    Some(current.with_lo(prev.hi()))
+}
+
 /// The format string doesn't exist in the HIR, so we reassemble it from source code
 #[derive(Debug)]
 pub struct FormatString {