about summary refs log tree commit diff
path: root/library/std/src
diff options
context:
space:
mode:
Diffstat (limited to 'library/std/src')
-rw-r--r--library/std/src/sys/unix/fs.rs3
-rw-r--r--library/std/src/sys/unix/kernel_copy.rs54
-rw-r--r--library/std/src/sys/unix/kernel_copy/tests.rs13
3 files changed, 54 insertions, 16 deletions
diff --git a/library/std/src/sys/unix/fs.rs b/library/std/src/sys/unix/fs.rs
index e2f0870ef0e..d1b0ad9e5f8 100644
--- a/library/std/src/sys/unix/fs.rs
+++ b/library/std/src/sys/unix/fs.rs
@@ -1211,7 +1211,8 @@ pub fn copy(from: &Path, to: &Path) -> io::Result<u64> {
     use super::kernel_copy::{copy_regular_files, CopyResult};
 
     match copy_regular_files(reader.as_raw_fd(), writer.as_raw_fd(), max_len) {
-        CopyResult::Ended(result) => result,
+        CopyResult::Ended(bytes) => Ok(bytes),
+        CopyResult::Error(e, _) => Err(e),
         CopyResult::Fallback(written) => match io::copy::generic_copy(&mut reader, &mut writer) {
             Ok(bytes) => Ok(bytes + written),
             Err(e) => Err(e),
diff --git a/library/std/src/sys/unix/kernel_copy.rs b/library/std/src/sys/unix/kernel_copy.rs
index f08c828ecff..5bfac803153 100644
--- a/library/std/src/sys/unix/kernel_copy.rs
+++ b/library/std/src/sys/unix/kernel_copy.rs
@@ -167,10 +167,11 @@ impl<R: CopyRead, W: CopyWrite> SpecCopy for Copier<'_, '_, R, W> {
 
             if input_meta.copy_file_range_candidate() && output_meta.copy_file_range_candidate() {
                 let result = copy_regular_files(readfd, writefd, max_write);
+                result.update_take(reader);
 
                 match result {
-                    CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
-                    CopyResult::Ended(err) => return err,
+                    CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
+                    CopyResult::Error(e, _) => return Err(e),
                     CopyResult::Fallback(bytes) => written += bytes,
                 }
             }
@@ -182,20 +183,22 @@ impl<R: CopyRead, W: CopyWrite> SpecCopy for Copier<'_, '_, R, W> {
             // fall back to the generic copy loop.
             if input_meta.potential_sendfile_source() {
                 let result = sendfile_splice(SpliceMode::Sendfile, readfd, writefd, max_write);
+                result.update_take(reader);
 
                 match result {
-                    CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
-                    CopyResult::Ended(err) => return err,
+                    CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
+                    CopyResult::Error(e, _) => return Err(e),
                     CopyResult::Fallback(bytes) => written += bytes,
                 }
             }
 
             if input_meta.maybe_fifo() || output_meta.maybe_fifo() {
                 let result = sendfile_splice(SpliceMode::Splice, readfd, writefd, max_write);
+                result.update_take(reader);
 
                 match result {
-                    CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
-                    CopyResult::Ended(err) => return err,
+                    CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
+                    CopyResult::Error(e, _) => return Err(e),
                     CopyResult::Fallback(0) => { /* use the fallback below */ }
                     CopyResult::Fallback(_) => {
                         unreachable!("splice should not return > 0 bytes on the fallback path")
@@ -225,6 +228,9 @@ trait CopyRead: Read {
         Ok(0)
     }
 
+    /// Updates `Take` wrappers to remove the number of bytes copied.
+    fn taken(&mut self, _bytes: u64) {}
+
     /// The minimum of the limit of all `Take<_>` wrappers, `u64::MAX` otherwise.
     /// This method does not account for data `BufReader` buffers and would underreport
     /// the limit of a `Take<BufReader<Take<_>>>` type. Thus its result is only valid
@@ -251,6 +257,10 @@ where
         (**self).drain_to(writer, limit)
     }
 
+    fn taken(&mut self, bytes: u64) {
+        (**self).taken(bytes);
+    }
+
     fn min_limit(&self) -> u64 {
         (**self).min_limit()
     }
@@ -407,6 +417,11 @@ impl<T: CopyRead> CopyRead for Take<T> {
         Ok(bytes_drained)
     }
 
+    fn taken(&mut self, bytes: u64) {
+        self.set_limit(self.limit() - bytes);
+        self.get_mut().taken(bytes);
+    }
+
     fn min_limit(&self) -> u64 {
         min(Take::limit(self), self.get_ref().min_limit())
     }
@@ -432,6 +447,10 @@ impl<T: CopyRead> CopyRead for BufReader<T> {
         Ok(bytes as u64 + inner_bytes)
     }
 
+    fn taken(&mut self, bytes: u64) {
+        self.get_mut().taken(bytes);
+    }
+
     fn min_limit(&self) -> u64 {
         self.get_ref().min_limit()
     }
@@ -457,10 +476,21 @@ fn fd_to_meta<T: AsRawFd>(fd: &T) -> FdMeta {
 }
 
 pub(super) enum CopyResult {
-    Ended(Result<u64>),
+    Ended(u64),
+    Error(Error, u64),
     Fallback(u64),
 }
 
+impl CopyResult {
+    fn update_take(&self, reader: &mut impl CopyRead) {
+        match *self {
+            CopyResult::Fallback(bytes)
+            | CopyResult::Ended(bytes)
+            | CopyResult::Error(_, bytes) => reader.taken(bytes),
+        }
+    }
+}
+
 /// linux-specific implementation that will attempt to use copy_file_range for copy offloading
 /// as the name says, it only works on regular files
 ///
@@ -527,7 +557,7 @@ pub(super) fn copy_regular_files(reader: RawFd, writer: RawFd, max_len: u64) ->
                 // - copying from an overlay filesystem in docker. reported to occur on fedora 32.
                 return CopyResult::Fallback(0);
             }
-            Ok(0) => return CopyResult::Ended(Ok(written)), // reached EOF
+            Ok(0) => return CopyResult::Ended(written), // reached EOF
             Ok(ret) => written += ret as u64,
             Err(err) => {
                 return match err.raw_os_error() {
@@ -545,12 +575,12 @@ pub(super) fn copy_regular_files(reader: RawFd, writer: RawFd, max_len: u64) ->
                         assert_eq!(written, 0);
                         CopyResult::Fallback(0)
                     }
-                    _ => CopyResult::Ended(Err(err)),
+                    _ => CopyResult::Error(err, written),
                 };
             }
         }
     }
-    CopyResult::Ended(Ok(written))
+    CopyResult::Ended(written)
 }
 
 #[derive(PartialEq)]
@@ -623,10 +653,10 @@ fn sendfile_splice(mode: SpliceMode, reader: RawFd, writer: RawFd, len: u64) ->
                     Some(os_err) if mode == SpliceMode::Sendfile && os_err == libc::EOVERFLOW => {
                         CopyResult::Fallback(written)
                     }
-                    _ => CopyResult::Ended(Err(err)),
+                    _ => CopyResult::Error(err, written),
                 };
             }
         }
     }
-    CopyResult::Ended(Ok(written))
+    CopyResult::Ended(written)
 }
diff --git a/library/std/src/sys/unix/kernel_copy/tests.rs b/library/std/src/sys/unix/kernel_copy/tests.rs
index 554ebd94022..3937a1ffa38 100644
--- a/library/std/src/sys/unix/kernel_copy/tests.rs
+++ b/library/std/src/sys/unix/kernel_copy/tests.rs
@@ -42,8 +42,15 @@ fn copy_specialization() -> Result<()> {
         assert_eq!(sink.buffer(), b"wxyz");
 
         let copied = crate::io::copy(&mut source, &mut sink)?;
-        assert_eq!(copied, 10);
-        assert_eq!(sink.buffer().len(), 0);
+        assert_eq!(copied, 10, "copy obeyed limit imposed by Take");
+        assert_eq!(sink.buffer().len(), 0, "sink buffer was flushed");
+        assert_eq!(source.limit(), 0, "outer Take was exhausted");
+        assert_eq!(source.get_ref().buffer().len(), 0, "source buffer should be drained");
+        assert_eq!(
+            source.get_ref().get_ref().limit(),
+            1,
+            "inner Take allowed reading beyond end of file, some bytes should be left"
+        );
 
         let mut sink = sink.into_inner()?;
         sink.seek(SeekFrom::Start(0))?;
@@ -210,7 +217,7 @@ fn bench_socket_pipe_socket_copy(b: &mut test::Bencher) {
     );
 
     match probe {
-        CopyResult::Ended(Ok(1)) => {
+        CopyResult::Ended(1) => {
             // splice works
         }
         _ => {