about summary refs log tree commit diff
diff options
context:
space:
mode:
authorTyson Nottingham <tgnottingham@gmail.com>2020-12-15 14:18:02 -0800
committerTyson Nottingham <tgnottingham@gmail.com>2021-04-13 09:48:58 -0700
commit72aecbfd01e73da710ee35826f28f41d5d0cfebe (patch)
tree9159f11256658e28b87e14ecdd3db99a246f4152
parent5fd9372c1118b213b15a0e75bec7fab3b2e00260 (diff)
downloadrust-72aecbfd01e73da710ee35826f28f41d5d0cfebe.tar.gz
rust-72aecbfd01e73da710ee35826f28f41d5d0cfebe.zip
BufWriter: handle possibility of overflow
-rw-r--r--library/std/src/io/buffered/bufwriter.rs54
1 files changed, 39 insertions, 15 deletions
diff --git a/library/std/src/io/buffered/bufwriter.rs b/library/std/src/io/buffered/bufwriter.rs
index f0ff186d99b..a9fc450de31 100644
--- a/library/std/src/io/buffered/bufwriter.rs
+++ b/library/std/src/io/buffered/bufwriter.rs
@@ -190,7 +190,7 @@ impl<W: Write> BufWriter<W> {
     /// data. Writes as much as possible without exceeding capacity. Returns
     /// the number of bytes written.
     pub(super) fn write_to_buf(&mut self, buf: &[u8]) -> usize {
-        let available = self.buf.capacity() - self.buf.len();
+        let available = self.spare_capacity();
         let amt_to_buffer = available.min(buf.len());
 
         // SAFETY: `amt_to_buffer` is <= buffer's spare capacity by construction.
@@ -353,7 +353,7 @@ impl<W: Write> BufWriter<W> {
     // or their write patterns are somewhat pathological.
     #[inline(never)]
     fn write_cold(&mut self, buf: &[u8]) -> io::Result<usize> {
-        if self.buf.len() + buf.len() > self.buf.capacity() {
+        if buf.len() > self.spare_capacity() {
             self.flush_buf()?;
         }
 
@@ -371,7 +371,7 @@ impl<W: Write> BufWriter<W> {
 
             // SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and
             // we entered this else block because `buf.len() < self.buf.capacity()`.
-            // Therefore, `self.buf.len() + buf.len() <= self.buf.capacity()`.
+            // Therefore, `buf.len() <= self.buf.capacity() - self.buf.len()`.
             unsafe {
                 self.write_to_buffer_unchecked(buf);
             }
@@ -391,7 +391,8 @@ impl<W: Write> BufWriter<W> {
         // by calling `self.get_mut().write_all()` directly, which avoids
         // round trips through the buffer in the event of a series of partial
         // writes in some circumstances.
-        if self.buf.len() + buf.len() > self.buf.capacity() {
+
+        if buf.len() > self.spare_capacity() {
             self.flush_buf()?;
         }
 
@@ -409,7 +410,7 @@ impl<W: Write> BufWriter<W> {
 
             // SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and
             // we entered this else block because `buf.len() < self.buf.capacity()`.
-            // Therefore, `self.buf.len() + buf.len() <= self.buf.capacity()`.
+            // Therefore, `buf.len() <= self.buf.capacity() - self.buf.len()`.
             unsafe {
                 self.write_to_buffer_unchecked(buf);
             }
@@ -418,11 +419,11 @@ impl<W: Write> BufWriter<W> {
         }
     }
 
-    // SAFETY: Requires `self.buf.len() + buf.len() <= self.buf.capacity()`,
+    // SAFETY: Requires `buf.len() <= self.buf.capacity() - self.buf.len()`,
     // i.e., that input buffer length is less than or equal to spare capacity.
     #[inline(always)]
     unsafe fn write_to_buffer_unchecked(&mut self, buf: &[u8]) {
-        debug_assert!(self.buf.len() + buf.len() <= self.buf.capacity());
+        debug_assert!(buf.len() <= self.spare_capacity());
         let old_len = self.buf.len();
         let buf_len = buf.len();
         let src = buf.as_ptr();
@@ -430,6 +431,11 @@ impl<W: Write> BufWriter<W> {
         ptr::copy_nonoverlapping(src, dst, buf_len);
         self.buf.set_len(old_len + buf_len);
     }
+
+    #[inline]
+    fn spare_capacity(&self) -> usize {
+        self.buf.capacity() - self.buf.len()
+    }
 }
 
 #[unstable(feature = "bufwriter_into_raw_parts", issue = "80690")]
@@ -505,7 +511,7 @@ impl<W: Write> Write for BufWriter<W> {
     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
         // Use < instead of <= to avoid a needless trip through the buffer in some cases.
         // See `write_cold` for details.
-        if self.buf.len() + buf.len() < self.buf.capacity() {
+        if buf.len() < self.spare_capacity() {
             // SAFETY: safe by above conditional.
             unsafe {
                 self.write_to_buffer_unchecked(buf);
@@ -521,7 +527,7 @@ impl<W: Write> Write for BufWriter<W> {
     fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
         // Use < instead of <= to avoid a needless trip through the buffer in some cases.
         // See `write_all_cold` for details.
-        if self.buf.len() + buf.len() < self.buf.capacity() {
+        if buf.len() < self.spare_capacity() {
             // SAFETY: safe by above conditional.
             unsafe {
                 self.write_to_buffer_unchecked(buf);
@@ -537,16 +543,31 @@ impl<W: Write> Write for BufWriter<W> {
         // FIXME: Consider applying `#[inline]` / `#[inline(never)]` optimizations already applied
         // to `write` and `write_all`. The performance benefits can be significant. See #79930.
         if self.get_ref().is_write_vectored() {
-            let total_len = bufs.iter().map(|b| b.len()).sum::<usize>();
-            if self.buf.len() + total_len > self.buf.capacity() {
+            // We have to handle the possibility that the total length of the buffers overflows
+            // `usize` (even though this can only happen if multiple `IoSlice`s reference the
+            // same underlying buffer, as otherwise the buffers wouldn't fit in memory). If the
+            // computation overflows, then surely the input cannot fit in our buffer, so we forward
+            // to the inner writer's `write_vectored` method to let it handle it appropriately.
+            let saturated_total_len =
+                bufs.iter().fold(0usize, |acc, b| acc.saturating_add(b.len()));
+
+            if saturated_total_len > self.spare_capacity() {
+                // Flush if the total length of the input exceeds our buffer's spare capacity.
+                // If we would have overflowed, this condition also holds, and we need to flush.
                 self.flush_buf()?;
             }
-            if total_len >= self.buf.capacity() {
+
+            if saturated_total_len >= self.buf.capacity() {
+                // Forward to our inner writer if the total length of the input is greater than or
+                // equal to our buffer capacity. If we would have overflowed, this condition also
+                // holds, and we punt to the inner writer.
                 self.panicked = true;
                 let r = self.get_mut().write_vectored(bufs);
                 self.panicked = false;
                 r
             } else {
+                // `saturated_total_len < self.buf.capacity()` implies that we did not saturate.
+
                 // SAFETY: We checked whether or not the spare capacity was large enough above. If
                 // it was, then we're safe already. If it wasn't, we flushed, making sufficient
                 // room for any input <= the buffer size, which includes this input.
@@ -554,14 +575,14 @@ impl<W: Write> Write for BufWriter<W> {
                     bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b));
                 };
 
-                Ok(total_len)
+                Ok(saturated_total_len)
             }
         } else {
             let mut iter = bufs.iter();
             let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) {
                 // This is the first non-empty slice to write, so if it does
                 // not fit in the buffer, we still get to flush and proceed.
-                if self.buf.len() + buf.len() > self.buf.capacity() {
+                if buf.len() > self.spare_capacity() {
                     self.flush_buf()?;
                 }
                 if buf.len() >= self.buf.capacity() {
@@ -586,12 +607,15 @@ impl<W: Write> Write for BufWriter<W> {
             };
             debug_assert!(total_written != 0);
             for buf in iter {
-                if self.buf.len() + buf.len() <= self.buf.capacity() {
+                if buf.len() <= self.spare_capacity() {
                     // SAFETY: safe by above conditional.
                     unsafe {
                         self.write_to_buffer_unchecked(buf);
                     }
 
+                    // This cannot overflow `usize`. If we are here, we've written all of the bytes
+                    // so far to our buffer, and we've ensured that we never exceed the buffer's
+                    // capacity. Therefore, `total_written` <= `self.buf.capacity()` <= `usize::MAX`.
                     total_written += buf.len();
                 } else {
                     break;