diff options
| author | Bastian Kersting <bkersting@google.com> | 2024-12-17 13:00:22 +0000 |
|---|---|---|
| committer | Bastian Kersting <bkersting@google.com> | 2025-01-31 11:13:34 +0000 |
| commit | b151b513ba2b65c7506ec1a80f2712bbd09154d1 (patch) | |
| tree | fc1505e76399304e8da3aca3808ad1b494dfe4e8 /compiler | |
| parent | 851322b74db9ac91a1b9d206c5f80fc51a97f7c1 (diff) | |
| download | rust-b151b513ba2b65c7506ec1a80f2712bbd09154d1.tar.gz rust-b151b513ba2b65c7506ec1a80f2712bbd09154d1.zip | |
Insert null checks for pointer dereferences when debug assertions are enabled
Similar to how the alignment is already checked, this adds a check for null pointer dereferences in debug mode. It is implemented similarly to the alignment check as a MirPass. This is related to a 2025H1 project goal for better UB checks in debug mode: https://github.com/rust-lang/rust-project-goals/pull/177.
Diffstat (limited to 'compiler')
| -rw-r--r-- | compiler/rustc_codegen_cranelift/src/base.rs | 10 | ||||
| -rw-r--r-- | compiler/rustc_codegen_ssa/src/mir/block.rs | 5 | ||||
| -rw-r--r-- | compiler/rustc_const_eval/src/const_eval/machine.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_hir/src/lang_items.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_middle/messages.ftl | 3 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/mir/syntax.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/mir/terminator.rs | 6 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/mir/visit.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/check_null.rs | 110 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/check_pointers.rs | 3 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/lib.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_monomorphize/src/collector.rs | 3 | ||||
| -rw-r--r-- | compiler/rustc_smir/src/rustc_smir/convert/mir.rs | 3 | ||||
| -rw-r--r-- | compiler/rustc_span/src/symbol.rs | 1 | ||||
| -rw-r--r-- | compiler/stable_mir/src/mir/body.rs | 2 | ||||
| -rw-r--r-- | compiler/stable_mir/src/mir/pretty.rs | 3 | ||||
| -rw-r--r-- | compiler/stable_mir/src/mir/visit.rs | 5 |
17 files changed, 156 insertions, 5 deletions
diff --git a/compiler/rustc_codegen_cranelift/src/base.rs b/compiler/rustc_codegen_cranelift/src/base.rs index 34066eb83fc..c7c9c6236d1 100644 --- a/compiler/rustc_codegen_cranelift/src/base.rs +++ b/compiler/rustc_codegen_cranelift/src/base.rs @@ -417,6 +417,16 @@ fn codegen_fn_body(fx: &mut FunctionCx<'_, '_, '_>, start_block: Block) { Some(source_info.span), ); } + AssertKind::NullPointerDereference => { + let location = fx.get_caller_location(source_info).load_scalar(fx); + + codegen_panic_inner( + fx, + rustc_hir::LangItem::PanicNullPointerDereference, + &[location], + Some(source_info.span), + ) + } _ => { let location = fx.get_caller_location(source_info).load_scalar(fx); diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index b0a1dedd646..4be363ca9a2 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -713,6 +713,11 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // and `#[track_caller]` adds an implicit third argument. (LangItem::PanicMisalignedPointerDereference, vec![required, found, location]) } + AssertKind::NullPointerDereference => { + // It's `fn panic_null_pointer_dereference()`, + // `#[track_caller]` adds an implicit argument. + (LangItem::PanicNullPointerDereference, vec![location]) + } _ => { // It's `pub fn panic_...()` and `#[track_caller]` adds an implicit argument. (msg.panic_function(), vec![location]) diff --git a/compiler/rustc_const_eval/src/const_eval/machine.rs b/compiler/rustc_const_eval/src/const_eval/machine.rs index 6a339d69542..82e0a6e6666 100644 --- a/compiler/rustc_const_eval/src/const_eval/machine.rs +++ b/compiler/rustc_const_eval/src/const_eval/machine.rs @@ -508,6 +508,7 @@ impl<'tcx> interpret::Machine<'tcx> for CompileTimeMachine<'tcx> { found: eval_to_int(found)?, } } + NullPointerDereference => NullPointerDereference, }; Err(ConstEvalErrKind::AssertFailure(err)).into() } diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index 02bc069fc5f..d9759580e8f 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -316,6 +316,7 @@ language_item_table! { PanicAsyncFnResumedPanic, sym::panic_const_async_fn_resumed_panic, panic_const_async_fn_resumed_panic, Target::Fn, GenericRequirement::None; PanicAsyncGenFnResumedPanic, sym::panic_const_async_gen_fn_resumed_panic, panic_const_async_gen_fn_resumed_panic, Target::Fn, GenericRequirement::None; PanicGenFnNonePanic, sym::panic_const_gen_fn_none_panic, panic_const_gen_fn_none_panic, Target::Fn, GenericRequirement::None; + PanicNullPointerDereference, sym::panic_null_pointer_dereference, panic_null_pointer_dereference, Target::Fn, GenericRequirement::None; /// libstd panic entry point. Necessary for const eval to be able to catch it BeginPanic, sym::begin_panic, begin_panic_fn, Target::Fn, GenericRequirement::None; diff --git a/compiler/rustc_middle/messages.ftl b/compiler/rustc_middle/messages.ftl index 8dc529b4d7a..f74cea23c24 100644 --- a/compiler/rustc_middle/messages.ftl +++ b/compiler/rustc_middle/messages.ftl @@ -17,6 +17,9 @@ middle_assert_gen_resume_after_panic = `gen` fn or block cannot be further itera middle_assert_misaligned_ptr_deref = misaligned pointer dereference: address must be a multiple of {$required} but is {$found} +middle_assert_null_ptr_deref = + null pointer dereference occurred + middle_assert_op_overflow = attempt to compute `{$left} {$op} {$right}`, which would overflow diff --git a/compiler/rustc_middle/src/mir/syntax.rs b/compiler/rustc_middle/src/mir/syntax.rs index 5868b64f6b5..90a2f175bc9 100644 --- a/compiler/rustc_middle/src/mir/syntax.rs +++ b/compiler/rustc_middle/src/mir/syntax.rs @@ -1076,6 +1076,7 @@ pub enum AssertKind<O> { ResumedAfterReturn(CoroutineKind), ResumedAfterPanic(CoroutineKind), MisalignedPointerDereference { required: O, found: O }, + NullPointerDereference, } #[derive(Clone, Debug, PartialEq, TyEncodable, TyDecodable, Hash, HashStable)] diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs index c04a8251fbc..2fe116212eb 100644 --- a/compiler/rustc_middle/src/mir/terminator.rs +++ b/compiler/rustc_middle/src/mir/terminator.rs @@ -206,6 +206,7 @@ impl<O> AssertKind<O> { ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) => { LangItem::PanicGenFnNonePanic } + NullPointerDereference => LangItem::PanicNullPointerDereference, BoundsCheck { .. } | MisalignedPointerDereference { .. } => { bug!("Unexpected AssertKind") @@ -271,6 +272,7 @@ impl<O> AssertKind<O> { "\"misaligned pointer dereference: address must be a multiple of {{}} but is {{}}\", {required:?}, {found:?}" ) } + NullPointerDereference => write!(f, "\"null pointer dereference occured\""), ResumedAfterReturn(CoroutineKind::Coroutine(_)) => { write!(f, "\"coroutine resumed after completion\"") } @@ -341,7 +343,7 @@ impl<O> AssertKind<O> { ResumedAfterPanic(CoroutineKind::Coroutine(_)) => { middle_assert_coroutine_resume_after_panic } - + NullPointerDereference => middle_assert_null_ptr_deref, MisalignedPointerDereference { .. } => middle_assert_misaligned_ptr_deref, } } @@ -374,7 +376,7 @@ impl<O> AssertKind<O> { add!("left", format!("{left:#?}")); add!("right", format!("{right:#?}")); } - ResumedAfterReturn(_) | ResumedAfterPanic(_) => {} + ResumedAfterReturn(_) | ResumedAfterPanic(_) | NullPointerDereference => {} MisalignedPointerDereference { required, found } => { add!("required", format!("{required:#?}")); add!("found", format!("{found:#?}")); diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs index 95de08ce9c8..6319a7e65b9 100644 --- a/compiler/rustc_middle/src/mir/visit.rs +++ b/compiler/rustc_middle/src/mir/visit.rs @@ -636,7 +636,7 @@ macro_rules! make_mir_visitor { OverflowNeg(op) | DivisionByZero(op) | RemainderByZero(op) => { self.visit_operand(op, location); } - ResumedAfterReturn(_) | ResumedAfterPanic(_) => { + ResumedAfterReturn(_) | ResumedAfterPanic(_) | NullPointerDereference => { // Nothing to visit } MisalignedPointerDereference { required, found } => { diff --git a/compiler/rustc_mir_transform/src/check_null.rs b/compiler/rustc_mir_transform/src/check_null.rs new file mode 100644 index 00000000000..0b6c0ceaac1 --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_null.rs @@ -0,0 +1,110 @@ +use rustc_index::IndexVec; +use rustc_middle::mir::interpret::Scalar; +use rustc_middle::mir::*; +use rustc_middle::ty::{Ty, TyCtxt}; +use rustc_session::Session; + +use crate::check_pointers::{BorrowCheckMode, PointerCheck, check_pointers}; + +pub(super) struct CheckNull; + +impl<'tcx> crate::MirPass<'tcx> for CheckNull { + fn is_enabled(&self, sess: &Session) -> bool { + sess.ub_checks() + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + check_pointers(tcx, body, &[], insert_null_check, BorrowCheckMode::IncludeBorrows); + } + + fn is_required(&self) -> bool { + true + } +} + +fn insert_null_check<'tcx>( + tcx: TyCtxt<'tcx>, + pointer: Place<'tcx>, + pointee_ty: Ty<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + stmts: &mut Vec<Statement<'tcx>>, + source_info: SourceInfo, +) -> PointerCheck<'tcx> { + // Cast the pointer to a *const (). + let const_raw_ptr = Ty::new_imm_ptr(tcx, tcx.types.unit); + let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr); + let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into(); + stmts + .push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) }); + + // Transmute the pointer to a usize (equivalent to `ptr.addr()`). + let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize); + let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) }); + + // Get size of the pointee (zero-sized reads and writes are allowed). + let rvalue = Rvalue::NullaryOp(NullOp::SizeOf, pointee_ty); + let sizeof_pointee = + local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((sizeof_pointee, rvalue))), + }); + + // Check that the pointee is not a ZST. + let zero = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), tcx.types.usize), + })); + let is_pointee_no_zst = + local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_pointee_no_zst, + Rvalue::BinaryOp(BinOp::Ne, Box::new((Operand::Copy(sizeof_pointee), zero.clone()))), + ))), + }); + + // Check whether the pointer is null. + let is_null = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_null, + Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(addr), zero))), + ))), + }); + + // We want to throw an exception if the pointer is null and doesn't point to a ZST. + let should_throw_exception = + local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + should_throw_exception, + Rvalue::BinaryOp( + BinOp::BitAnd, + Box::new((Operand::Copy(is_null), Operand::Copy(is_pointee_no_zst))), + ), + ))), + }); + + // The final condition whether this pointer usage is ok or not. + let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_ok, + Rvalue::UnaryOp(UnOp::Not, Operand::Copy(should_throw_exception)), + ))), + }); + + // Emit a PointerCheck that asserts on the condition and otherwise triggers + // a AssertKind::NullPointerDereference. + PointerCheck { + cond: Operand::Copy(is_ok), + assert_kind: Box::new(AssertKind::NullPointerDereference), + } +} diff --git a/compiler/rustc_mir_transform/src/check_pointers.rs b/compiler/rustc_mir_transform/src/check_pointers.rs index 95d3ce05000..72460542f87 100644 --- a/compiler/rustc_mir_transform/src/check_pointers.rs +++ b/compiler/rustc_mir_transform/src/check_pointers.rs @@ -17,6 +17,7 @@ pub(crate) struct PointerCheck<'tcx> { /// [NonMutatingUseContext::SharedBorrow]. #[derive(Copy, Clone)] pub(crate) enum BorrowCheckMode { + IncludeBorrows, ExcludeBorrows, } @@ -168,7 +169,7 @@ impl<'a, 'tcx> PointerFinder<'a, 'tcx> { ) => true, PlaceContext::MutatingUse(MutatingUseContext::Borrow) | PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow) => { - !matches!(self.borrow_check_mode, BorrowCheckMode::ExcludeBorrows) + matches!(self.borrow_check_mode, BorrowCheckMode::IncludeBorrows) } _ => false, } diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 0df9ce0c3d3..c6a9e3ee0e6 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -119,6 +119,7 @@ declare_passes! { mod check_call_recursion : CheckCallRecursion, CheckDropRecursion; mod check_alignment : CheckAlignment; mod check_const_item_mutation : CheckConstItemMutation; + mod check_null : CheckNull; mod check_packed_ref : CheckPackedRef; mod check_undefined_transmutes : CheckUndefinedTransmutes; // This pass is public to allow external drivers to perform MIR cleanup @@ -643,6 +644,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &[ // Add some UB checks before any UB gets optimized away. &check_alignment::CheckAlignment, + &check_null::CheckNull, // Before inlining: trim down MIR with passes to reduce inlining work. // Has to be done before inlining, otherwise actual call will be almost always inlined. diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index d53848f7461..d22feb81369 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -814,6 +814,9 @@ impl<'a, 'tcx> MirVisitor<'tcx> for MirUsedCollector<'a, 'tcx> { mir::AssertKind::MisalignedPointerDereference { .. } => { push_mono_lang_item(self, LangItem::PanicMisalignedPointerDereference); } + mir::AssertKind::NullPointerDereference => { + push_mono_lang_item(self, LangItem::PanicNullPointerDereference); + } _ => { push_mono_lang_item(self, msg.panic_function()); } diff --git a/compiler/rustc_smir/src/rustc_smir/convert/mir.rs b/compiler/rustc_smir/src/rustc_smir/convert/mir.rs index aee98d7d410..150ec02b7db 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/mir.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/mir.rs @@ -496,6 +496,9 @@ impl<'tcx> Stable<'tcx> for mir::AssertMessage<'tcx> { found: found.stable(tables), } } + AssertKind::NullPointerDereference => { + stable_mir::mir::AssertMessage::NullPointerDereference + } } } } diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 6f1d3a74a81..37ae59bf207 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -1476,6 +1476,7 @@ symbols! { panic_location, panic_misaligned_pointer_dereference, panic_nounwind, + panic_null_pointer_dereference, panic_runtime, panic_str_2015, panic_unwind, diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs index 8686169c15d..eec6cd8d49b 100644 --- a/compiler/stable_mir/src/mir/body.rs +++ b/compiler/stable_mir/src/mir/body.rs @@ -251,6 +251,7 @@ pub enum AssertMessage { ResumedAfterReturn(CoroutineKind), ResumedAfterPanic(CoroutineKind), MisalignedPointerDereference { required: Operand, found: Operand }, + NullPointerDereference, } impl AssertMessage { @@ -306,6 +307,7 @@ impl AssertMessage { AssertMessage::MisalignedPointerDereference { .. } => { Ok("misaligned pointer dereference") } + AssertMessage::NullPointerDereference => Ok("null pointer dereference occured"), } } } diff --git a/compiler/stable_mir/src/mir/pretty.rs b/compiler/stable_mir/src/mir/pretty.rs index 81981bce202..6638f93e555 100644 --- a/compiler/stable_mir/src/mir/pretty.rs +++ b/compiler/stable_mir/src/mir/pretty.rs @@ -298,6 +298,9 @@ fn pretty_assert_message<W: Write>(writer: &mut W, msg: &AssertMessage) -> io::R "\"misaligned pointer dereference: address must be a multiple of {{}} but is {{}}\",{pretty_required}, {pretty_found}" ) } + AssertMessage::NullPointerDereference => { + write!(writer, "\"null pointer dereference occured.\"") + } AssertMessage::ResumedAfterReturn(_) | AssertMessage::ResumedAfterPanic(_) => { write!(writer, "{}", msg.description().unwrap()) } diff --git a/compiler/stable_mir/src/mir/visit.rs b/compiler/stable_mir/src/mir/visit.rs index 79efb83cebd..d985a98fcbf 100644 --- a/compiler/stable_mir/src/mir/visit.rs +++ b/compiler/stable_mir/src/mir/visit.rs @@ -426,7 +426,10 @@ pub trait MirVisitor { | AssertMessage::RemainderByZero(op) => { self.visit_operand(op, location); } - AssertMessage::ResumedAfterReturn(_) | AssertMessage::ResumedAfterPanic(_) => { //nothing to visit + AssertMessage::ResumedAfterReturn(_) + | AssertMessage::ResumedAfterPanic(_) + | AssertMessage::NullPointerDereference => { + //nothing to visit } AssertMessage::MisalignedPointerDereference { required, found } => { self.visit_operand(required, location); |
