diff options
| author | Lukas Wirth <lukastw97@gmail.com> | 2024-12-04 06:40:40 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-04 06:40:40 +0000 |
| commit | 1b54eea983c1f89056a54ca2dce67d0d1700ae6d (patch) | |
| tree | 52ff93ab4dfd5410f409161bbc323bd2d3b72267 | |
| parent | 03705274564aca483f6c15b09b2d94063debe351 (diff) | |
| parent | 9296578960673092625dc4955f153a52a8a93750 (diff) | |
| download | rust-1b54eea983c1f89056a54ca2dce67d0d1700ae6d.tar.gz rust-1b54eea983c1f89056a54ca2dce67d0d1700ae6d.zip | |
Merge pull request #18609 from ChayimFriedman2/unsafe-coverage
feat: Extend reported unsafe operations
7 files changed, 499 insertions, 113 deletions
diff --git a/src/tools/rust-analyzer/crates/hir-def/src/body.rs b/src/tools/rust-analyzer/crates/hir-def/src/body.rs index 5a386f6cf8d..d4a1120908f 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/body.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/body.rs @@ -408,7 +408,8 @@ impl Body { f(else_branch); } } - Expr::Let { expr, .. } => { + Expr::Let { expr, pat } => { + self.walk_exprs_in_pat(*pat, &mut f); f(*expr); } Expr::Block { statements, tail, .. } @@ -444,7 +445,10 @@ impl Body { } Expr::Match { expr, arms } => { f(*expr); - arms.iter().map(|arm| arm.expr).for_each(f); + arms.iter().for_each(|arm| { + f(arm.expr); + self.walk_exprs_in_pat(arm.pat, &mut f); + }); } Expr::Break { expr, .. } | Expr::Return { expr } @@ -505,6 +509,131 @@ impl Body { } } + pub fn walk_child_exprs_without_pats(&self, expr_id: ExprId, mut f: impl FnMut(ExprId)) { + let expr = &self[expr_id]; + match expr { + Expr::Continue { .. } + | Expr::Const(_) + | Expr::Missing + | Expr::Path(_) + | Expr::OffsetOf(_) + | Expr::Literal(_) + | Expr::Underscore => {} + Expr::InlineAsm(it) => it.operands.iter().for_each(|(_, op)| match op { + AsmOperand::In { expr, .. } + | AsmOperand::Out { expr: Some(expr), .. } + | AsmOperand::InOut { expr, .. } => f(*expr), + AsmOperand::SplitInOut { in_expr, out_expr, .. } => { + f(*in_expr); + if let Some(out_expr) = out_expr { + f(*out_expr); + } + } + AsmOperand::Out { expr: None, .. } + | AsmOperand::Const(_) + | AsmOperand::Label(_) + | AsmOperand::Sym(_) => (), + }), + Expr::If { condition, then_branch, else_branch } => { + f(*condition); + f(*then_branch); + if let &Some(else_branch) = else_branch { + f(else_branch); + } + } + Expr::Let { expr, .. } => { + f(*expr); + } + Expr::Block { statements, tail, .. } + | Expr::Unsafe { statements, tail, .. } + | Expr::Async { statements, tail, .. } => { + for stmt in statements.iter() { + match stmt { + Statement::Let { initializer, else_branch, .. } => { + if let &Some(expr) = initializer { + f(expr); + } + if let &Some(expr) = else_branch { + f(expr); + } + } + Statement::Expr { expr: expression, .. } => f(*expression), + Statement::Item(_) => (), + } + } + if let &Some(expr) = tail { + f(expr); + } + } + Expr::Loop { body, .. } => f(*body), + Expr::Call { callee, args, .. } => { + f(*callee); + args.iter().copied().for_each(f); + } + Expr::MethodCall { receiver, args, .. } => { + f(*receiver); + args.iter().copied().for_each(f); + } + Expr::Match { expr, arms } => { + f(*expr); + arms.iter().map(|arm| arm.expr).for_each(f); + } + Expr::Break { expr, .. } + | Expr::Return { expr } + | Expr::Yield { expr } + | Expr::Yeet { expr } => { + if let &Some(expr) = expr { + f(expr); + } + } + Expr::Become { expr } => f(*expr), + Expr::RecordLit { fields, spread, .. } => { + for field in fields.iter() { + f(field.expr); + } + if let &Some(expr) = spread { + f(expr); + } + } + Expr::Closure { body, .. } => { + f(*body); + } + Expr::BinaryOp { lhs, rhs, .. } => { + f(*lhs); + f(*rhs); + } + Expr::Range { lhs, rhs, .. } => { + if let &Some(lhs) = rhs { + f(lhs); + } + if let &Some(rhs) = lhs { + f(rhs); + } + } + Expr::Index { base, index, .. } => { + f(*base); + f(*index); + } + Expr::Field { expr, .. } + | Expr::Await { expr } + | Expr::Cast { expr, .. } + | Expr::Ref { expr, .. } + | Expr::UnaryOp { expr, .. } + | Expr::Box { expr } => { + f(*expr); + } + Expr::Tuple { exprs, .. } => exprs.iter().copied().for_each(f), + Expr::Array(a) => match a { + Array::ElementList { elements, .. } => elements.iter().copied().for_each(f), + Array::Repeat { initializer, repeat } => { + f(*initializer); + f(*repeat) + } + }, + &Expr::Assignment { target: _, value } => f(value), + } + } + pub fn walk_exprs_in_pat(&self, pat_id: PatId, f: &mut impl FnMut(ExprId)) { self.walk_pats(pat_id, &mut |pat| { if let Pat::Expr(expr) | Pat::ConstBlock(expr) = self[pat] { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics.rs b/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics.rs index af4d2c9fc04..30c02a2936d 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics.rs @@ -9,5 +9,5 @@ pub use crate::diagnostics::{ expr::{ record_literal_missing_fields, record_pattern_missing_fields, BodyValidationDiagnostic, }, - unsafe_check::{missing_unsafe, unsafe_expressions, UnsafeExpr}, + unsafe_check::{missing_unsafe, unsafe_expressions, InsideUnsafeBlock, UnsafetyReason}, }; diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/unsafe_check.rs b/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/unsafe_check.rs index c7f7fb7ad3d..193aaa52c26 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/unsafe_check.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/unsafe_check.rs @@ -1,12 +1,16 @@ //! Provides validations for unsafe code. Currently checks if unsafe functions are missing //! unsafe blocks. +use std::mem; + +use either::Either; use hir_def::{ body::Body, - hir::{Expr, ExprId, ExprOrPatId, Pat, UnaryOp}, - resolver::{resolver_for_expr, ResolveValueResult, Resolver, ValueNs}, + hir::{Expr, ExprId, ExprOrPatId, Pat, PatId, Statement, UnaryOp}, + path::Path, + resolver::{HasResolver, ResolveValueResult, Resolver, ValueNs}, type_ref::Rawness, - DefWithBodyId, + AdtId, DefWithBodyId, FieldId, VariantId, }; use crate::{ @@ -16,7 +20,10 @@ use crate::{ /// Returns `(unsafe_exprs, fn_is_unsafe)`. /// /// If `fn_is_unsafe` is false, `unsafe_exprs` are hard errors. If true, they're `unsafe_op_in_unsafe_fn`. -pub fn missing_unsafe(db: &dyn HirDatabase, def: DefWithBodyId) -> (Vec<ExprOrPatId>, bool) { +pub fn missing_unsafe( + db: &dyn HirDatabase, + def: DefWithBodyId, +) -> (Vec<(ExprOrPatId, UnsafetyReason)>, bool) { let _p = tracing::info_span!("missing_unsafe").entered(); let mut res = Vec::new(); @@ -30,111 +37,243 @@ pub fn missing_unsafe(db: &dyn HirDatabase, def: DefWithBodyId) -> (Vec<ExprOrPa let body = db.body(def); let infer = db.infer(def); - unsafe_expressions(db, &infer, def, &body, body.body_expr, &mut |expr| { - if !expr.inside_unsafe_block { - res.push(expr.node); + let mut callback = |node, inside_unsafe_block, reason| { + if inside_unsafe_block == InsideUnsafeBlock::No { + res.push((node, reason)); } - }); + }; + let mut visitor = UnsafeVisitor::new(db, &infer, &body, def, &mut callback); + visitor.walk_expr(body.body_expr); + + if !is_unsafe { + // Unsafety in function parameter patterns (that can only be union destructuring) + // cannot be inserted into an unsafe block, so even with `unsafe_op_in_unsafe_fn` + // it is turned off for unsafe functions. + for ¶m in &body.params { + visitor.walk_pat(param); + } + } (res, is_unsafe) } -pub struct UnsafeExpr { - pub node: ExprOrPatId, - pub inside_unsafe_block: bool, +#[derive(Debug, Clone, Copy)] +pub enum UnsafetyReason { + UnionField, + UnsafeFnCall, + InlineAsm, + RawPtrDeref, + MutableStatic, + ExternStatic, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InsideUnsafeBlock { + No, + Yes, } -// FIXME: Move this out, its not a diagnostic only thing anymore, and handle unsafe pattern accesses as well pub fn unsafe_expressions( db: &dyn HirDatabase, infer: &InferenceResult, def: DefWithBodyId, body: &Body, current: ExprId, - unsafe_expr_cb: &mut dyn FnMut(UnsafeExpr), + unsafe_expr_cb: &mut dyn FnMut(ExprOrPatId, InsideUnsafeBlock, UnsafetyReason), ) { - walk_unsafe( - db, - infer, - body, - &mut resolver_for_expr(db.upcast(), def, current), - def, - current, - false, - unsafe_expr_cb, - ) + let mut visitor = UnsafeVisitor::new(db, infer, body, def, unsafe_expr_cb); + _ = visitor.resolver.update_to_inner_scope(db.upcast(), def, current); + visitor.walk_expr(current); } -fn walk_unsafe( - db: &dyn HirDatabase, - infer: &InferenceResult, - body: &Body, - resolver: &mut Resolver, +struct UnsafeVisitor<'a> { + db: &'a dyn HirDatabase, + infer: &'a InferenceResult, + body: &'a Body, + resolver: Resolver, def: DefWithBodyId, - current: ExprId, - inside_unsafe_block: bool, - unsafe_expr_cb: &mut dyn FnMut(UnsafeExpr), -) { - let mut mark_unsafe_path = |path, node| { - let g = resolver.update_to_inner_scope(db.upcast(), def, current); - let hygiene = body.expr_or_pat_path_hygiene(node); - let value_or_partial = resolver.resolve_path_in_value_ns(db.upcast(), path, hygiene); - if let Some(ResolveValueResult::ValueNs(ValueNs::StaticId(id), _)) = value_or_partial { - let static_data = db.static_data(id); - if static_data.mutable || (static_data.is_extern && !static_data.has_safe_kw) { - unsafe_expr_cb(UnsafeExpr { node, inside_unsafe_block }); + inside_unsafe_block: InsideUnsafeBlock, + inside_assignment: bool, + inside_union_destructure: bool, + unsafe_expr_cb: &'a mut dyn FnMut(ExprOrPatId, InsideUnsafeBlock, UnsafetyReason), +} + +impl<'a> UnsafeVisitor<'a> { + fn new( + db: &'a dyn HirDatabase, + infer: &'a InferenceResult, + body: &'a Body, + def: DefWithBodyId, + unsafe_expr_cb: &'a mut dyn FnMut(ExprOrPatId, InsideUnsafeBlock, UnsafetyReason), + ) -> Self { + let resolver = def.resolver(db.upcast()); + Self { + db, + infer, + body, + resolver, + def, + inside_unsafe_block: InsideUnsafeBlock::No, + inside_assignment: false, + inside_union_destructure: false, + unsafe_expr_cb, + } + } + + fn call_cb(&mut self, node: ExprOrPatId, reason: UnsafetyReason) { + (self.unsafe_expr_cb)(node, self.inside_unsafe_block, reason); + } + + fn walk_pats_top(&mut self, pats: impl Iterator<Item = PatId>, parent_expr: ExprId) { + let guard = self.resolver.update_to_inner_scope(self.db.upcast(), self.def, parent_expr); + pats.for_each(|pat| self.walk_pat(pat)); + self.resolver.reset_to_guard(guard); + } + + fn walk_pat(&mut self, current: PatId) { + let pat = &self.body.pats[current]; + + if self.inside_union_destructure { + match pat { + Pat::Tuple { .. } + | Pat::Record { .. } + | Pat::Range { .. } + | Pat::Slice { .. } + | Pat::Path(..) + | Pat::Lit(..) + | Pat::Bind { .. } + | Pat::TupleStruct { .. } + | Pat::Ref { .. } + | Pat::Box { .. } + | Pat::Expr(..) + | Pat::ConstBlock(..) => self.call_cb(current.into(), UnsafetyReason::UnionField), + // `Or` only wraps other patterns, and `Missing`/`Wild` do not constitute a read. + Pat::Missing | Pat::Wild | Pat::Or(_) => {} } } - resolver.reset_to_guard(g); - }; - let expr = &body.exprs[current]; - match expr { - &Expr::Call { callee, .. } => { - if let Some(func) = infer[callee].as_fn_def(db) { - if is_fn_unsafe_to_call(db, func) { - unsafe_expr_cb(UnsafeExpr { node: current.into(), inside_unsafe_block }); + match pat { + Pat::Record { .. } => { + if let Some((AdtId::UnionId(_), _)) = self.infer[current].as_adt() { + let old_inside_union_destructure = + mem::replace(&mut self.inside_union_destructure, true); + self.body.walk_pats_shallow(current, |pat| self.walk_pat(pat)); + self.inside_union_destructure = old_inside_union_destructure; + return; } } - } - Expr::Path(path) => mark_unsafe_path(path, current.into()), - Expr::Ref { expr, rawness: Rawness::RawPtr, mutability: _ } => { - if let Expr::Path(_) = body.exprs[*expr] { - // Do not report unsafe for `addr_of[_mut]!(EXTERN_OR_MUT_STATIC)`, - // see https://github.com/rust-lang/rust/pull/125834. - return; + Pat::Path(path) => self.mark_unsafe_path(current.into(), path), + &Pat::ConstBlock(expr) => { + let old_inside_assignment = mem::replace(&mut self.inside_assignment, false); + self.walk_expr(expr); + self.inside_assignment = old_inside_assignment; } + &Pat::Expr(expr) => self.walk_expr(expr), + _ => {} } - Expr::MethodCall { .. } => { - if infer - .method_resolution(current) - .map(|(func, _)| is_fn_unsafe_to_call(db, func)) - .unwrap_or(false) - { - unsafe_expr_cb(UnsafeExpr { node: current.into(), inside_unsafe_block }); + + self.body.walk_pats_shallow(current, |pat| self.walk_pat(pat)); + } + + fn walk_expr(&mut self, current: ExprId) { + let expr = &self.body.exprs[current]; + let inside_assignment = mem::replace(&mut self.inside_assignment, false); + match expr { + &Expr::Call { callee, .. } => { + if let Some(func) = self.infer[callee].as_fn_def(self.db) { + if is_fn_unsafe_to_call(self.db, func) { + self.call_cb(current.into(), UnsafetyReason::UnsafeFnCall); + } + } } - } - Expr::UnaryOp { expr, op: UnaryOp::Deref } => { - if let TyKind::Raw(..) = &infer[*expr].kind(Interner) { - unsafe_expr_cb(UnsafeExpr { node: current.into(), inside_unsafe_block }); + Expr::Path(path) => { + let guard = + self.resolver.update_to_inner_scope(self.db.upcast(), self.def, current); + self.mark_unsafe_path(current.into(), path); + self.resolver.reset_to_guard(guard); } - } - Expr::Unsafe { .. } => { - return body.walk_child_exprs(current, |child| { - walk_unsafe(db, infer, body, resolver, def, child, true, unsafe_expr_cb); - }); - } - &Expr::Assignment { target, value: _ } => { - body.walk_pats(target, &mut |pat| { - if let Pat::Path(path) = &body[pat] { - mark_unsafe_path(path, pat.into()); + Expr::Ref { expr, rawness: Rawness::RawPtr, mutability: _ } => { + if let Expr::Path(_) = self.body.exprs[*expr] { + // Do not report unsafe for `addr_of[_mut]!(EXTERN_OR_MUT_STATIC)`, + // see https://github.com/rust-lang/rust/pull/125834. + return; + } + } + Expr::MethodCall { .. } => { + if self + .infer + .method_resolution(current) + .map(|(func, _)| is_fn_unsafe_to_call(self.db, func)) + .unwrap_or(false) + { + self.call_cb(current.into(), UnsafetyReason::UnsafeFnCall); } - }); + } + Expr::UnaryOp { expr, op: UnaryOp::Deref } => { + if let TyKind::Raw(..) = &self.infer[*expr].kind(Interner) { + self.call_cb(current.into(), UnsafetyReason::RawPtrDeref); + } + } + Expr::Unsafe { .. } => { + let old_inside_unsafe_block = + mem::replace(&mut self.inside_unsafe_block, InsideUnsafeBlock::Yes); + self.body.walk_child_exprs_without_pats(current, |child| self.walk_expr(child)); + self.inside_unsafe_block = old_inside_unsafe_block; + return; + } + &Expr::Assignment { target, value: _ } => { + let old_inside_assignment = mem::replace(&mut self.inside_assignment, true); + self.walk_pats_top(std::iter::once(target), current); + self.inside_assignment = old_inside_assignment; + } + Expr::InlineAsm(_) => self.call_cb(current.into(), UnsafetyReason::InlineAsm), + // rustc allows union assignment to propagate through field accesses and casts. + Expr::Cast { .. } => self.inside_assignment = inside_assignment, + Expr::Field { .. } => { + self.inside_assignment = inside_assignment; + if !inside_assignment { + if let Some(Either::Left(FieldId { parent: VariantId::UnionId(_), .. })) = + self.infer.field_resolution(current) + { + self.call_cb(current.into(), UnsafetyReason::UnionField); + } + } + } + Expr::Block { statements, .. } | Expr::Async { statements, .. } => { + self.walk_pats_top( + statements.iter().filter_map(|statement| match statement { + &Statement::Let { pat, .. } => Some(pat), + _ => None, + }), + current, + ); + } + Expr::Match { arms, .. } => { + self.walk_pats_top(arms.iter().map(|arm| arm.pat), current); + } + &Expr::Let { pat, .. } => { + self.walk_pats_top(std::iter::once(pat), current); + } + Expr::Closure { args, .. } => { + self.walk_pats_top(args.iter().copied(), current); + } + _ => {} } - _ => {} + + self.body.walk_child_exprs_without_pats(current, |child| self.walk_expr(child)); } - body.walk_child_exprs(current, |child| { - walk_unsafe(db, infer, body, resolver, def, child, inside_unsafe_block, unsafe_expr_cb); - }); + fn mark_unsafe_path(&mut self, node: ExprOrPatId, path: &Path) { + let hygiene = self.body.expr_or_pat_path_hygiene(node); + let value_or_partial = + self.resolver.resolve_path_in_value_ns(self.db.upcast(), path, hygiene); + if let Some(ResolveValueResult::ValueNs(ValueNs::StaticId(id), _)) = value_or_partial { + let static_data = self.db.static_data(id); + if static_data.mutable { + self.call_cb(node, UnsafetyReason::MutableStatic); + } else if static_data.is_extern && !static_data.has_safe_kw { + self.call_cb(node, UnsafetyReason::ExternStatic); + } + } + } } diff --git a/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs b/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs index 8297acde857..9ca021027d5 100644 --- a/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs +++ b/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs @@ -5,7 +5,9 @@ //! be expressed in terms of hir types themselves. pub use hir_ty::diagnostics::{CaseType, IncorrectCase}; use hir_ty::{ - db::HirDatabase, diagnostics::BodyValidationDiagnostic, CastError, InferenceDiagnostic, + db::HirDatabase, + diagnostics::{BodyValidationDiagnostic, UnsafetyReason}, + CastError, InferenceDiagnostic, }; use cfg::{CfgExpr, CfgOptions}; @@ -258,9 +260,10 @@ pub struct PrivateField { #[derive(Debug)] pub struct MissingUnsafe { - pub expr: InFile<AstPtr<Either<ast::Expr, ast::Pat>>>, + pub node: InFile<AstPtr<Either<ast::Expr, ast::Pat>>>, /// If true, the diagnostics is an `unsafe_op_in_unsafe_fn` lint instead of a hard error. pub only_lint: bool, + pub reason: UnsafetyReason, } #[derive(Debug)] diff --git a/src/tools/rust-analyzer/crates/hir/src/lib.rs b/src/tools/rust-analyzer/crates/hir/src/lib.rs index c9498b3aead..0b2ba56b1ff 100644 --- a/src/tools/rust-analyzer/crates/hir/src/lib.rs +++ b/src/tools/rust-analyzer/crates/hir/src/lib.rs @@ -147,6 +147,7 @@ pub use { }, hir_ty::{ consteval::ConstEvalError, + diagnostics::UnsafetyReason, display::{ClosureStyle, HirDisplay, HirDisplayError, HirWrite}, dyn_compatibility::{DynCompatibilityViolation, MethodViolationCode}, layout::LayoutError, @@ -1890,10 +1891,10 @@ impl DefWithBody { ); } - let (unafe_exprs, only_lint) = hir_ty::diagnostics::missing_unsafe(db, self.into()); - for expr in unafe_exprs { - match source_map.expr_or_pat_syntax(expr) { - Ok(expr) => acc.push(MissingUnsafe { expr, only_lint }.into()), + let (unsafe_exprs, only_lint) = hir_ty::diagnostics::missing_unsafe(db, self.into()); + for (node, reason) in unsafe_exprs { + match source_map.expr_or_pat_syntax(node) { + Ok(node) => acc.push(MissingUnsafe { node, only_lint, reason }.into()), Err(SyntheticSyntax) => { // FIXME: Here and elsewhere in this file, the `expr` was // desugared, report or assert that this doesn't happen. diff --git a/src/tools/rust-analyzer/crates/hir/src/source_analyzer.rs b/src/tools/rust-analyzer/crates/hir/src/source_analyzer.rs index c16454cff68..56ed81f053c 100644 --- a/src/tools/rust-analyzer/crates/hir/src/source_analyzer.rs +++ b/src/tools/rust-analyzer/crates/hir/src/source_analyzer.rs @@ -36,7 +36,7 @@ use hir_expand::{ use hir_ty::{ diagnostics::{ record_literal_missing_fields, record_pattern_missing_fields, unsafe_expressions, - UnsafeExpr, + InsideUnsafeBlock, }, lang_items::lang_items_for_bin_op, method_resolution, Adjustment, InferenceResult, Interner, Substitution, Ty, TyExt, TyKind, @@ -939,8 +939,8 @@ impl SourceAnalyzer { *def, body, expr_id, - &mut |UnsafeExpr { inside_unsafe_block, .. }| { - is_unsafe |= !inside_unsafe_block + &mut |_, inside_unsafe_block, _| { + is_unsafe |= inside_unsafe_block == InsideUnsafeBlock::No }, ) }; diff --git a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/missing_unsafe.rs b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/missing_unsafe.rs index a630d3c7c36..2bfdda35659 100644 --- a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/missing_unsafe.rs +++ b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/missing_unsafe.rs @@ -1,5 +1,5 @@ use hir::db::ExpandDatabase; -use hir::HirFileIdExt; +use hir::{HirFileIdExt, UnsafetyReason}; use ide_db::text_edit::TextEdit; use ide_db::{assists::Assist, source_change::SourceChange}; use syntax::{ast, SyntaxNode}; @@ -16,23 +16,35 @@ pub(crate) fn missing_unsafe(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsaf } else { DiagnosticCode::RustcHardError("E0133") }; + let operation = display_unsafety_reason(d.reason); Diagnostic::new_with_syntax_node_ptr( ctx, code, - "this operation is unsafe and requires an unsafe function or block", - d.expr.map(|it| it.into()), + format!("{operation} is unsafe and requires an unsafe function or block"), + d.node.map(|it| it.into()), ) .with_fixes(fixes(ctx, d)) } +fn display_unsafety_reason(reason: UnsafetyReason) -> &'static str { + match reason { + UnsafetyReason::UnionField => "access to union field", + UnsafetyReason::UnsafeFnCall => "call to unsafe function", + UnsafetyReason::InlineAsm => "use of inline assembly", + UnsafetyReason::RawPtrDeref => "dereference of raw pointer", + UnsafetyReason::MutableStatic => "use of mutable static", + UnsafetyReason::ExternStatic => "use of extern static", + } +} + fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Assist>> { // The fixit will not work correctly for macro expansions, so we don't offer it in that case. - if d.expr.file_id.is_macro() { + if d.node.file_id.is_macro() { return None; } - let root = ctx.sema.db.parse_or_expand(d.expr.file_id); - let node = d.expr.value.to_node(&root); + let root = ctx.sema.db.parse_or_expand(d.node.file_id); + let node = d.node.value.to_node(&root); let expr = node.syntax().ancestors().find_map(ast::Expr::cast)?; let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block(&expr)?; @@ -40,7 +52,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass let replacement = format!("unsafe {{ {} }}", node_to_add_unsafe_block.text()); let edit = TextEdit::replace(node_to_add_unsafe_block.text_range(), replacement); let source_change = - SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), edit); + SourceChange::from_text_edit(d.node.file_id.original_file(ctx.sema.db), edit); Some(vec![fix("add_unsafe", "Add unsafe block", source_change, expr.syntax().text_range())]) } @@ -110,7 +122,7 @@ fn main() { let x = &5_usize as *const usize; unsafe { let _y = *x; } let _z = *x; -} //^^💡 error: this operation is unsafe and requires an unsafe function or block +} //^^💡 error: dereference of raw pointer is unsafe and requires an unsafe function or block "#, ) } @@ -136,9 +148,9 @@ unsafe fn unsafe_fn() { fn main() { unsafe_fn(); - //^^^^^^^^^^^💡 error: this operation is unsafe and requires an unsafe function or block + //^^^^^^^^^^^💡 error: call to unsafe function is unsafe and requires an unsafe function or block HasUnsafe.unsafe_fn(); - //^^^^^^^^^^^^^^^^^^^^^💡 error: this operation is unsafe and requires an unsafe function or block + //^^^^^^^^^^^^^^^^^^^^^💡 error: call to unsafe function is unsafe and requires an unsafe function or block unsafe { unsafe_fn(); HasUnsafe.unsafe_fn(); @@ -162,7 +174,7 @@ static mut STATIC_MUT: Ty = Ty { a: 0 }; fn main() { let _x = STATIC_MUT.a; - //^^^^^^^^^^💡 error: this operation is unsafe and requires an unsafe function or block + //^^^^^^^^^^💡 error: use of mutable static is unsafe and requires an unsafe function or block unsafe { let _x = STATIC_MUT.a; } @@ -184,9 +196,9 @@ extern "C" { fn main() { let _x = EXTERN; - //^^^^^^💡 error: this operation is unsafe and requires an unsafe function or block + //^^^^^^💡 error: use of extern static is unsafe and requires an unsafe function or block let _x = EXTERN_MUT; - //^^^^^^^^^^💡 error: this operation is unsafe and requires an unsafe function or block + //^^^^^^^^^^💡 error: use of mutable static is unsafe and requires an unsafe function or block unsafe { let _x = EXTERN; let _x = EXTERN_MUT; @@ -234,7 +246,7 @@ extern "rust-intrinsic" { fn main() { let _ = bitreverse(12); let _ = floorf32(12.0); - //^^^^^^^^^^^^^^💡 error: this operation is unsafe and requires an unsafe function or block + //^^^^^^^^^^^^^^💡 error: call to unsafe function is unsafe and requires an unsafe function or block } "#, ); @@ -567,7 +579,7 @@ unsafe fn not_safe() -> u8 { fn main() { ed2021::safe(); ed2024::not_safe(); - //^^^^^^^^^^^^^^^^^^💡 error: this operation is unsafe and requires an unsafe function or block + //^^^^^^^^^^^^^^^^^^💡 error: call to unsafe function is unsafe and requires an unsafe function or block } "#, ) @@ -591,7 +603,7 @@ unsafe fn foo(p: *mut i32) { #![warn(unsafe_op_in_unsafe_fn)] unsafe fn foo(p: *mut i32) { *p = 123; - //^^💡 warn: this operation is unsafe and requires an unsafe function or block + //^^💡 warn: dereference of raw pointer is unsafe and requires an unsafe function or block } "#, ) @@ -618,17 +630,119 @@ unsafe extern { fn main() { f(); g(); - //^^^💡 error: this operation is unsafe and requires an unsafe function or block + //^^^💡 error: call to unsafe function is unsafe and requires an unsafe function or block h(); - //^^^💡 error: this operation is unsafe and requires an unsafe function or block + //^^^💡 error: call to unsafe function is unsafe and requires an unsafe function or block let _ = S1; let _ = S2; - //^^💡 error: this operation is unsafe and requires an unsafe function or block + //^^💡 error: use of extern static is unsafe and requires an unsafe function or block let _ = S3; - //^^💡 error: this operation is unsafe and requires an unsafe function or block + //^^💡 error: use of extern static is unsafe and requires an unsafe function or block +} +"#, + ); + } + + #[test] + fn no_unsafe_diagnostic_when_destructuring_union_with_wildcard() { + check_diagnostics( + r#" +union Union { field: i32 } +fn foo(v: &Union) { + let Union { field: _ } = v; + let Union { field: _ | _ } = v; + Union { field: _ } = *v; +} +"#, + ); + } + + #[test] + fn union_destructuring() { + check_diagnostics( + r#" +union Union { field: u8 } +fn foo(v @ Union { field: _field }: &Union) { + // ^^^^^^ error: access to union field is unsafe and requires an unsafe function or block + let Union { mut field } = v; + // ^^^^^^^^^💡 error: access to union field is unsafe and requires an unsafe function or block + let Union { field: 0..=255 } = v; + // ^^^^^^^💡 error: access to union field is unsafe and requires an unsafe function or block + let Union { field: 0 + // ^💡 error: access to union field is unsafe and requires an unsafe function or block + | 1..=255 } = v; + // ^^^^^^^💡 error: access to union field is unsafe and requires an unsafe function or block + Union { field } = *v; + // ^^^^^💡 error: access to union field is unsafe and requires an unsafe function or block + match v { + Union { field: _field } => {} + // ^^^^^^💡 error: access to union field is unsafe and requires an unsafe function or block + } + if let Union { field: _field } = v {} + // ^^^^^^💡 error: access to union field is unsafe and requires an unsafe function or block + (|&Union { field }| { _ = field; })(v); + // ^^^^^💡 error: access to union field is unsafe and requires an unsafe function or block +} +"#, + ); + } + + #[test] + fn union_field_access() { + check_diagnostics( + r#" +union Union { field: u8 } +fn foo(v: &Union) { + v.field; + // ^^^^^^^💡 error: access to union field is unsafe and requires an unsafe function or block } "#, ); } + + #[test] + fn inline_asm() { + check_diagnostics( + r#" +//- minicore: asm +fn foo() { + core::arch::asm!(""); + // ^^^^ error: use of inline assembly is unsafe and requires an unsafe function or block +} +"#, + ); + } + + #[test] + fn unsafe_op_in_unsafe_fn_dismissed_in_signature() { + check_diagnostics( + r#" +#![warn(unsafe_op_in_unsafe_fn)] +union Union { field: u32 } +unsafe fn foo(Union { field: _field }: Union) {} + "#, + ) + } + + #[test] + fn union_assignment_allowed() { + check_diagnostics( + r#" +union Union { field: u32 } +fn foo(mut v: Union) { + v.field = 123; + (v.field,) = (123,); + *&mut v.field = 123; + // ^^^^^^^💡 error: access to union field is unsafe and requires an unsafe function or block +} +struct Struct { field: u32 } +union Union2 { field: Struct } +fn bar(mut v: Union2) { + v.field.field = 123; +} + + "#, + ) + } } |
