1 //! Add, remove, or modify a collation
2 use std::cmp::Ordering;
3 use std::os::raw::{c_char, c_int, c_void};
4 use std::panic::{catch_unwind, UnwindSafe};
5 use std::ptr;
6 use std::slice;
7
8 use crate::ffi;
9 use crate::{str_to_cstring, Connection, InnerConnection, Result};
10
11 // FIXME copy/paste from function.rs
free_boxed_value<T>(p: *mut c_void)12 unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
13 drop(Box::from_raw(p.cast::<T>()));
14 }
15
16 impl Connection {
17 /// Add or modify a collation.
18 #[inline]
create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()> where C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,19 pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()>
20 where
21 C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
22 {
23 self.db
24 .borrow_mut()
25 .create_collation(collation_name, x_compare)
26 }
27
28 /// Collation needed callback
29 #[inline]
collation_needed( &self, x_coll_needed: fn(&Connection, &str) -> Result<()>, ) -> Result<()>30 pub fn collation_needed(
31 &self,
32 x_coll_needed: fn(&Connection, &str) -> Result<()>,
33 ) -> Result<()> {
34 self.db.borrow_mut().collation_needed(x_coll_needed)
35 }
36
37 /// Remove collation.
38 #[inline]
remove_collation(&self, collation_name: &str) -> Result<()>39 pub fn remove_collation(&self, collation_name: &str) -> Result<()> {
40 self.db.borrow_mut().remove_collation(collation_name)
41 }
42 }
43
44 impl InnerConnection {
create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()> where C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,45 fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()>
46 where
47 C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
48 {
49 unsafe extern "C" fn call_boxed_closure<C>(
50 arg1: *mut c_void,
51 arg2: c_int,
52 arg3: *const c_void,
53 arg4: c_int,
54 arg5: *const c_void,
55 ) -> c_int
56 where
57 C: Fn(&str, &str) -> Ordering,
58 {
59 let r = catch_unwind(|| {
60 let boxed_f: *mut C = arg1.cast::<C>();
61 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
62 let s1 = {
63 let c_slice = slice::from_raw_parts(arg3.cast::<u8>(), arg2 as usize);
64 String::from_utf8_lossy(c_slice)
65 };
66 let s2 = {
67 let c_slice = slice::from_raw_parts(arg5.cast::<u8>(), arg4 as usize);
68 String::from_utf8_lossy(c_slice)
69 };
70 (*boxed_f)(s1.as_ref(), s2.as_ref())
71 });
72 let t = match r {
73 Err(_) => {
74 return -1; // FIXME How ?
75 }
76 Ok(r) => r,
77 };
78
79 match t {
80 Ordering::Less => -1,
81 Ordering::Equal => 0,
82 Ordering::Greater => 1,
83 }
84 }
85
86 let boxed_f: *mut C = Box::into_raw(Box::new(x_compare));
87 let c_name = str_to_cstring(collation_name)?;
88 let flags = ffi::SQLITE_UTF8;
89 let r = unsafe {
90 ffi::sqlite3_create_collation_v2(
91 self.db(),
92 c_name.as_ptr(),
93 flags,
94 boxed_f.cast::<c_void>(),
95 Some(call_boxed_closure::<C>),
96 Some(free_boxed_value::<C>),
97 )
98 };
99 let res = self.decode_result(r);
100 // The xDestroy callback is not called if the sqlite3_create_collation_v2()
101 // function fails.
102 if res.is_err() {
103 drop(unsafe { Box::from_raw(boxed_f) });
104 }
105 res
106 }
107
collation_needed( &mut self, x_coll_needed: fn(&Connection, &str) -> Result<()>, ) -> Result<()>108 fn collation_needed(
109 &mut self,
110 x_coll_needed: fn(&Connection, &str) -> Result<()>,
111 ) -> Result<()> {
112 use std::mem;
113 #[allow(clippy::needless_return)]
114 unsafe extern "C" fn collation_needed_callback(
115 arg1: *mut c_void,
116 arg2: *mut ffi::sqlite3,
117 e_text_rep: c_int,
118 arg3: *const c_char,
119 ) {
120 use std::ffi::CStr;
121 use std::str;
122
123 if e_text_rep != ffi::SQLITE_UTF8 {
124 // TODO: validate
125 return;
126 }
127
128 let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
129 let res = catch_unwind(|| {
130 let conn = Connection::from_handle(arg2).unwrap();
131 let collation_name = {
132 let c_slice = CStr::from_ptr(arg3).to_bytes();
133 str::from_utf8(c_slice).expect("illegal collation sequence name")
134 };
135 callback(&conn, collation_name)
136 });
137 if res.is_err() {
138 return; // FIXME How ?
139 }
140 }
141
142 let r = unsafe {
143 ffi::sqlite3_collation_needed(
144 self.db(),
145 x_coll_needed as *mut c_void,
146 Some(collation_needed_callback),
147 )
148 };
149 self.decode_result(r)
150 }
151
152 #[inline]
remove_collation(&mut self, collation_name: &str) -> Result<()>153 fn remove_collation(&mut self, collation_name: &str) -> Result<()> {
154 let c_name = str_to_cstring(collation_name)?;
155 let r = unsafe {
156 ffi::sqlite3_create_collation_v2(
157 self.db(),
158 c_name.as_ptr(),
159 ffi::SQLITE_UTF8,
160 ptr::null_mut(),
161 None,
162 None,
163 )
164 };
165 self.decode_result(r)
166 }
167 }
168
169 #[cfg(test)]
170 mod test {
171 use crate::{Connection, Result};
172 use fallible_streaming_iterator::FallibleStreamingIterator;
173 use std::cmp::Ordering;
174 use unicase::UniCase;
175
unicase_compare(s1: &str, s2: &str) -> Ordering176 fn unicase_compare(s1: &str, s2: &str) -> Ordering {
177 UniCase::new(s1).cmp(&UniCase::new(s2))
178 }
179
180 #[test]
test_unicase() -> Result<()>181 fn test_unicase() -> Result<()> {
182 let db = Connection::open_in_memory()?;
183
184 db.create_collation("unicase", unicase_compare)?;
185
186 collate(db)
187 }
188
collate(db: Connection) -> Result<()>189 fn collate(db: Connection) -> Result<()> {
190 db.execute_batch(
191 "CREATE TABLE foo (bar);
192 INSERT INTO foo (bar) VALUES ('Maße');
193 INSERT INTO foo (bar) VALUES ('MASSE');",
194 )?;
195 let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?;
196 let rows = stmt.query([])?;
197 assert_eq!(rows.count()?, 1);
198 Ok(())
199 }
200
collation_needed(db: &Connection, collation_name: &str) -> Result<()>201 fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
202 if "unicase" == collation_name {
203 db.create_collation(collation_name, unicase_compare)
204 } else {
205 Ok(())
206 }
207 }
208
209 #[test]
test_collation_needed() -> Result<()>210 fn test_collation_needed() -> Result<()> {
211 let db = Connection::open_in_memory()?;
212 db.collation_needed(collation_needed)?;
213 collate(db)
214 }
215 }
216