1use crate::bookmark_sync::BookmarksSyncEngine;
6use crate::db::db::{PlacesDb, SharedPlacesDb};
7use crate::error::*;
8use crate::history_sync::HistorySyncEngine;
9use crate::util::normalize_path;
10use error_support::handle_error;
11use interrupt_support::register_interrupt;
12use lazy_static::lazy_static;
13use parking_lot::Mutex;
14use rusqlite::OpenFlags;
15use std::collections::HashMap;
16use std::path::{Path, PathBuf};
17use std::sync::{
18 atomic::{AtomicUsize, Ordering},
19 Arc, Weak,
20};
21use sync15::engine::{SyncEngine, SyncEngineId};
22
23lazy_static::lazy_static! {
25 static ref PLACES_API_FOR_SYNC_MANAGER: Mutex<Weak<PlacesApi>> = Mutex::new(Weak::new());
30}
31
32pub fn get_registered_sync_engine(engine_id: &SyncEngineId) -> Option<Box<dyn SyncEngine>> {
35 match PLACES_API_FOR_SYNC_MANAGER.lock().upgrade() {
36 None => {
37 warn!("places: get_registered_sync_engine: no PlacesApi registered");
38 None
39 }
40 Some(places_api) => match create_sync_engine(&places_api, engine_id) {
41 Ok(engine) => Some(engine),
42 Err(e) => {
43 if !matches!(e, Error::OpenDatabaseError(_)) {
48 error_support::report_error!(
49 "places-no-registered-sync-engine",
50 "places: get_registered_sync_engine: {}",
51 e
52 );
53 }
54 None
55 }
56 },
57 }
58}
59
60fn create_sync_engine(
61 places_api: &PlacesApi,
62 engine_id: &SyncEngineId,
63) -> Result<Box<dyn SyncEngine>> {
64 let conn = places_api.get_sync_connection()?;
65 match engine_id {
66 SyncEngineId::Bookmarks => Ok(Box::new(BookmarksSyncEngine::new(conn)?)),
67 SyncEngineId::History => Ok(Box::new(HistorySyncEngine::new(conn)?)),
68 _ => unreachable!("can't provide unknown engine: {}", engine_id),
69 }
70}
71
72#[repr(u8)]
73#[derive(Debug, Copy, Clone, PartialEq, Eq)]
74pub enum ConnectionType {
75 ReadOnly = 1,
76 ReadWrite = 2,
77 Sync = 3,
78}
79
80impl ConnectionType {
81 pub fn from_primitive(p: u8) -> Option<Self> {
82 match p {
83 1 => Some(ConnectionType::ReadOnly),
84 2 => Some(ConnectionType::ReadWrite),
85 3 => Some(ConnectionType::Sync),
86 _ => None,
87 }
88 }
89}
90
91impl ConnectionType {
92 pub fn rusqlite_flags(self) -> OpenFlags {
93 let common_flags = OpenFlags::SQLITE_OPEN_NO_MUTEX | OpenFlags::SQLITE_OPEN_URI;
94 match self {
95 ConnectionType::ReadOnly => common_flags | OpenFlags::SQLITE_OPEN_READ_ONLY,
96 ConnectionType::ReadWrite => {
97 common_flags | OpenFlags::SQLITE_OPEN_CREATE | OpenFlags::SQLITE_OPEN_READ_WRITE
98 }
99 ConnectionType::Sync => common_flags | OpenFlags::SQLITE_OPEN_READ_WRITE,
100 }
101 }
102}
103
104lazy_static! {
106 static ref APIS: Mutex<HashMap<PathBuf, Weak<PlacesApi>>> = Mutex::new(HashMap::new());
107}
108
109static ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
110
111#[handle_error(crate::Error)]
115pub fn places_api_new(db_name: impl AsRef<Path>) -> ApiResult<Arc<PlacesApi>> {
116 PlacesApi::new(db_name)
117}
118
119pub struct PlacesApi {
123 db_name: PathBuf,
124 write_connection: Mutex<Option<PlacesDb>>,
125 coop_tx_lock: Arc<Mutex<()>>,
126 sync_connection: Mutex<Weak<SharedPlacesDb>>,
135 id: usize,
136}
137
138impl PlacesApi {
139 pub fn new(db_name: impl AsRef<Path>) -> Result<Arc<Self>> {
141 let db_name = normalize_path(db_name)?;
142 Self::new_or_existing(db_name)
143 }
144
145 pub fn new_memory(db_name: &str) -> Result<Arc<Self>> {
149 let name = PathBuf::from(format!("file:{}?mode=memory&cache=shared", db_name));
150 Self::new_or_existing(name)
151 }
152 fn new_or_existing_into(
153 target: &mut HashMap<PathBuf, Weak<PlacesApi>>,
154 db_name: PathBuf,
155 ) -> Result<Arc<Self>> {
156 let id = ID_COUNTER.fetch_add(1, Ordering::SeqCst);
157 match target.get(&db_name).and_then(Weak::upgrade) {
158 Some(existing) => Ok(existing),
159 None => {
160 let coop_tx_lock = Arc::new(Mutex::new(()));
163 let connection = PlacesDb::open(
164 &db_name,
165 ConnectionType::ReadWrite,
166 id,
167 coop_tx_lock.clone(),
168 )?;
169 let new = PlacesApi {
170 db_name: db_name.clone(),
171 write_connection: Mutex::new(Some(connection)),
172 sync_connection: Mutex::new(Weak::new()),
173 id,
174 coop_tx_lock,
175 };
176 let arc = Arc::new(new);
177 target.insert(db_name, Arc::downgrade(&arc));
178 Ok(arc)
179 }
180 }
181 }
182
183 fn new_or_existing(db_name: PathBuf) -> Result<Arc<Self>> {
184 let mut guard = APIS.lock();
185 Self::new_or_existing_into(&mut guard, db_name)
186 }
187
188 pub fn open_connection(&self, conn_type: ConnectionType) -> Result<PlacesDb> {
190 match conn_type {
191 ConnectionType::ReadOnly => {
192 PlacesDb::open(
194 self.db_name.clone(),
195 ConnectionType::ReadOnly,
196 self.id,
197 self.coop_tx_lock.clone(),
198 )
199 }
200 ConnectionType::ReadWrite => {
201 let mut guard = self.write_connection.lock();
203 match guard.take() {
204 None => Err(Error::ConnectionAlreadyOpen),
205 Some(db) => Ok(db),
206 }
207 }
208 ConnectionType::Sync => {
209 panic!("Use `get_sync_connection` to open a sync connection");
210 }
211 }
212 }
213
214 pub fn get_sync_connection(&self) -> Result<Arc<SharedPlacesDb>> {
222 let mut conn = self.sync_connection.lock();
224 match conn.upgrade() {
225 Some(db) => Ok(db),
227 None => {
229 let db = Arc::new(SharedPlacesDb::new(PlacesDb::open(
230 self.db_name.clone(),
231 ConnectionType::Sync,
232 self.id,
233 self.coop_tx_lock.clone(),
234 )?));
235 register_interrupt(Arc::<SharedPlacesDb>::downgrade(&db));
236 *conn = Arc::downgrade(&db);
238 Ok(db)
239 }
240 }
241 }
242
243 pub fn close_connection(&self, connection: PlacesDb) -> Result<()> {
246 if connection.api_id() != self.id {
247 return Err(Error::WrongApiForClose);
248 }
249 if connection.conn_type() == ConnectionType::ReadWrite {
250 let mut guard = self.write_connection.lock();
252 assert!((*guard).is_none());
253 *guard = Some(connection);
254 }
255 Ok(())
256 }
257
258 pub fn register_with_sync_manager(self: Arc<Self>) {
264 *PLACES_API_FOR_SYNC_MANAGER.lock() = Arc::downgrade(&self);
265 }
266}
267
268#[cfg(test)]
269pub mod test {
270 use super::*;
271 use std::sync::atomic::{AtomicUsize, Ordering};
272
273 static ATOMIC_COUNTER: AtomicUsize = AtomicUsize::new(0);
275
276 pub fn new_mem_api() -> Arc<PlacesApi> {
277 error_support::init_for_tests();
280
281 let counter = ATOMIC_COUNTER.fetch_add(1, Ordering::Relaxed);
282 PlacesApi::new_memory(&format!("test-api-{}", counter)).expect("should get an API")
283 }
284
285 pub fn new_mem_connection() -> PlacesDb {
286 new_mem_api()
287 .open_connection(ConnectionType::ReadWrite)
288 .expect("should get a connection")
289 }
290
291 pub struct MemConnections {
292 pub read: PlacesDb,
293 pub write: PlacesDb,
294 pub api: Arc<PlacesApi>,
295 }
296
297 pub fn new_mem_connections() -> MemConnections {
298 let api = new_mem_api();
299 let read = api
300 .open_connection(ConnectionType::ReadOnly)
301 .expect("should get a read connection");
302 let write = api
303 .open_connection(ConnectionType::ReadWrite)
304 .expect("should get a write connection");
305 MemConnections { read, write, api }
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::test::*;
312 use super::*;
313 use sql_support::ConnExt;
314
315 #[test]
316 fn test_multi_writers_fails() {
317 let api = new_mem_api();
318 let writer1 = api
319 .open_connection(ConnectionType::ReadWrite)
320 .expect("should get writer");
321 api.open_connection(ConnectionType::ReadWrite)
322 .expect_err("should fail to get second writer");
323 api.close_connection(writer1)
325 .expect("should be able to close");
326 api.open_connection(ConnectionType::ReadWrite)
327 .expect("should get a writer after closing the other");
328 }
329
330 #[test]
331 fn test_shared_memory() {
332 let api = new_mem_api();
333 let writer = api
334 .open_connection(ConnectionType::ReadWrite)
335 .expect("should get writer");
336 writer
337 .execute_batch(
338 "CREATE TABLE test_table (test_value INTEGER);
339 INSERT INTO test_table VALUES (999)",
340 )
341 .expect("should insert");
342 let reader = api
343 .open_connection(ConnectionType::ReadOnly)
344 .expect("should get reader");
345 let val = reader
346 .conn_ext_query_one::<i64>("SELECT test_value FROM test_table")
347 .expect("should get value");
348 assert_eq!(val, 999);
349 }
350
351 #[test]
352 fn test_reader_before_writer() {
353 let api = new_mem_api();
354 let reader = api
355 .open_connection(ConnectionType::ReadOnly)
356 .expect("should get reader");
357 let writer = api
358 .open_connection(ConnectionType::ReadWrite)
359 .expect("should get writer");
360 writer
361 .execute_batch(
362 "CREATE TABLE test_table (test_value INTEGER);
363 INSERT INTO test_table VALUES (999)",
364 )
365 .expect("should insert");
366 let val = reader
367 .conn_ext_query_one::<i64>("SELECT test_value FROM test_table")
368 .expect("should get value");
369 assert_eq!(val, 999);
370 }
371
372 #[test]
373 fn test_wrong_writer_close() {
374 let api = new_mem_api();
375 let _writer = api
377 .open_connection(ConnectionType::ReadWrite)
378 .expect("should get writer");
379
380 let fake_api = new_mem_api();
381 let fake_writer = fake_api
382 .open_connection(ConnectionType::ReadWrite)
383 .expect("should get writer 2");
384
385 assert!(matches!(
386 api.close_connection(fake_writer).unwrap_err(),
387 Error::WrongApiForClose
388 ));
389 }
390
391 #[test]
392 fn test_valid_writer_close() {
393 let api = new_mem_api();
394 let writer = api
395 .open_connection(ConnectionType::ReadWrite)
396 .expect("should get writer");
397
398 api.close_connection(writer)
399 .expect("Should allow closing own connection");
400
401 assert!(api.open_connection(ConnectionType::ReadWrite).is_ok());
403 }
404}