about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/core/src/str/mod.rs2
-rw-r--r--library/std/src/sys/windows/stdio.rs104
2 files changed, 91 insertions, 15 deletions
diff --git a/library/core/src/str/mod.rs b/library/core/src/str/mod.rs
index 947afbdc68d..ca4e2e6b7f3 100644
--- a/library/core/src/str/mod.rs
+++ b/library/core/src/str/mod.rs
@@ -69,7 +69,7 @@ pub use iter::SplitAsciiWhitespace;
 pub use iter::SplitInclusive;
 
 #[unstable(feature = "str_internals", issue = "none")]
-pub use validations::next_code_point;
+pub use validations::{next_code_point, utf8_char_width};
 
 use iter::MatchIndicesInternal;
 use iter::SplitInternal;
diff --git a/library/std/src/sys/windows/stdio.rs b/library/std/src/sys/windows/stdio.rs
index 6f2618c63b5..2719a530dfd 100644
--- a/library/std/src/sys/windows/stdio.rs
+++ b/library/std/src/sys/windows/stdio.rs
@@ -9,14 +9,25 @@ use crate::str;
 use crate::sys::c;
 use crate::sys::cvt;
 use crate::sys::handle::Handle;
+use core::str::utf8_char_width;
 
 // Don't cache handles but get them fresh for every read/write. This allows us to track changes to
 // the value over time (such as if a process calls `SetStdHandle` while it's running). See #40490.
 pub struct Stdin {
     surrogate: u16,
 }
-pub struct Stdout;
-pub struct Stderr;
+pub struct Stdout {
+    incomplete_utf8: IncompleteUtf8,
+}
+
+pub struct Stderr {
+    incomplete_utf8: IncompleteUtf8,
+}
+
+struct IncompleteUtf8 {
+    bytes: [u8; 4],
+    len: u8,
+}
 
 // Apparently Windows doesn't handle large reads on stdin or writes to stdout/stderr well (see
 // #13304 for details).
@@ -51,7 +62,15 @@ fn is_console(handle: c::HANDLE) -> bool {
     unsafe { c::GetConsoleMode(handle, &mut mode) != 0 }
 }
 
-fn write(handle_id: c::DWORD, data: &[u8]) -> io::Result<usize> {
+fn write(
+    handle_id: c::DWORD,
+    data: &[u8],
+    incomplete_utf8: &mut IncompleteUtf8,
+) -> io::Result<usize> {
+    if data.is_empty() {
+        return Ok(0);
+    }
+
     let handle = get_handle(handle_id)?;
     if !is_console(handle) {
         unsafe {
@@ -62,22 +81,73 @@ fn write(handle_id: c::DWORD, data: &[u8]) -> io::Result<usize> {
         }
     }
 
-    // As the console is meant for presenting text, we assume bytes of `data` come from a string
-    // and are encoded as UTF-8, which needs to be encoded as UTF-16.
+    if incomplete_utf8.len > 0 {
+        assert!(
+            incomplete_utf8.len < 4,
+            "Unexpected number of bytes for incomplete UTF-8 codepoint."
+        );
+        if data[0] >> 6 != 0b10 {
+            // not a continuation byte - reject
+            incomplete_utf8.len = 0;
+            return Err(io::Error::new_const(
+                io::ErrorKind::InvalidData,
+                &"Windows stdio in console mode does not support writing non-UTF-8 byte sequences",
+            ));
+        }
+        incomplete_utf8.bytes[incomplete_utf8.len as usize] = data[0];
+        incomplete_utf8.len += 1;
+        let char_width = utf8_char_width(incomplete_utf8.bytes[0]);
+        if (incomplete_utf8.len as usize) < char_width {
+            // more bytes needed
+            return Ok(1);
+        }
+        let s = str::from_utf8(&incomplete_utf8.bytes[0..incomplete_utf8.len as usize]);
+        incomplete_utf8.len = 0;
+        match s {
+            Ok(s) => {
+                assert_eq!(char_width, s.len());
+                let written = write_valid_utf8_to_console(handle, s)?;
+                assert_eq!(written, s.len()); // guaranteed by write_valid_utf8_to_console() for single codepoint writes
+                return Ok(1);
+            }
+            Err(_) => {
+                return Err(io::Error::new_const(
+                    io::ErrorKind::InvalidData,
+                    &"Windows stdio in console mode does not support writing non-UTF-8 byte sequences",
+                ));
+            }
+        }
+    }
+
+    // As the console is meant for presenting text, we assume bytes of `data` are encoded as UTF-8,
+    // which needs to be encoded as UTF-16.
     //
     // If the data is not valid UTF-8 we write out as many bytes as are valid.
-    // Only when there are no valid bytes (which will happen on the next call), return an error.
+    // If the first byte is invalid it is either first byte of a multi-byte sequence but the
+    // provided byte slice is too short or it is the first byte of an invalide multi-byte sequence.
     let len = cmp::min(data.len(), MAX_BUFFER_SIZE / 2);
     let utf8 = match str::from_utf8(&data[..len]) {
         Ok(s) => s,
         Err(ref e) if e.valid_up_to() == 0 => {
-            return Err(io::Error::new_const(
-                io::ErrorKind::InvalidData,
-                &"Windows stdio in console mode does not support writing non-UTF-8 byte sequences",
-            ));
+            let first_byte_char_width = utf8_char_width(data[0]);
+            if first_byte_char_width > 1 && data.len() < first_byte_char_width {
+                incomplete_utf8.bytes[0] = data[0];
+                incomplete_utf8.len = 1;
+                return Ok(1);
+            } else {
+                return Err(io::Error::new_const(
+                    io::ErrorKind::InvalidData,
+                    &"Windows stdio in console mode does not support writing non-UTF-8 byte sequences",
+                ));
+            }
         }
         Err(e) => str::from_utf8(&data[..e.valid_up_to()]).unwrap(),
     };
+
+    write_valid_utf8_to_console(handle, utf8)
+}
+
+fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usize> {
     let mut utf16 = [0u16; MAX_BUFFER_SIZE / 2];
     let mut len_utf16 = 0;
     for (chr, dest) in utf8.encode_utf16().zip(utf16.iter_mut()) {
@@ -259,15 +329,21 @@ fn utf16_to_utf8(utf16: &[u16], utf8: &mut [u8]) -> io::Result<usize> {
     Ok(written)
 }
 
+impl IncompleteUtf8 {
+    pub const fn new() -> IncompleteUtf8 {
+        IncompleteUtf8 { bytes: [0; 4], len: 0 }
+    }
+}
+
 impl Stdout {
     pub const fn new() -> Stdout {
-        Stdout
+        Stdout { incomplete_utf8: IncompleteUtf8::new() }
     }
 }
 
 impl io::Write for Stdout {
     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
-        write(c::STD_OUTPUT_HANDLE, buf)
+        write(c::STD_OUTPUT_HANDLE, buf, &mut self.incomplete_utf8)
     }
 
     fn flush(&mut self) -> io::Result<()> {
@@ -277,13 +353,13 @@ impl io::Write for Stdout {
 
 impl Stderr {
     pub const fn new() -> Stderr {
-        Stderr
+        Stderr { incomplete_utf8: IncompleteUtf8::new() }
     }
 }
 
 impl io::Write for Stderr {
     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
-        write(c::STD_ERROR_HANDLE, buf)
+        write(c::STD_ERROR_HANDLE, buf, &mut self.incomplete_utf8)
     }
 
     fn flush(&mut self) -> io::Result<()> {