1 use smallvec::{smallvec, SmallVec};
2 use std::ffi::{CStr, CString, NulError};
3 
4 /// Similar to `std::ffi::CString`, but avoids heap allocating if the string is
5 /// small enough. Also guarantees it's input is UTF-8 -- used for cases where we
6 /// need to pass a NUL-terminated string to SQLite, and we have a `&str`.
7 #[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
8 pub(crate) struct SmallCString(SmallVec<[u8; 16]>);
9 
10 impl SmallCString {
11     #[inline]
new(s: &str) -> Result<Self, NulError>12     pub fn new(s: &str) -> Result<Self, NulError> {
13         if s.as_bytes().contains(&0_u8) {
14             return Err(Self::fabricate_nul_error(s));
15         }
16         let mut buf = SmallVec::with_capacity(s.len() + 1);
17         buf.extend_from_slice(s.as_bytes());
18         buf.push(0);
19         let res = Self(buf);
20         res.debug_checks();
21         Ok(res)
22     }
23 
24     #[inline]
as_str(&self) -> &str25     pub fn as_str(&self) -> &str {
26         self.debug_checks();
27         // Constructor takes a &str so this is safe.
28         unsafe { std::str::from_utf8_unchecked(self.as_bytes_without_nul()) }
29     }
30 
31     /// Get the bytes not including the NUL terminator. E.g. the bytes which
32     /// make up our `str`:
33     /// - `SmallCString::new("foo").as_bytes_without_nul() == b"foo"`
34     /// - `SmallCString::new("foo").as_bytes_with_nul() == b"foo\0"`
35     #[inline]
as_bytes_without_nul(&self) -> &[u8]36     pub fn as_bytes_without_nul(&self) -> &[u8] {
37         self.debug_checks();
38         &self.0[..self.len()]
39     }
40 
41     /// Get the bytes behind this str *including* the NUL terminator. This
42     /// should never return an empty slice.
43     #[inline]
as_bytes_with_nul(&self) -> &[u8]44     pub fn as_bytes_with_nul(&self) -> &[u8] {
45         self.debug_checks();
46         &self.0
47     }
48 
49     #[inline]
50     #[cfg(debug_assertions)]
debug_checks(&self)51     fn debug_checks(&self) {
52         debug_assert_ne!(self.0.len(), 0);
53         debug_assert_eq!(self.0[self.0.len() - 1], 0);
54         let strbytes = &self.0[..(self.0.len() - 1)];
55         debug_assert!(!strbytes.contains(&0));
56         debug_assert!(std::str::from_utf8(strbytes).is_ok());
57     }
58 
59     #[inline]
60     #[cfg(not(debug_assertions))]
debug_checks(&self)61     fn debug_checks(&self) {}
62 
63     #[inline]
len(&self) -> usize64     pub fn len(&self) -> usize {
65         debug_assert_ne!(self.0.len(), 0);
66         self.0.len() - 1
67     }
68 
69     #[inline]
70     #[allow(unused)] // clippy wants this function.
is_empty(&self) -> bool71     pub fn is_empty(&self) -> bool {
72         self.len() == 0
73     }
74 
75     #[inline]
as_cstr(&self) -> &CStr76     pub fn as_cstr(&self) -> &CStr {
77         let bytes = self.as_bytes_with_nul();
78         debug_assert!(CStr::from_bytes_with_nul(bytes).is_ok());
79         unsafe { CStr::from_bytes_with_nul_unchecked(bytes) }
80     }
81 
82     #[cold]
fabricate_nul_error(b: &str) -> NulError83     fn fabricate_nul_error(b: &str) -> NulError {
84         CString::new(b).unwrap_err()
85     }
86 }
87 
88 impl Default for SmallCString {
89     #[inline]
default() -> Self90     fn default() -> Self {
91         Self(smallvec![0])
92     }
93 }
94 
95 impl std::fmt::Debug for SmallCString {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result96     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97         f.debug_tuple("SmallCString").field(&self.as_str()).finish()
98     }
99 }
100 
101 impl std::ops::Deref for SmallCString {
102     type Target = CStr;
103 
104     #[inline]
deref(&self) -> &CStr105     fn deref(&self) -> &CStr {
106         self.as_cstr()
107     }
108 }
109 
110 impl PartialEq<SmallCString> for str {
111     #[inline]
eq(&self, s: &SmallCString) -> bool112     fn eq(&self, s: &SmallCString) -> bool {
113         s.as_bytes_without_nul() == self.as_bytes()
114     }
115 }
116 
117 impl PartialEq<str> for SmallCString {
118     #[inline]
eq(&self, s: &str) -> bool119     fn eq(&self, s: &str) -> bool {
120         self.as_bytes_without_nul() == s.as_bytes()
121     }
122 }
123 
124 impl std::borrow::Borrow<str> for SmallCString {
125     #[inline]
borrow(&self) -> &str126     fn borrow(&self) -> &str {
127         self.as_str()
128     }
129 }
130 
131 #[cfg(test)]
132 mod test {
133     use super::*;
134 
135     #[test]
test_small_cstring()136     fn test_small_cstring() {
137         // We don't go through the normal machinery for default, so make sure
138         // things work.
139         assert_eq!(SmallCString::default().0, SmallCString::new("").unwrap().0);
140         assert_eq!(SmallCString::new("foo").unwrap().len(), 3);
141         assert_eq!(
142             SmallCString::new("foo").unwrap().as_bytes_with_nul(),
143             b"foo\0"
144         );
145         assert_eq!(
146             SmallCString::new("foo").unwrap().as_bytes_without_nul(),
147             b"foo",
148         );
149 
150         assert_eq!(SmallCString::new("��").unwrap().len(), 4);
151         assert_eq!(
152             SmallCString::new("��").unwrap().0.as_slice(),
153             b"\xf0\x9f\x98\x80\0",
154         );
155         assert_eq!(
156             SmallCString::new("��").unwrap().as_bytes_without_nul(),
157             b"\xf0\x9f\x98\x80",
158         );
159 
160         assert_eq!(SmallCString::new("").unwrap().len(), 0);
161         assert!(SmallCString::new("").unwrap().is_empty());
162 
163         assert_eq!(SmallCString::new("").unwrap().0.as_slice(), b"\0");
164         assert_eq!(SmallCString::new("").unwrap().as_bytes_without_nul(), b"");
165 
166         SmallCString::new("\0").unwrap_err();
167         SmallCString::new("\0abc").unwrap_err();
168         SmallCString::new("abc\0").unwrap_err();
169     }
170 }
171