1use std::{ops::Deref, path::Path};
5
6use rusqlite::Connection;
7use sql_support::{open_database, ConnExt};
8
9use crate::error::{debug, PushError, Result};
10
11use super::{record::PushRecord, schema};
12
13pub trait Storage: Sized {
14 fn open<P: AsRef<Path>>(path: P) -> Result<Self>;
15
16 fn get_record(&self, chid: &str) -> Result<Option<PushRecord>>;
17
18 fn get_record_by_scope(&self, scope: &str) -> Result<Option<PushRecord>>;
19
20 fn put_record(&self, record: &PushRecord) -> Result<bool>;
21
22 fn delete_record(&self, chid: &str) -> Result<bool>;
23
24 fn delete_all_records(&self) -> Result<()>;
25
26 fn get_channel_list(&self) -> Result<Vec<String>>;
27
28 #[allow(dead_code)]
29 fn update_endpoint(&self, channel_id: &str, endpoint: &str) -> Result<bool>;
30
31 fn get_uaid(&self) -> Result<Option<String>>;
33 fn set_uaid(&self, uaid: &str) -> Result<()>;
34
35 fn get_auth(&self) -> Result<Option<String>>;
36 fn set_auth(&self, auth: &str) -> Result<()>;
37
38 fn get_registration_id(&self) -> Result<Option<String>>;
39 fn set_registration_id(&self, native_id: &str) -> Result<()>;
40
41 fn get_meta(&self, key: &str) -> Result<Option<String>>;
43 fn set_meta(&self, key: &str, value: &str) -> Result<()>;
44}
45
46pub struct PushDb {
47 pub db: Connection,
48}
49
50impl PushDb {
51 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
52 let path = path.as_ref();
53 let initializer = schema::PushConnectionInitializer {};
56 let db = open_database::open_database(path, &initializer).map_err(|orig| {
57 PushError::StorageError(format!(
58 "Could not open database file {:?} - {}",
59 &path.as_os_str(),
60 orig,
61 ))
62 })?;
63 Ok(Self { db })
64 }
65
66 #[cfg(test)]
67 pub fn open_in_memory() -> Result<Self> {
68 error_support::init_for_tests();
70
71 let initializer = schema::PushConnectionInitializer {};
72 let db = open_database::open_memory_database(&initializer)?;
73 Ok(Self { db })
74 }
75
76 pub fn normalize_uuid(uuid: &str) -> String {
80 uuid.replace('-', "").to_lowercase()
81 }
82}
83
84impl Deref for PushDb {
85 type Target = Connection;
86 fn deref(&self) -> &Connection {
87 &self.db
88 }
89}
90
91impl ConnExt for PushDb {
92 fn conn(&self) -> &Connection {
93 &self.db
94 }
95}
96
97impl Storage for PushDb {
98 fn get_record(&self, chid: &str) -> Result<Option<PushRecord>> {
99 let query = format!(
100 "SELECT {common_cols}
101 FROM push_record WHERE channel_id = :chid",
102 common_cols = schema::COMMON_COLS,
103 );
104 self.try_query_row(
105 &query,
106 &[(":chid", &Self::normalize_uuid(chid))],
107 PushRecord::from_row,
108 false,
109 )
110 }
111
112 fn get_record_by_scope(&self, scope: &str) -> Result<Option<PushRecord>> {
113 let query = format!(
114 "SELECT {common_cols}
115 FROM push_record WHERE scope = :scope",
116 common_cols = schema::COMMON_COLS,
117 );
118 self.try_query_row(&query, &[(":scope", scope)], PushRecord::from_row, false)
119 }
120
121 fn put_record(&self, record: &PushRecord) -> Result<bool> {
122 debug!(
123 "adding push subscription for scope '{}', channel '{}', endpoint '{}'",
124 record.scope, record.channel_id, record.endpoint
125 );
126 let query = format!(
127 "INSERT OR REPLACE INTO push_record
128 ({common_cols})
129 VALUES
130 (:channel_id, :endpoint, :scope, :key, :ctime, :app_server_key)",
131 common_cols = schema::COMMON_COLS,
132 );
133 let affected_rows = self.execute(
134 &query,
135 &[
136 (
137 ":channel_id",
138 &Self::normalize_uuid(&record.channel_id) as &dyn rusqlite::ToSql,
139 ),
140 (":endpoint", &record.endpoint),
141 (":scope", &record.scope),
142 (":key", &record.key),
143 (":ctime", &record.ctime),
144 (":app_server_key", &record.app_server_key),
145 ],
146 )?;
147 Ok(affected_rows == 1)
148 }
149
150 fn delete_record(&self, chid: &str) -> Result<bool> {
151 debug!("deleting push subscription: {}", chid);
152 let affected_rows = self.execute(
153 "DELETE FROM push_record
154 WHERE channel_id = :chid",
155 &[(":chid", &Self::normalize_uuid(chid))],
156 )?;
157 Ok(affected_rows == 1)
158 }
159
160 fn delete_all_records(&self) -> Result<()> {
161 debug!("deleting all push subscriptions and some metadata");
162 self.execute("DELETE FROM push_record", [])?;
163 self.execute_batch(
168 "DELETE FROM meta_data WHERE key='uaid';
169 DELETE FROM meta_data WHERE key='auth';
170 ",
171 )?;
172 Ok(())
173 }
174
175 fn get_channel_list(&self) -> Result<Vec<String>> {
176 self.query_rows_and_then(
177 "SELECT channel_id FROM push_record",
178 [],
179 |row| -> Result<String> { Ok(row.get(0)?) },
180 )
181 }
182
183 fn update_endpoint(&self, channel_id: &str, endpoint: &str) -> Result<bool> {
184 debug!("updating endpoint for '{}' to '{}'", channel_id, endpoint);
185 let affected_rows = self.execute(
186 "UPDATE push_record set endpoint = :endpoint
187 WHERE channel_id = :channel_id",
188 &[
189 (":endpoint", &endpoint as &dyn rusqlite::ToSql),
190 (":channel_id", &Self::normalize_uuid(channel_id)),
191 ],
192 )?;
193 Ok(affected_rows == 1)
194 }
195
196 fn get_uaid(&self) -> Result<Option<String>> {
198 self.get_meta("uaid")
199 }
200
201 fn set_uaid(&self, uaid: &str) -> Result<()> {
202 self.set_meta("uaid", uaid)
203 }
204
205 fn get_auth(&self) -> Result<Option<String>> {
206 self.get_meta("auth")
207 }
208
209 fn set_auth(&self, auth: &str) -> Result<()> {
210 self.set_meta("auth", auth)
211 }
212
213 fn get_registration_id(&self) -> Result<Option<String>> {
214 self.get_meta("registration_id")
215 }
216
217 fn set_registration_id(&self, registration_id: &str) -> Result<()> {
218 self.set_meta("registration_id", registration_id)
219 }
220
221 fn get_meta(&self, key: &str) -> Result<Option<String>> {
222 self.try_query_one(
225 "SELECT value FROM meta_data where key = :key limit 1",
226 &[(":key", &key)],
227 true,
228 )
229 .map_err(PushError::StorageSqlError)
230 }
231
232 fn set_meta(&self, key: &str, value: &str) -> Result<()> {
233 let query = "INSERT or REPLACE into meta_data (key, value) values (:k, :v)";
234 self.execute_cached(query, &[(":k", &key), (":v", &value)])?;
235 Ok(())
236 }
237
238 #[cfg(not(test))]
239 fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
240 PushDb::open(path)
241 }
242
243 #[cfg(test)]
244 fn open<P: AsRef<Path>>(_path: P) -> Result<Self> {
245 PushDb::open_in_memory()
246 }
247}
248
249#[cfg(test)]
250mod test {
251 use crate::error::Result;
252 use crate::internal::crypto::{Crypto, Cryptography};
253
254 use super::PushDb;
255 use crate::internal::crypto::get_random_bytes;
256 use crate::internal::storage::{db::Storage, record::PushRecord};
257 use nss::ensure_initialized;
258
259 const DUMMY_UAID: &str = "abad1dea00000000aabbccdd00000000";
260
261 fn get_db() -> Result<PushDb> {
262 error_support::init_for_tests();
263 PushDb::open_in_memory()
267 }
268
269 fn get_uuid() -> Result<String> {
270 Ok(get_random_bytes(16)?
271 .iter()
272 .map(|b| format!("{:02x}", b))
273 .collect::<Vec<String>>()
274 .join(""))
275 }
276
277 fn prec(chid: &str) -> PushRecord {
278 PushRecord::new(
279 chid,
280 &format!("https://example.com/update/{}", chid),
281 "https://example.com/",
282 Crypto::generate_key().expect("Couldn't generate_key"),
283 )
284 .unwrap()
285 }
286
287 #[test]
288 fn basic() -> Result<()> {
289 ensure_initialized();
290
291 let db = get_db()?;
292 let chid = &get_uuid()?;
293 let rec = prec(chid);
294
295 assert!(db.get_record(chid)?.is_none());
296 db.put_record(&rec)?;
297 assert!(db.get_record(chid)?.is_some());
298 db.put_record(&rec)?;
300 assert_eq!(db.get_record(chid)?, Some(rec.clone()));
302
303 let mut rec2 = rec.clone();
304 rec2.endpoint = format!("https://example.com/update2/{}", chid);
305 db.put_record(&rec2)?;
306 let result = db.get_record(chid)?.unwrap();
307 assert_ne!(result, rec);
308 assert_eq!(result, rec2);
309
310 let result = db.get_record_by_scope("https://example.com/")?.unwrap();
311 assert_eq!(result, rec2);
312
313 Ok(())
314 }
315
316 #[test]
317 fn delete() -> Result<()> {
318 ensure_initialized();
319
320 let db = get_db()?;
321 let chid = &get_uuid()?;
322 let rec = prec(chid);
323
324 assert!(db.put_record(&rec)?);
325 assert!(db.get_record(chid)?.is_some());
326 assert!(db.delete_record(chid)?);
327 assert!(db.get_record(chid)?.is_none());
328 Ok(())
329 }
330
331 #[test]
332 fn delete_all_records() -> Result<()> {
333 ensure_initialized();
334
335 let db = get_db()?;
336 let chid = &get_uuid()?;
337 let rec = prec(chid);
338 let mut rec2 = rec.clone();
339 rec2.channel_id = get_uuid()?;
340 rec2.endpoint = format!("https://example.com/update/{}", &rec2.channel_id);
341
342 assert!(db.put_record(&rec)?);
343 assert!(db.put_record(&rec2)?);
346 assert!(db.get_record(&rec.channel_id)?.is_none());
347 assert!(db.get_record(&rec2.channel_id)?.is_some());
348 db.delete_all_records()?;
349 assert!(db.get_record(&rec.channel_id)?.is_none());
350 assert!(db.get_record(&rec.channel_id)?.is_none());
351 assert!(db.get_uaid()?.is_none());
352 assert!(db.get_auth()?.is_none());
353 Ok(())
354 }
355
356 #[test]
357 fn meta() -> Result<()> {
358 ensure_initialized();
359
360 use super::Storage;
361 let db = get_db()?;
362 let no_rec = db.get_uaid()?;
363 assert_eq!(no_rec, None);
364 db.set_uaid(DUMMY_UAID)?;
365 db.set_meta("fruit", "apple")?;
366 db.set_meta("fruit", "banana")?;
367 assert_eq!(db.get_uaid()?, Some(DUMMY_UAID.to_owned()));
368 assert_eq!(db.get_meta("fruit")?, Some("banana".to_owned()));
369 Ok(())
370 }
371
372 #[test]
373 fn dash() -> Result<()> {
374 ensure_initialized();
375
376 let db = get_db()?;
377 let chid = "deadbeef-0000-0000-0000-decafbad12345678";
378
379 let rec = prec(chid);
380
381 assert!(db.put_record(&rec)?);
382 assert!(db.get_record(chid)?.is_some());
383 assert!(db.delete_record(chid)?);
384 Ok(())
385 }
386}