about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2013-07-04 10:25:59 -0700
committerbors <bors@rust-lang.org>2013-07-04 10:25:59 -0700
commit3d7c1ddf74f9ea678c9f22d609d67cfdd3a2ad3f (patch)
tree7513b65d352255d3238a329e0456368f3f6bc53d /src
parente07e9bbf3698c7e1491129f58cb02858dde96337 (diff)
parente9988c1e2d9ec0a1442f8591804c471ee11e1924 (diff)
downloadrust-3d7c1ddf74f9ea678c9f22d609d67cfdd3a2ad3f.tar.gz
rust-3d7c1ddf74f9ea678c9f22d609d67cfdd3a2ad3f.zip
auto merge of #7513 : sfackler/rust/master, r=msullivan
The Base64 package previously had extremely basic functionality. It only
suported the standard encoding character set, didn't support line breaks
and always padded output. This commit makes it significantly more
powerful.

The FromBase64 impl now supports all of the standard variants of Base64.
It ignores newlines,interprets '-' and '_' as well as '+' and '/' and
doesn't require padding. It isn't incredibly pedantic and will
successfully parse strings that are not strictly valid, but I don't
think the extra complexity required to make it accept _only_ valid
strings is worth it.

The ToBase64 trait has been modified such that to_base64 now takes a
base64::Config struct which contains the output format configuration.
This currently includes the selection of character set (standard or
url safe), whether or not to pad and an optional line break width. The
package comes with three static Config structs for the RFC 4648
standard, RFC 4648 url safe and RFC 2045 MIME formats.

The other option for configuring ToBase64 output would be to have one
method with the configuration flags passed and other traits with default
impls for the common cases, but I think that's a little messier.

FromBase64 still kills the task if you pass it invalid input, which isn't
particularly appropriate for a function into which you'll be passing
unvalidated input. Would it be worth changing its signature to return a
Result?
Diffstat (limited to 'src')
-rw-r--r--src/libextra/base64.rs354
1 files changed, 233 insertions, 121 deletions
diff --git a/src/libextra/base64.rs b/src/libextra/base64.rs
index a53a22ee831..3c1fc72e957 100644
--- a/src/libextra/base64.rs
+++ b/src/libextra/base64.rs
@@ -10,17 +10,37 @@
 
 //! Base64 binary-to-text encoding
 
+/// Available encoding character sets
+pub enum CharacterSet {
+    /// The standard character set (uses '+' and '/')
+    Standard,
+    /// The URL safe character set (uses '-' and '_')
+    UrlSafe
+}
 
-use std::vec;
-
-/// A trait for converting a value to base64 encoding.
-pub trait ToBase64 {
-    /// Converts the value of `self` to a base64 value, returning the owned
-    /// string
-    fn to_base64(&self) -> ~str;
+/// Contains configuration parameters for to_base64
+pub struct Config {
+    /// Character set to use
+    char_set: CharacterSet,
+    /// True to pad output with '=' characters
+    pad: bool,
+    /// Some(len) to wrap lines at len, None to disable line wrapping
+    line_length: Option<uint>
 }
 
-static CHARS: [char, ..64] = [
+/// Configuration for RFC 4648 standard base64 encoding
+pub static STANDARD: Config =
+    Config {char_set: Standard, pad: true, line_length: None};
+
+/// Configuration for RFC 4648 base64url encoding
+pub static URL_SAFE: Config =
+    Config {char_set: UrlSafe, pad: false, line_length: None};
+
+/// Configuration for RFC 2045 MIME base64 encoding
+pub static MIME: Config =
+    Config {char_set: Standard, pad: true, line_length: Some(76)};
+
+static STANDARD_CHARS: [char, ..64] = [
     'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
     'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
     'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
@@ -28,6 +48,21 @@ static CHARS: [char, ..64] = [
     '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'
 ];
 
+static URLSAFE_CHARS: [char, ..64] = [
+    'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
+    'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
+    'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
+    'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_'
+];
+
+/// A trait for converting a value to base64 encoding.
+pub trait ToBase64 {
+    /// Converts the value of `self` to a base64 value following the specified
+    /// format configuration, returning the owned string.
+    fn to_base64(&self, config: Config) -> ~str;
+}
+
 impl<'self> ToBase64 for &'self [u8] {
     /**
      * Turn a vector of `u8` bytes into a base64 string.
@@ -36,55 +71,81 @@ impl<'self> ToBase64 for &'self [u8] {
      *
      * ~~~ {.rust}
      * extern mod extra;
-     * use extra::base64::ToBase64;
+     * use extra::base64::{ToBase64, standard};
      *
      * fn main () {
-     *     let str = [52,32].to_base64();
+     *     let str = [52,32].to_base64(standard);
      *     println(fmt!("%s", str));
      * }
      * ~~~
      */
-    fn to_base64(&self) -> ~str {
+    fn to_base64(&self, config: Config) -> ~str {
+        let chars = match config.char_set {
+            Standard => STANDARD_CHARS,
+            UrlSafe => URLSAFE_CHARS
+        };
+
         let mut s = ~"";
+        let mut i = 0;
+        let mut cur_length = 0;
         let len = self.len();
-        s.reserve(((len + 3u) / 4u) * 3u);
+        while i < len - (len % 3) {
+            match config.line_length {
+                Some(line_length) =>
+                    if cur_length >= line_length {
+                        s.push_str("\r\n");
+                        cur_length = 0;
+                    },
+                None => ()
+            }
 
-        let mut i = 0u;
-
-        while i < len - (len % 3u) {
-            let n = (self[i] as uint) << 16u |
-                    (self[i + 1u] as uint) << 8u |
-                    (self[i + 2u] as uint);
+            let n = (self[i] as u32) << 16 |
+                    (self[i + 1] as u32) << 8 |
+                    (self[i + 2] as u32);
 
             // This 24-bit number gets separated into four 6-bit numbers.
-            s.push_char(CHARS[(n >> 18u) & 63u]);
-            s.push_char(CHARS[(n >> 12u) & 63u]);
-            s.push_char(CHARS[(n >> 6u) & 63u]);
-            s.push_char(CHARS[n & 63u]);
+            s.push_char(chars[(n >> 18) & 63]);
+            s.push_char(chars[(n >> 12) & 63]);
+            s.push_char(chars[(n >> 6 ) & 63]);
+            s.push_char(chars[n & 63]);
+
+            cur_length += 4;
+            i += 3;
+        }
 
-            i += 3u;
+        if len % 3 != 0 {
+            match config.line_length {
+                Some(line_length) =>
+                    if cur_length >= line_length {
+                        s.push_str("\r\n");
+                    },
+                None => ()
+            }
         }
 
         // Heh, would be cool if we knew this was exhaustive
         // (the dream of bounded integer types)
         match len % 3 {
-          0 => (),
-          1 => {
-            let n = (self[i] as uint) << 16u;
-            s.push_char(CHARS[(n >> 18u) & 63u]);
-            s.push_char(CHARS[(n >> 12u) & 63u]);
-            s.push_char('=');
-            s.push_char('=');
-          }
-          2 => {
-            let n = (self[i] as uint) << 16u |
-                (self[i + 1u] as uint) << 8u;
-            s.push_char(CHARS[(n >> 18u) & 63u]);
-            s.push_char(CHARS[(n >> 12u) & 63u]);
-            s.push_char(CHARS[(n >> 6u) & 63u]);
-            s.push_char('=');
-          }
-          _ => fail!("Algebra is broken, please alert the math police")
+            0 => (),
+            1 => {
+                let n = (self[i] as u32) << 16;
+                s.push_char(chars[(n >> 18) & 63]);
+                s.push_char(chars[(n >> 12) & 63]);
+                if config.pad {
+                    s.push_str("==");
+                }
+            }
+            2 => {
+                let n = (self[i] as u32) << 16 |
+                    (self[i + 1u] as u32) << 8;
+                s.push_char(chars[(n >> 18) & 63]);
+                s.push_char(chars[(n >> 12) & 63]);
+                s.push_char(chars[(n >> 6 ) & 63]);
+                if config.pad {
+                    s.push_char('=');
+                }
+            }
+            _ => fail!("Algebra is broken, please alert the math police")
         }
         s
     }
@@ -99,23 +160,25 @@ impl<'self> ToBase64 for &'self str {
      *
      * ~~~ {.rust}
      * extern mod extra;
-     * use extra::base64::ToBase64;
+     * use extra::base64::{ToBase64, standard};
      *
      * fn main () {
-     *     let str = "Hello, World".to_base64();
+     *     let str = "Hello, World".to_base64(standard);
      *     println(fmt!("%s",str));
      * }
      * ~~~
      *
      */
-    fn to_base64(&self) -> ~str {
-        self.as_bytes().to_base64()
+    fn to_base64(&self, config: Config) -> ~str {
+        self.as_bytes().to_base64(config)
     }
 }
 
-#[allow(missing_doc)]
+/// A trait for converting from base64 encoded values.
 pub trait FromBase64 {
-    fn from_base64(&self) -> ~[u8];
+    /// Converts the value of `self`, interpreted as base64 encoded data, into
+    /// an owned vector of bytes, returning the vector.
+    fn from_base64(&self) -> Result<~[u8], ~str>;
 }
 
 impl<'self> FromBase64 for &'self [u8] {
@@ -127,69 +190,64 @@ impl<'self> FromBase64 for &'self [u8] {
      *
      * ~~~ {.rust}
      * extern mod extra;
-     * use extra::base64::ToBase64;
-     * use extra::base64::FromBase64;
+     * use extra::base64::{ToBase64, FromBase64, standard};
      *
      * fn main () {
-     *     let str = [52,32].to_base64();
+     *     let str = [52,32].to_base64(standard);
      *     println(fmt!("%s", str));
      *     let bytes = str.from_base64();
      *     println(fmt!("%?",bytes));
      * }
      * ~~~
      */
-    fn from_base64(&self) -> ~[u8] {
-        if self.len() % 4u != 0u { fail!("invalid base64 length"); }
+    fn from_base64(&self) -> Result<~[u8], ~str> {
+        let mut r = ~[];
+        let mut buf: u32 = 0;
+        let mut modulus = 0;
 
-        let len = self.len();
-        let mut padding = 0u;
+        let mut it = self.iter();
+        for it.advance |&byte| {
+            let ch = byte as char;
+            let val = byte as u32;
 
-        if len != 0u {
-            if self[len - 1u] == '=' as u8 { padding += 1u; }
-            if self[len - 2u] == '=' as u8 { padding += 1u; }
-        }
+            match ch {
+                'A'..'Z'  => buf |= val - 0x41,
+                'a'..'z'  => buf |= val - 0x47,
+                '0'..'9'  => buf |= val + 0x04,
+                '+'|'-'   => buf |= 0x3E,
+                '/'|'_'   => buf |= 0x3F,
+                '\r'|'\n' => loop,
+                '='       => break,
+                _         => return Err(~"Invalid Base64 character")
+            }
 
-        let mut r = vec::with_capacity((len / 4u) * 3u - padding);
-
-        let mut i = 0u;
-        while i < len {
-            let mut n = 0u;
-
-            for 4u.times {
-                let ch = self[i] as char;
-                n <<= 6u;
-
-                match ch {
-                    'A'..'Z' => n |= (ch as uint) - 0x41,
-                    'a'..'z' => n |= (ch as uint) - 0x47,
-                    '0'..'9' => n |= (ch as uint) + 0x04,
-                    '+'      => n |= 0x3E,
-                    '/'      => n |= 0x3F,
-                    '='      => {
-                        match len - i {
-                            1u => {
-                                r.push(((n >> 16u) & 0xFFu) as u8);
-                                r.push(((n >> 8u ) & 0xFFu) as u8);
-                                return copy r;
-                            }
-                            2u => {
-                                r.push(((n >> 10u) & 0xFFu) as u8);
-                                return copy r;
-                            }
-                            _ => fail!("invalid base64 padding")
-                        }
-                    }
-                    _ => fail!("invalid base64 character")
-                }
+            buf <<= 6;
+            modulus += 1;
+            if modulus == 4 {
+                modulus = 0;
+                r.push((buf >> 22) as u8);
+                r.push((buf >> 14) as u8);
+                r.push((buf >> 6 ) as u8);
+            }
+        }
 
-                i += 1u;
-            };
+        if !it.all(|&byte| {byte as char == '='}) {
+            return Err(~"Invalid Base64 character");
+        }
 
-            r.push(((n >> 16u) & 0xFFu) as u8);
-            r.push(((n >> 8u ) & 0xFFu) as u8);
-            r.push(((n       ) & 0xFFu) as u8);
+        match modulus {
+            2 => {
+                r.push((buf >> 10) as u8);
+            }
+            3 => {
+                r.push((buf >> 16) as u8);
+                r.push((buf >> 8 ) as u8);
+            }
+            0 => (),
+            _ => return Err(~"Invalid Base64 length")
         }
-        r
+
+        Ok(r)
     }
 }
 
@@ -199,7 +257,8 @@ impl<'self> FromBase64 for &'self str {
      * to the byte values it encodes.
      *
      * You can use the `from_bytes` function in `std::str`
-     * to turn a `[u8]` into a string with characters corresponding to those values.
+     * to turn a `[u8]` into a string with characters corresponding to those
+     * values.
      *
      * # Example
      *
@@ -207,12 +266,11 @@ impl<'self> FromBase64 for &'self str {
      *
      * ~~~ {.rust}
      * extern mod extra;
-     * use extra::base64::ToBase64;
-     * use extra::base64::FromBase64;
+     * use extra::base64::{ToBase64, FromBase64, standard};
      * use std::str;
      *
      * fn main () {
-     *     let hello_str = "Hello, World".to_base64();
+     *     let hello_str = "Hello, World".to_base64(standard);
      *     println(fmt!("%s",hello_str));
      *     let bytes = hello_str.from_base64();
      *     println(fmt!("%?",bytes));
@@ -221,32 +279,86 @@ impl<'self> FromBase64 for &'self str {
      * }
      * ~~~
      */
-    fn from_base64(&self) -> ~[u8] {
+    fn from_base64(&self) -> Result<~[u8], ~str> {
         self.as_bytes().from_base64()
     }
 }
 
-#[cfg(test)]
-mod tests {
-    #[test]
-    fn test_to_base64() {
-        assert_eq!("".to_base64(), ~"");
-        assert_eq!("f".to_base64(), ~"Zg==");
-        assert_eq!("fo".to_base64(), ~"Zm8=");
-        assert_eq!("foo".to_base64(), ~"Zm9v");
-        assert_eq!("foob".to_base64(), ~"Zm9vYg==");
-        assert_eq!("fooba".to_base64(), ~"Zm9vYmE=");
-        assert_eq!("foobar".to_base64(), ~"Zm9vYmFy");
-    }
+#[test]
+fn test_to_base64_basic() {
+    assert_eq!("".to_base64(STANDARD), ~"");
+    assert_eq!("f".to_base64(STANDARD), ~"Zg==");
+    assert_eq!("fo".to_base64(STANDARD), ~"Zm8=");
+    assert_eq!("foo".to_base64(STANDARD), ~"Zm9v");
+    assert_eq!("foob".to_base64(STANDARD), ~"Zm9vYg==");
+    assert_eq!("fooba".to_base64(STANDARD), ~"Zm9vYmE=");
+    assert_eq!("foobar".to_base64(STANDARD), ~"Zm9vYmFy");
+}
+
+#[test]
+fn test_to_base64_line_break() {
+    assert!(![0u8, 1000].to_base64(Config {line_length: None, ..STANDARD})
+        .contains("\r\n"));
+    assert_eq!("foobar".to_base64(Config {line_length: Some(4), ..STANDARD}),
+        ~"Zm9v\r\nYmFy");
+}
+
+#[test]
+fn test_to_base64_padding() {
+    assert_eq!("f".to_base64(Config {pad: false, ..STANDARD}), ~"Zg");
+    assert_eq!("fo".to_base64(Config {pad: false, ..STANDARD}), ~"Zm8");
+}
+
+#[test]
+fn test_to_base64_url_safe() {
+    assert_eq!([251, 255].to_base64(URL_SAFE), ~"-_8");
+    assert_eq!([251, 255].to_base64(STANDARD), ~"+/8=");
+}
+
+#[test]
+fn test_from_base64_basic() {
+    assert_eq!("".from_base64().get(), "".as_bytes().to_owned());
+    assert_eq!("Zg==".from_base64().get(), "f".as_bytes().to_owned());
+    assert_eq!("Zm8=".from_base64().get(), "fo".as_bytes().to_owned());
+    assert_eq!("Zm9v".from_base64().get(), "foo".as_bytes().to_owned());
+    assert_eq!("Zm9vYg==".from_base64().get(), "foob".as_bytes().to_owned());
+    assert_eq!("Zm9vYmE=".from_base64().get(), "fooba".as_bytes().to_owned());
+    assert_eq!("Zm9vYmFy".from_base64().get(), "foobar".as_bytes().to_owned());
+}
+
+#[test]
+fn test_from_base64_newlines() {
+    assert_eq!("Zm9v\r\nYmFy".from_base64().get(),
+        "foobar".as_bytes().to_owned());
+}
+
+#[test]
+fn test_from_base64_urlsafe() {
+    assert_eq!("-_8".from_base64().get(), "+/8=".from_base64().get());
+}
+
+#[test]
+fn test_from_base64_invalid_char() {
+    assert!("Zm$=".from_base64().is_err())
+    assert!("Zg==$".from_base64().is_err());
+}
+
+#[test]
+fn test_from_base64_invalid_padding() {
+    assert!("Z===".from_base64().is_err());
+}
+
+#[test]
+fn test_base64_random() {
+    use std::rand::{task_rng, random, RngUtil};
+    use std::vec;
 
-    #[test]
-    fn test_from_base64() {
-        assert_eq!("".from_base64(), "".as_bytes().to_owned());
-        assert_eq!("Zg==".from_base64(), "f".as_bytes().to_owned());
-        assert_eq!("Zm8=".from_base64(), "fo".as_bytes().to_owned());
-        assert_eq!("Zm9v".from_base64(), "foo".as_bytes().to_owned());
-        assert_eq!("Zm9vYg==".from_base64(), "foob".as_bytes().to_owned());
-        assert_eq!("Zm9vYmE=".from_base64(), "fooba".as_bytes().to_owned());
-        assert_eq!("Zm9vYmFy".from_base64(), "foobar".as_bytes().to_owned());
+    for 1000.times {
+        let v: ~[u8] = do vec::build |push| {
+            for task_rng().gen_uint_range(1, 100).times {
+                push(random());
+            }
+        };
+        assert_eq!(v.to_base64(STANDARD).from_base64().get(), v);
     }
 }