1 //! Add, remove, or modify a collation
2 use std::cmp::Ordering;
3 use std::os::raw::{c_char, c_int, c_void};
4 use std::panic::{catch_unwind, UnwindSafe};
5 use std::ptr;
6 use std::slice;
7 
8 use crate::ffi;
9 use crate::{str_to_cstring, Connection, InnerConnection, Result};
10 
11 // FIXME copy/paste from function.rs
free_boxed_value<T>(p: *mut c_void)12 unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
13     drop(Box::from_raw(p.cast::<T>()));
14 }
15 
16 impl Connection {
17     /// Add or modify a collation.
18     #[inline]
create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()> where C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,19     pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()>
20     where
21         C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
22     {
23         self.db
24             .borrow_mut()
25             .create_collation(collation_name, x_compare)
26     }
27 
28     /// Collation needed callback
29     #[inline]
collation_needed( &self, x_coll_needed: fn(&Connection, &str) -> Result<()>, ) -> Result<()>30     pub fn collation_needed(
31         &self,
32         x_coll_needed: fn(&Connection, &str) -> Result<()>,
33     ) -> Result<()> {
34         self.db.borrow_mut().collation_needed(x_coll_needed)
35     }
36 
37     /// Remove collation.
38     #[inline]
remove_collation(&self, collation_name: &str) -> Result<()>39     pub fn remove_collation(&self, collation_name: &str) -> Result<()> {
40         self.db.borrow_mut().remove_collation(collation_name)
41     }
42 }
43 
44 impl InnerConnection {
create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()> where C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,45     fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()>
46     where
47         C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
48     {
49         unsafe extern "C" fn call_boxed_closure<C>(
50             arg1: *mut c_void,
51             arg2: c_int,
52             arg3: *const c_void,
53             arg4: c_int,
54             arg5: *const c_void,
55         ) -> c_int
56         where
57             C: Fn(&str, &str) -> Ordering,
58         {
59             let r = catch_unwind(|| {
60                 let boxed_f: *mut C = arg1.cast::<C>();
61                 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
62                 let s1 = {
63                     let c_slice = slice::from_raw_parts(arg3.cast::<u8>(), arg2 as usize);
64                     String::from_utf8_lossy(c_slice)
65                 };
66                 let s2 = {
67                     let c_slice = slice::from_raw_parts(arg5.cast::<u8>(), arg4 as usize);
68                     String::from_utf8_lossy(c_slice)
69                 };
70                 (*boxed_f)(s1.as_ref(), s2.as_ref())
71             });
72             let t = match r {
73                 Err(_) => {
74                     return -1; // FIXME How ?
75                 }
76                 Ok(r) => r,
77             };
78 
79             match t {
80                 Ordering::Less => -1,
81                 Ordering::Equal => 0,
82                 Ordering::Greater => 1,
83             }
84         }
85 
86         let boxed_f: *mut C = Box::into_raw(Box::new(x_compare));
87         let c_name = str_to_cstring(collation_name)?;
88         let flags = ffi::SQLITE_UTF8;
89         let r = unsafe {
90             ffi::sqlite3_create_collation_v2(
91                 self.db(),
92                 c_name.as_ptr(),
93                 flags,
94                 boxed_f.cast::<c_void>(),
95                 Some(call_boxed_closure::<C>),
96                 Some(free_boxed_value::<C>),
97             )
98         };
99         let res = self.decode_result(r);
100         // The xDestroy callback is not called if the sqlite3_create_collation_v2()
101         // function fails.
102         if res.is_err() {
103             drop(unsafe { Box::from_raw(boxed_f) });
104         }
105         res
106     }
107 
collation_needed( &mut self, x_coll_needed: fn(&Connection, &str) -> Result<()>, ) -> Result<()>108     fn collation_needed(
109         &mut self,
110         x_coll_needed: fn(&Connection, &str) -> Result<()>,
111     ) -> Result<()> {
112         use std::mem;
113         #[allow(clippy::needless_return)]
114         unsafe extern "C" fn collation_needed_callback(
115             arg1: *mut c_void,
116             arg2: *mut ffi::sqlite3,
117             e_text_rep: c_int,
118             arg3: *const c_char,
119         ) {
120             use std::ffi::CStr;
121             use std::str;
122 
123             if e_text_rep != ffi::SQLITE_UTF8 {
124                 // TODO: validate
125                 return;
126             }
127 
128             let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
129             let res = catch_unwind(|| {
130                 let conn = Connection::from_handle(arg2).unwrap();
131                 let collation_name = {
132                     let c_slice = CStr::from_ptr(arg3).to_bytes();
133                     str::from_utf8(c_slice).expect("illegal collation sequence name")
134                 };
135                 callback(&conn, collation_name)
136             });
137             if res.is_err() {
138                 return; // FIXME How ?
139             }
140         }
141 
142         let r = unsafe {
143             ffi::sqlite3_collation_needed(
144                 self.db(),
145                 x_coll_needed as *mut c_void,
146                 Some(collation_needed_callback),
147             )
148         };
149         self.decode_result(r)
150     }
151 
152     #[inline]
remove_collation(&mut self, collation_name: &str) -> Result<()>153     fn remove_collation(&mut self, collation_name: &str) -> Result<()> {
154         let c_name = str_to_cstring(collation_name)?;
155         let r = unsafe {
156             ffi::sqlite3_create_collation_v2(
157                 self.db(),
158                 c_name.as_ptr(),
159                 ffi::SQLITE_UTF8,
160                 ptr::null_mut(),
161                 None,
162                 None,
163             )
164         };
165         self.decode_result(r)
166     }
167 }
168 
169 #[cfg(test)]
170 mod test {
171     use crate::{Connection, Result};
172     use fallible_streaming_iterator::FallibleStreamingIterator;
173     use std::cmp::Ordering;
174     use unicase::UniCase;
175 
unicase_compare(s1: &str, s2: &str) -> Ordering176     fn unicase_compare(s1: &str, s2: &str) -> Ordering {
177         UniCase::new(s1).cmp(&UniCase::new(s2))
178     }
179 
180     #[test]
test_unicase() -> Result<()>181     fn test_unicase() -> Result<()> {
182         let db = Connection::open_in_memory()?;
183 
184         db.create_collation("unicase", unicase_compare)?;
185 
186         collate(db)
187     }
188 
collate(db: Connection) -> Result<()>189     fn collate(db: Connection) -> Result<()> {
190         db.execute_batch(
191             "CREATE TABLE foo (bar);
192              INSERT INTO foo (bar) VALUES ('Maße');
193              INSERT INTO foo (bar) VALUES ('MASSE');",
194         )?;
195         let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?;
196         let rows = stmt.query([])?;
197         assert_eq!(rows.count()?, 1);
198         Ok(())
199     }
200 
collation_needed(db: &Connection, collation_name: &str) -> Result<()>201     fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
202         if "unicase" == collation_name {
203             db.create_collation(collation_name, unicase_compare)
204         } else {
205             Ok(())
206         }
207     }
208 
209     #[test]
test_collation_needed() -> Result<()>210     fn test_collation_needed() -> Result<()> {
211         let db = Connection::open_in_memory()?;
212         db.collation_needed(collation_needed)?;
213         collate(db)
214     }
215 }
216