1use crate::error::*;
6use crate::schema;
7use interrupt_support::{SqlInterruptHandle, SqlInterruptScope};
8use parking_lot::Mutex;
9use rusqlite::types::{FromSql, ToSql};
10use rusqlite::Connection;
11use rusqlite::OpenFlags;
12use sql_support::open_database::open_database_with_flags;
13use sql_support::ConnExt;
14use std::ops::Deref;
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use url::Url;
18
19pub enum WebExtStorageDb {
27 Open(Connection),
28 Closed,
29}
30
31pub struct StorageDb {
32 pub writer: WebExtStorageDb,
33 interrupt_handle: Arc<SqlInterruptHandle>,
34}
35
36impl StorageDb {
37 pub fn new(db_path: impl AsRef<Path>) -> Result<Self> {
39 let db_path = normalize_path(db_path)?;
40 Self::new_named(db_path)
41 }
42
43 #[cfg(test)]
47 pub fn new_memory(db_path: &str) -> Result<Self> {
48 let name = PathBuf::from(format!("file:{}?mode=memory&cache=shared", db_path));
49 Self::new_named(name)
50 }
51
52 fn new_named(db_path: PathBuf) -> Result<Self> {
53 let flags = OpenFlags::SQLITE_OPEN_NO_MUTEX
56 | OpenFlags::SQLITE_OPEN_URI
57 | OpenFlags::SQLITE_OPEN_CREATE
58 | OpenFlags::SQLITE_OPEN_READ_WRITE;
59
60 let conn = open_database_with_flags(db_path, flags, &schema::WebExtMigrationLogin)?;
61 Ok(Self {
62 interrupt_handle: Arc::new(SqlInterruptHandle::new(&conn)),
63 writer: WebExtStorageDb::Open(conn),
64 })
65 }
66
67 pub fn interrupt_handle(&self) -> Arc<SqlInterruptHandle> {
68 Arc::clone(&self.interrupt_handle)
69 }
70
71 #[allow(dead_code)]
72 pub fn begin_interrupt_scope(&self) -> Result<SqlInterruptScope> {
73 Ok(self.interrupt_handle.begin_interrupt_scope()?)
74 }
75
76 pub fn close(&mut self) -> Result<()> {
83 let conn = match std::mem::replace(&mut self.writer, WebExtStorageDb::Closed) {
84 WebExtStorageDb::Open(conn) => conn,
85 WebExtStorageDb::Closed => return Ok(()),
86 };
87 conn.close().map_err(|(_, y)| Error::SqlError(y))
88 }
89
90 pub(crate) fn get_connection(&self) -> Result<&Connection> {
91 let db = &self.writer;
92 match db {
93 WebExtStorageDb::Open(y) => Ok(y),
94 WebExtStorageDb::Closed => Err(Error::DatabaseConnectionClosed),
95 }
96 }
97}
98
99pub struct ThreadSafeStorageDb {
101 db: Mutex<StorageDb>,
102 interrupt_handle: Arc<SqlInterruptHandle>,
106}
107
108impl ThreadSafeStorageDb {
109 pub fn new(db: StorageDb) -> Self {
110 Self {
111 interrupt_handle: db.interrupt_handle(),
112 db: Mutex::new(db),
113 }
114 }
115
116 pub fn interrupt_handle(&self) -> Arc<SqlInterruptHandle> {
117 Arc::clone(&self.interrupt_handle)
118 }
119
120 #[allow(dead_code)]
121 pub fn begin_interrupt_scope(&self) -> Result<SqlInterruptScope> {
122 Ok(self.interrupt_handle.begin_interrupt_scope()?)
123 }
124}
125
126impl Deref for ThreadSafeStorageDb {
128 type Target = Mutex<StorageDb>;
129
130 #[inline]
131 fn deref(&self) -> &Mutex<StorageDb> {
132 &self.db
133 }
134}
135
136impl AsRef<SqlInterruptHandle> for ThreadSafeStorageDb {
138 fn as_ref(&self) -> &SqlInterruptHandle {
139 &self.interrupt_handle
140 }
141}
142
143pub(crate) mod sql_fns {
144 use rusqlite::{functions::Context, Result};
145 use sync_guid::Guid as SyncGuid;
146
147 #[inline(never)]
148 pub fn generate_guid(_ctx: &Context<'_>) -> Result<SyncGuid> {
149 Ok(SyncGuid::random())
150 }
151}
152
153pub fn put_meta(db: &Connection, key: &str, value: &dyn ToSql) -> Result<()> {
155 db.conn().execute_cached(
156 "REPLACE INTO meta (key, value) VALUES (:key, :value)",
157 rusqlite::named_params! { ":key": key, ":value": value },
158 )?;
159 Ok(())
160}
161
162pub fn get_meta<T: FromSql>(db: &Connection, key: &str) -> Result<Option<T>> {
163 let res = db.conn().try_query_one(
164 "SELECT value FROM meta WHERE key = :key",
165 &[(":key", &key)],
166 true,
167 )?;
168 Ok(res)
169}
170
171pub fn delete_meta(db: &Connection, key: &str) -> Result<()> {
172 db.conn()
173 .execute_cached("DELETE FROM meta WHERE key = :key", &[(":key", &key)])?;
174 Ok(())
175}
176
177fn unurl_path(p: impl AsRef<Path>) -> PathBuf {
188 p.as_ref()
189 .to_str()
190 .and_then(|s| Url::parse(s).ok())
191 .and_then(|u| {
192 if u.scheme() == "file" {
193 u.to_file_path().ok()
194 } else {
195 None
196 }
197 })
198 .unwrap_or_else(|| p.as_ref().to_owned())
199}
200
201#[allow(dead_code)]
206pub fn ensure_url_path(p: impl AsRef<Path>) -> Result<Url> {
207 if let Some(u) = p.as_ref().to_str().and_then(|s| Url::parse(s).ok()) {
208 if u.scheme() == "file" {
209 Ok(u)
210 } else {
211 Err(Error::IllegalDatabasePath(p.as_ref().to_owned()))
212 }
213 } else {
214 let p = p.as_ref();
215 let u = Url::from_file_path(p).map_err(|_| Error::IllegalDatabasePath(p.to_owned()))?;
216 Ok(u)
217 }
218}
219
220fn normalize_path(p: impl AsRef<Path>) -> Result<PathBuf> {
225 let path = unurl_path(p);
226 if let Ok(canonical) = path.canonicalize() {
227 return Ok(canonical);
228 }
229 let file_name = path
239 .file_name()
240 .ok_or_else(|| Error::IllegalDatabasePath(path.clone()))?;
241
242 let parent = path
243 .parent()
244 .ok_or_else(|| Error::IllegalDatabasePath(path.clone()))?;
245
246 let mut canonical = parent.canonicalize()?;
247 canonical.push(file_name);
248 Ok(canonical)
249}
250
251#[cfg(test)]
253pub mod test {
254 use super::*;
255 use std::sync::atomic::{AtomicUsize, Ordering};
256
257 static ATOMIC_COUNTER: AtomicUsize = AtomicUsize::new(0);
259
260 pub fn new_mem_db() -> StorageDb {
261 error_support::init_for_tests();
262 let counter = ATOMIC_COUNTER.fetch_add(1, Ordering::Relaxed);
263 StorageDb::new_memory(&format!("test-api-{}", counter)).expect("should get an API")
264 }
265
266 pub fn new_mem_thread_safe_storage_db() -> Arc<ThreadSafeStorageDb> {
267 Arc::new(ThreadSafeStorageDb::new(new_mem_db()))
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::test::*;
274 use super::*;
275
276 #[test]
278 fn test_open() {
279 new_mem_db();
280 }
284
285 #[test]
286 fn test_meta() -> Result<()> {
287 let db = new_mem_db();
288 let conn = &db.get_connection()?;
289 assert_eq!(get_meta::<String>(conn, "foo")?, None);
290 put_meta(conn, "foo", &"bar".to_string())?;
291 assert_eq!(get_meta(conn, "foo")?, Some("bar".to_string()));
292 delete_meta(conn, "foo")?;
293 assert_eq!(get_meta::<String>(conn, "foo")?, None);
294 Ok(())
295 }
296}