about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src/check_null.rs
blob: 0b6c0ceaac19962c562dd7be38f786e66ab57a23 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use rustc_index::IndexVec;
use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::*;
use rustc_middle::ty::{Ty, TyCtxt};
use rustc_session::Session;

use crate::check_pointers::{BorrowCheckMode, PointerCheck, check_pointers};

pub(super) struct CheckNull;

impl<'tcx> crate::MirPass<'tcx> for CheckNull {
    fn is_enabled(&self, sess: &Session) -> bool {
        sess.ub_checks()
    }

    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
        check_pointers(tcx, body, &[], insert_null_check, BorrowCheckMode::IncludeBorrows);
    }

    fn is_required(&self) -> bool {
        true
    }
}

fn insert_null_check<'tcx>(
    tcx: TyCtxt<'tcx>,
    pointer: Place<'tcx>,
    pointee_ty: Ty<'tcx>,
    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
    stmts: &mut Vec<Statement<'tcx>>,
    source_info: SourceInfo,
) -> PointerCheck<'tcx> {
    // Cast the pointer to a *const ().
    let const_raw_ptr = Ty::new_imm_ptr(tcx, tcx.types.unit);
    let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr);
    let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into();
    stmts
        .push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) });

    // Transmute the pointer to a usize (equivalent to `ptr.addr()`).
    let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize);
    let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
    stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });

    // Get size of the pointee (zero-sized reads and writes are allowed).
    let rvalue = Rvalue::NullaryOp(NullOp::SizeOf, pointee_ty);
    let sizeof_pointee =
        local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
    stmts.push(Statement {
        source_info,
        kind: StatementKind::Assign(Box::new((sizeof_pointee, rvalue))),
    });

    // Check that the pointee is not a ZST.
    let zero = Operand::Constant(Box::new(ConstOperand {
        span: source_info.span,
        user_ty: None,
        const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), tcx.types.usize),
    }));
    let is_pointee_no_zst =
        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
    stmts.push(Statement {
        source_info,
        kind: StatementKind::Assign(Box::new((
            is_pointee_no_zst,
            Rvalue::BinaryOp(BinOp::Ne, Box::new((Operand::Copy(sizeof_pointee), zero.clone()))),
        ))),
    });

    // Check whether the pointer is null.
    let is_null = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
    stmts.push(Statement {
        source_info,
        kind: StatementKind::Assign(Box::new((
            is_null,
            Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(addr), zero))),
        ))),
    });

    // We want to throw an exception if the pointer is null and doesn't point to a ZST.
    let should_throw_exception =
        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
    stmts.push(Statement {
        source_info,
        kind: StatementKind::Assign(Box::new((
            should_throw_exception,
            Rvalue::BinaryOp(
                BinOp::BitAnd,
                Box::new((Operand::Copy(is_null), Operand::Copy(is_pointee_no_zst))),
            ),
        ))),
    });

    // The final condition whether this pointer usage is ok or not.
    let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
    stmts.push(Statement {
        source_info,
        kind: StatementKind::Assign(Box::new((
            is_ok,
            Rvalue::UnaryOp(UnOp::Not, Operand::Copy(should_throw_exception)),
        ))),
    });

    // Emit a PointerCheck that asserts on the condition and otherwise triggers
    // a AssertKind::NullPointerDereference.
    PointerCheck {
        cond: Operand::Copy(is_ok),
        assert_kind: Box::new(AssertKind::NullPointerDereference),
    }
}