push/internal/storage/
db.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4use std::{ops::Deref, path::Path};
5
6use rusqlite::Connection;
7use sql_support::{open_database, ConnExt};
8
9use crate::error::{debug, PushError, Result};
10
11use super::{record::PushRecord, schema};
12
13pub trait Storage: Sized {
14    fn open<P: AsRef<Path>>(path: P) -> Result<Self>;
15
16    fn get_record(&self, chid: &str) -> Result<Option<PushRecord>>;
17
18    fn get_record_by_scope(&self, scope: &str) -> Result<Option<PushRecord>>;
19
20    fn put_record(&self, record: &PushRecord) -> Result<bool>;
21
22    fn delete_record(&self, chid: &str) -> Result<bool>;
23
24    fn delete_all_records(&self) -> Result<()>;
25
26    fn get_channel_list(&self) -> Result<Vec<String>>;
27
28    #[allow(dead_code)]
29    fn update_endpoint(&self, channel_id: &str, endpoint: &str) -> Result<bool>;
30
31    // Some of our "meta" keys are more important than others, so they get special helpers.
32    fn get_uaid(&self) -> Result<Option<String>>;
33    fn set_uaid(&self, uaid: &str) -> Result<()>;
34
35    fn get_auth(&self) -> Result<Option<String>>;
36    fn set_auth(&self, auth: &str) -> Result<()>;
37
38    fn get_registration_id(&self) -> Result<Option<String>>;
39    fn set_registration_id(&self, native_id: &str) -> Result<()>;
40
41    // And general purpose meta with hard-coded key names spread everywhere.
42    fn get_meta(&self, key: &str) -> Result<Option<String>>;
43    fn set_meta(&self, key: &str, value: &str) -> Result<()>;
44}
45
46pub struct PushDb {
47    pub db: Connection,
48}
49
50impl PushDb {
51    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
52        let path = path.as_ref();
53        // By default, file open errors are StorageSqlErrors and aren't super helpful.
54        // Instead, remap to StorageError and provide the path to the file that couldn't be opened.
55        let initializer = schema::PushConnectionInitializer {};
56        let db = open_database::open_database(path, &initializer).map_err(|orig| {
57            PushError::StorageError(format!(
58                "Could not open database file {:?} - {}",
59                &path.as_os_str(),
60                orig,
61            ))
62        })?;
63        Ok(Self { db })
64    }
65
66    #[cfg(test)]
67    pub fn open_in_memory() -> Result<Self> {
68        // A nod to our tests which use this.
69        error_support::init_for_tests();
70
71        let initializer = schema::PushConnectionInitializer {};
72        let db = open_database::open_memory_database(&initializer)?;
73        Ok(Self { db })
74    }
75
76    /// Normalize UUID values to undashed, lowercase.
77    // The server mangles ChannelID UUIDs to undashed lowercase values. We should force those
78    // so that key lookups continue to work.
79    pub fn normalize_uuid(uuid: &str) -> String {
80        uuid.replace('-', "").to_lowercase()
81    }
82}
83
84impl Deref for PushDb {
85    type Target = Connection;
86    fn deref(&self) -> &Connection {
87        &self.db
88    }
89}
90
91impl ConnExt for PushDb {
92    fn conn(&self) -> &Connection {
93        &self.db
94    }
95}
96
97impl Storage for PushDb {
98    fn get_record(&self, chid: &str) -> Result<Option<PushRecord>> {
99        let query = format!(
100            "SELECT {common_cols}
101             FROM push_record WHERE channel_id = :chid",
102            common_cols = schema::COMMON_COLS,
103        );
104        self.try_query_row(
105            &query,
106            &[(":chid", &Self::normalize_uuid(chid))],
107            PushRecord::from_row,
108            false,
109        )
110    }
111
112    fn get_record_by_scope(&self, scope: &str) -> Result<Option<PushRecord>> {
113        let query = format!(
114            "SELECT {common_cols}
115             FROM push_record WHERE scope = :scope",
116            common_cols = schema::COMMON_COLS,
117        );
118        self.try_query_row(&query, &[(":scope", scope)], PushRecord::from_row, false)
119    }
120
121    fn put_record(&self, record: &PushRecord) -> Result<bool> {
122        debug!(
123            "adding push subscription for scope '{}', channel '{}', endpoint '{}'",
124            record.scope, record.channel_id, record.endpoint
125        );
126        let query = format!(
127            "INSERT OR REPLACE INTO push_record
128                 ({common_cols})
129             VALUES
130                 (:channel_id, :endpoint, :scope, :key, :ctime, :app_server_key)",
131            common_cols = schema::COMMON_COLS,
132        );
133        let affected_rows = self.execute(
134            &query,
135            &[
136                (
137                    ":channel_id",
138                    &Self::normalize_uuid(&record.channel_id) as &dyn rusqlite::ToSql,
139                ),
140                (":endpoint", &record.endpoint),
141                (":scope", &record.scope),
142                (":key", &record.key),
143                (":ctime", &record.ctime),
144                (":app_server_key", &record.app_server_key),
145            ],
146        )?;
147        Ok(affected_rows == 1)
148    }
149
150    fn delete_record(&self, chid: &str) -> Result<bool> {
151        debug!("deleting push subscription: {}", chid);
152        let affected_rows = self.execute(
153            "DELETE FROM push_record
154             WHERE channel_id = :chid",
155            &[(":chid", &Self::normalize_uuid(chid))],
156        )?;
157        Ok(affected_rows == 1)
158    }
159
160    fn delete_all_records(&self) -> Result<()> {
161        debug!("deleting all push subscriptions and some metadata");
162        self.execute("DELETE FROM push_record", [])?;
163        // Clean up the meta data records as well, since we probably want to reset the
164        // UAID and get a new secret.
165        // Note we *do not* delete the registration_id - it's possible we are deleting all
166        // subscriptions because we just provided a different registration_id.
167        self.execute_batch(
168            "DELETE FROM meta_data WHERE key='uaid';
169             DELETE FROM meta_data WHERE key='auth';
170             ",
171        )?;
172        Ok(())
173    }
174
175    fn get_channel_list(&self) -> Result<Vec<String>> {
176        self.query_rows_and_then(
177            "SELECT channel_id FROM push_record",
178            [],
179            |row| -> Result<String> { Ok(row.get(0)?) },
180        )
181    }
182
183    fn update_endpoint(&self, channel_id: &str, endpoint: &str) -> Result<bool> {
184        debug!("updating endpoint for '{}' to '{}'", channel_id, endpoint);
185        let affected_rows = self.execute(
186            "UPDATE push_record set endpoint = :endpoint
187             WHERE channel_id = :channel_id",
188            &[
189                (":endpoint", &endpoint as &dyn rusqlite::ToSql),
190                (":channel_id", &Self::normalize_uuid(channel_id)),
191            ],
192        )?;
193        Ok(affected_rows == 1)
194    }
195
196    // A couple of helpers to get/set "well known" meta keys.
197    fn get_uaid(&self) -> Result<Option<String>> {
198        self.get_meta("uaid")
199    }
200
201    fn set_uaid(&self, uaid: &str) -> Result<()> {
202        self.set_meta("uaid", uaid)
203    }
204
205    fn get_auth(&self) -> Result<Option<String>> {
206        self.get_meta("auth")
207    }
208
209    fn set_auth(&self, auth: &str) -> Result<()> {
210        self.set_meta("auth", auth)
211    }
212
213    fn get_registration_id(&self) -> Result<Option<String>> {
214        self.get_meta("registration_id")
215    }
216
217    fn set_registration_id(&self, registration_id: &str) -> Result<()> {
218        self.set_meta("registration_id", registration_id)
219    }
220
221    fn get_meta(&self, key: &str) -> Result<Option<String>> {
222        // Get the most recent UAID (which should be the same value across all records,
223        // but paranoia)
224        self.try_query_one(
225            "SELECT value FROM meta_data where key = :key limit 1",
226            &[(":key", &key)],
227            true,
228        )
229        .map_err(PushError::StorageSqlError)
230    }
231
232    fn set_meta(&self, key: &str, value: &str) -> Result<()> {
233        let query = "INSERT or REPLACE into meta_data (key, value) values (:k, :v)";
234        self.execute_cached(query, &[(":k", &key), (":v", &value)])?;
235        Ok(())
236    }
237
238    #[cfg(not(test))]
239    fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
240        PushDb::open(path)
241    }
242
243    #[cfg(test)]
244    fn open<P: AsRef<Path>>(_path: P) -> Result<Self> {
245        PushDb::open_in_memory()
246    }
247}
248
249#[cfg(test)]
250mod test {
251    use crate::error::Result;
252    use crate::internal::crypto::{Crypto, Cryptography};
253
254    use super::PushDb;
255    use crate::internal::crypto::get_random_bytes;
256    use crate::internal::storage::{db::Storage, record::PushRecord};
257    use nss::ensure_initialized;
258
259    const DUMMY_UAID: &str = "abad1dea00000000aabbccdd00000000";
260
261    fn get_db() -> Result<PushDb> {
262        error_support::init_for_tests();
263        // NOTE: In Memory tests can sometimes produce false positives. Use the following
264        // for debugging
265        // PushDb::open("/tmp/push.sqlite");
266        PushDb::open_in_memory()
267    }
268
269    fn get_uuid() -> Result<String> {
270        Ok(get_random_bytes(16)?
271            .iter()
272            .map(|b| format!("{:02x}", b))
273            .collect::<Vec<String>>()
274            .join(""))
275    }
276
277    fn prec(chid: &str) -> PushRecord {
278        PushRecord::new(
279            chid,
280            &format!("https://example.com/update/{}", chid),
281            "https://example.com/",
282            Crypto::generate_key().expect("Couldn't generate_key"),
283        )
284        .unwrap()
285    }
286
287    #[test]
288    fn basic() -> Result<()> {
289        ensure_initialized();
290
291        let db = get_db()?;
292        let chid = &get_uuid()?;
293        let rec = prec(chid);
294
295        assert!(db.get_record(chid)?.is_none());
296        db.put_record(&rec)?;
297        assert!(db.get_record(chid)?.is_some());
298        // don't fail if you've already added this record.
299        db.put_record(&rec)?;
300        // make sure that fetching the same uaid & chid returns the same record.
301        assert_eq!(db.get_record(chid)?, Some(rec.clone()));
302
303        let mut rec2 = rec.clone();
304        rec2.endpoint = format!("https://example.com/update2/{}", chid);
305        db.put_record(&rec2)?;
306        let result = db.get_record(chid)?.unwrap();
307        assert_ne!(result, rec);
308        assert_eq!(result, rec2);
309
310        let result = db.get_record_by_scope("https://example.com/")?.unwrap();
311        assert_eq!(result, rec2);
312
313        Ok(())
314    }
315
316    #[test]
317    fn delete() -> Result<()> {
318        ensure_initialized();
319
320        let db = get_db()?;
321        let chid = &get_uuid()?;
322        let rec = prec(chid);
323
324        assert!(db.put_record(&rec)?);
325        assert!(db.get_record(chid)?.is_some());
326        assert!(db.delete_record(chid)?);
327        assert!(db.get_record(chid)?.is_none());
328        Ok(())
329    }
330
331    #[test]
332    fn delete_all_records() -> Result<()> {
333        ensure_initialized();
334
335        let db = get_db()?;
336        let chid = &get_uuid()?;
337        let rec = prec(chid);
338        let mut rec2 = rec.clone();
339        rec2.channel_id = get_uuid()?;
340        rec2.endpoint = format!("https://example.com/update/{}", &rec2.channel_id);
341
342        assert!(db.put_record(&rec)?);
343        // save a record with different channel and endpoint, but same scope - it should overwrite
344        // the first because scopes are unique.
345        assert!(db.put_record(&rec2)?);
346        assert!(db.get_record(&rec.channel_id)?.is_none());
347        assert!(db.get_record(&rec2.channel_id)?.is_some());
348        db.delete_all_records()?;
349        assert!(db.get_record(&rec.channel_id)?.is_none());
350        assert!(db.get_record(&rec.channel_id)?.is_none());
351        assert!(db.get_uaid()?.is_none());
352        assert!(db.get_auth()?.is_none());
353        Ok(())
354    }
355
356    #[test]
357    fn meta() -> Result<()> {
358        ensure_initialized();
359
360        use super::Storage;
361        let db = get_db()?;
362        let no_rec = db.get_uaid()?;
363        assert_eq!(no_rec, None);
364        db.set_uaid(DUMMY_UAID)?;
365        db.set_meta("fruit", "apple")?;
366        db.set_meta("fruit", "banana")?;
367        assert_eq!(db.get_uaid()?, Some(DUMMY_UAID.to_owned()));
368        assert_eq!(db.get_meta("fruit")?, Some("banana".to_owned()));
369        Ok(())
370    }
371
372    #[test]
373    fn dash() -> Result<()> {
374        ensure_initialized();
375
376        let db = get_db()?;
377        let chid = "deadbeef-0000-0000-0000-decafbad12345678";
378
379        let rec = prec(chid);
380
381        assert!(db.put_record(&rec)?);
382        assert!(db.get_record(chid)?.is_some());
383        assert!(db.delete_record(chid)?);
384        Ok(())
385    }
386}