about summary refs log tree commit diff
path: root/library/std/src
diff options
context:
space:
mode:
authorThe8472 <git@infinite-source.de>2020-12-02 23:35:40 +0100
committerThe8472 <git@infinite-source.de>2020-12-03 00:02:01 +0100
commita9b1381b8dd0704966a5c2ab7e4ca0ff96d05a18 (patch)
tree10f207747b37b100d9fa8ce6d4bac0c69bd363c1 /library/std/src
parent9b390e73db26f17415dfeab2e49f13f04ad2474b (diff)
downloadrust-a9b1381b8dd0704966a5c2ab7e4ca0ff96d05a18.tar.gz
rust-a9b1381b8dd0704966a5c2ab7e4ca0ff96d05a18.zip
fix copy specialization not updating Take wrappers
Diffstat (limited to 'library/std/src')
-rw-r--r--library/std/src/sys/unix/fs.rs3
-rw-r--r--library/std/src/sys/unix/kernel_copy.rs54
-rw-r--r--library/std/src/sys/unix/kernel_copy/tests.rs2
3 files changed, 45 insertions, 14 deletions
diff --git a/library/std/src/sys/unix/fs.rs b/library/std/src/sys/unix/fs.rs
index e2f0870ef0e..d1b0ad9e5f8 100644
--- a/library/std/src/sys/unix/fs.rs
+++ b/library/std/src/sys/unix/fs.rs
@@ -1211,7 +1211,8 @@ pub fn copy(from: &Path, to: &Path) -> io::Result<u64> {
     use super::kernel_copy::{copy_regular_files, CopyResult};
 
     match copy_regular_files(reader.as_raw_fd(), writer.as_raw_fd(), max_len) {
-        CopyResult::Ended(result) => result,
+        CopyResult::Ended(bytes) => Ok(bytes),
+        CopyResult::Error(e, _) => Err(e),
         CopyResult::Fallback(written) => match io::copy::generic_copy(&mut reader, &mut writer) {
             Ok(bytes) => Ok(bytes + written),
             Err(e) => Err(e),
diff --git a/library/std/src/sys/unix/kernel_copy.rs b/library/std/src/sys/unix/kernel_copy.rs
index f08c828ecff..5bfac803153 100644
--- a/library/std/src/sys/unix/kernel_copy.rs
+++ b/library/std/src/sys/unix/kernel_copy.rs
@@ -167,10 +167,11 @@ impl<R: CopyRead, W: CopyWrite> SpecCopy for Copier<'_, '_, R, W> {
 
             if input_meta.copy_file_range_candidate() && output_meta.copy_file_range_candidate() {
                 let result = copy_regular_files(readfd, writefd, max_write);
+                result.update_take(reader);
 
                 match result {
-                    CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
-                    CopyResult::Ended(err) => return err,
+                    CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
+                    CopyResult::Error(e, _) => return Err(e),
                     CopyResult::Fallback(bytes) => written += bytes,
                 }
             }
@@ -182,20 +183,22 @@ impl<R: CopyRead, W: CopyWrite> SpecCopy for Copier<'_, '_, R, W> {
             // fall back to the generic copy loop.
             if input_meta.potential_sendfile_source() {
                 let result = sendfile_splice(SpliceMode::Sendfile, readfd, writefd, max_write);
+                result.update_take(reader);
 
                 match result {
-                    CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
-                    CopyResult::Ended(err) => return err,
+                    CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
+                    CopyResult::Error(e, _) => return Err(e),
                     CopyResult::Fallback(bytes) => written += bytes,
                 }
             }
 
             if input_meta.maybe_fifo() || output_meta.maybe_fifo() {
                 let result = sendfile_splice(SpliceMode::Splice, readfd, writefd, max_write);
+                result.update_take(reader);
 
                 match result {
-                    CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
-                    CopyResult::Ended(err) => return err,
+                    CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
+                    CopyResult::Error(e, _) => return Err(e),
                     CopyResult::Fallback(0) => { /* use the fallback below */ }
                     CopyResult::Fallback(_) => {
                         unreachable!("splice should not return > 0 bytes on the fallback path")
@@ -225,6 +228,9 @@ trait CopyRead: Read {
         Ok(0)
     }
 
+    /// Updates `Take` wrappers to remove the number of bytes copied.
+    fn taken(&mut self, _bytes: u64) {}
+
     /// The minimum of the limit of all `Take<_>` wrappers, `u64::MAX` otherwise.
     /// This method does not account for data `BufReader` buffers and would underreport
     /// the limit of a `Take<BufReader<Take<_>>>` type. Thus its result is only valid
@@ -251,6 +257,10 @@ where
         (**self).drain_to(writer, limit)
     }
 
+    fn taken(&mut self, bytes: u64) {
+        (**self).taken(bytes);
+    }
+
     fn min_limit(&self) -> u64 {
         (**self).min_limit()
     }
@@ -407,6 +417,11 @@ impl<T: CopyRead> CopyRead for Take<T> {
         Ok(bytes_drained)
     }
 
+    fn taken(&mut self, bytes: u64) {
+        self.set_limit(self.limit() - bytes);
+        self.get_mut().taken(bytes);
+    }
+
     fn min_limit(&self) -> u64 {
         min(Take::limit(self), self.get_ref().min_limit())
     }
@@ -432,6 +447,10 @@ impl<T: CopyRead> CopyRead for BufReader<T> {
         Ok(bytes as u64 + inner_bytes)
     }
 
+    fn taken(&mut self, bytes: u64) {
+        self.get_mut().taken(bytes);
+    }
+
     fn min_limit(&self) -> u64 {
         self.get_ref().min_limit()
     }
@@ -457,10 +476,21 @@ fn fd_to_meta<T: AsRawFd>(fd: &T) -> FdMeta {
 }
 
 pub(super) enum CopyResult {
-    Ended(Result<u64>),
+    Ended(u64),
+    Error(Error, u64),
     Fallback(u64),
 }
 
+impl CopyResult {
+    fn update_take(&self, reader: &mut impl CopyRead) {
+        match *self {
+            CopyResult::Fallback(bytes)
+            | CopyResult::Ended(bytes)
+            | CopyResult::Error(_, bytes) => reader.taken(bytes),
+        }
+    }
+}
+
 /// linux-specific implementation that will attempt to use copy_file_range for copy offloading
 /// as the name says, it only works on regular files
 ///
@@ -527,7 +557,7 @@ pub(super) fn copy_regular_files(reader: RawFd, writer: RawFd, max_len: u64) ->
                 // - copying from an overlay filesystem in docker. reported to occur on fedora 32.
                 return CopyResult::Fallback(0);
             }
-            Ok(0) => return CopyResult::Ended(Ok(written)), // reached EOF
+            Ok(0) => return CopyResult::Ended(written), // reached EOF
             Ok(ret) => written += ret as u64,
             Err(err) => {
                 return match err.raw_os_error() {
@@ -545,12 +575,12 @@ pub(super) fn copy_regular_files(reader: RawFd, writer: RawFd, max_len: u64) ->
                         assert_eq!(written, 0);
                         CopyResult::Fallback(0)
                     }
-                    _ => CopyResult::Ended(Err(err)),
+                    _ => CopyResult::Error(err, written),
                 };
             }
         }
     }
-    CopyResult::Ended(Ok(written))
+    CopyResult::Ended(written)
 }
 
 #[derive(PartialEq)]
@@ -623,10 +653,10 @@ fn sendfile_splice(mode: SpliceMode, reader: RawFd, writer: RawFd, len: u64) ->
                     Some(os_err) if mode == SpliceMode::Sendfile && os_err == libc::EOVERFLOW => {
                         CopyResult::Fallback(written)
                     }
-                    _ => CopyResult::Ended(Err(err)),
+                    _ => CopyResult::Error(err, written),
                 };
             }
         }
     }
-    CopyResult::Ended(Ok(written))
+    CopyResult::Ended(written)
 }
diff --git a/library/std/src/sys/unix/kernel_copy/tests.rs b/library/std/src/sys/unix/kernel_copy/tests.rs
index 79e99382591..3937a1ffa38 100644
--- a/library/std/src/sys/unix/kernel_copy/tests.rs
+++ b/library/std/src/sys/unix/kernel_copy/tests.rs
@@ -217,7 +217,7 @@ fn bench_socket_pipe_socket_copy(b: &mut test::Bencher) {
     );
 
     match probe {
-        CopyResult::Ended(Ok(1)) => {
+        CopyResult::Ended(1) => {
             // splice works
         }
         _ => {