1use std::{
33 borrow::Cow,
34 path::Path,
35 sync::atomic::{AtomicUsize, Ordering},
36};
37
38use rusqlite::{
39 Connection, Error as RusqliteError, ErrorCode, OpenFlags, Transaction, TransactionBehavior,
40};
41use thiserror::Error;
42
43use crate::ConnExt;
44use crate::{debug, info, warn};
45
46#[derive(Error, Debug)]
47pub enum Error {
48 #[error("Incompatible database version: {0}")]
49 IncompatibleVersion(u32),
50 #[error("Database is corrupt")]
51 Corrupt,
52 #[error("Error executing SQL: {0}")]
53 SqlError(rusqlite::Error),
54 #[error("Failed to recover a corrupt database due to an error deleting the file: {0}")]
55 RecoveryError(std::io::Error),
56 #[error("In shutdown mode")]
57 Shutdown,
58}
59
60impl From<rusqlite::Error> for Error {
61 fn from(value: rusqlite::Error) -> Self {
62 match value {
63 RusqliteError::SqliteFailure(e, _)
64 if matches!(e.code, ErrorCode::DatabaseCorrupt | ErrorCode::NotADatabase) =>
65 {
66 Self::Corrupt
67 }
68 _ => Self::SqlError(value),
69 }
70 }
71}
72
73pub type Result<T> = std::result::Result<T, Error>;
74
75pub trait ConnectionInitializer {
76 const NAME: &'static str;
78
79 const END_VERSION: u32;
81
82 fn init(&self, tx: &Transaction<'_>) -> Result<()>;
85
86 fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()>;
88
89 fn prepare(&self, _conn: &Connection, _db_empty: bool) -> Result<()> {
92 Ok(())
93 }
94
95 fn finish(&self, _conn: &Connection) -> Result<()> {
99 Ok(())
100 }
101}
102
103pub fn open_database<CI: ConnectionInitializer, P: AsRef<Path>>(
104 path: P,
105 connection_initializer: &CI,
106) -> Result<Connection> {
107 open_database_with_flags(path, OpenFlags::default(), connection_initializer)
108}
109
110pub fn open_memory_database<CI: ConnectionInitializer>(
111 conn_initializer: &CI,
112) -> Result<Connection> {
113 open_memory_database_with_flags(OpenFlags::default(), conn_initializer)
114}
115
116pub fn open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>(
117 path: P,
118 open_flags: OpenFlags,
119 connection_initializer: &CI,
120) -> Result<Connection> {
121 do_open_database_with_flags(&path, open_flags, connection_initializer).or_else(|e| {
122 try_handle_db_failure(&path, open_flags, connection_initializer, e)?;
124 do_open_database_with_flags(&path, open_flags, connection_initializer)
125 })
126}
127
128pub fn read_write_flags() -> OpenFlags {
130 OpenFlags::SQLITE_OPEN_URI
131 | OpenFlags::SQLITE_OPEN_NO_MUTEX
132 | OpenFlags::SQLITE_OPEN_CREATE
133 | OpenFlags::SQLITE_OPEN_READ_WRITE
134}
135
136pub fn read_only_flags() -> OpenFlags {
138 OpenFlags::SQLITE_OPEN_URI | OpenFlags::SQLITE_OPEN_NO_MUTEX | OpenFlags::SQLITE_OPEN_READ_ONLY
139}
140
141fn do_open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>(
142 path: P,
143 open_flags: OpenFlags,
144 connection_initializer: &CI,
145) -> Result<Connection> {
146 debug!("{}: opening database", CI::NAME);
148 let mut conn = Connection::open_with_flags(path, open_flags)?;
149 debug!("{}: checking if initialization is necessary", CI::NAME);
150 let db_empty = is_db_empty(&conn)?;
151
152 debug!("{}: preparing", CI::NAME);
153 connection_initializer.prepare(&conn, db_empty)?;
154
155 if open_flags.contains(OpenFlags::SQLITE_OPEN_READ_WRITE) {
156 let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate)?;
157 if db_empty {
158 debug!("{}: initializing new database", CI::NAME);
159 connection_initializer.init(&tx)?;
160 } else {
161 let mut current_version = get_schema_version(&tx)?;
162 if current_version > CI::END_VERSION {
163 return Err(Error::IncompatibleVersion(current_version));
164 }
165 while current_version < CI::END_VERSION {
166 debug!(
167 "{}: upgrading database to {}",
168 CI::NAME,
169 current_version + 1
170 );
171 connection_initializer.upgrade_from(&tx, current_version)?;
172 current_version += 1;
173 }
174 }
175 debug!("{}: finishing writable database open", CI::NAME);
176 connection_initializer.finish(&tx)?;
177 set_schema_version(&tx, CI::END_VERSION)?;
178 tx.commit()?;
179 } else {
180 assert!(!db_empty, "existing writer must have initialized");
183 assert!(
184 get_schema_version(&conn)? == CI::END_VERSION,
185 "existing writer must have migrated"
186 );
187 debug!("{}: finishing readonly database open", CI::NAME);
188 connection_initializer.finish(&conn)?;
189 }
190 debug!("{}: database open successful", CI::NAME);
191 Ok(conn)
192}
193
194pub fn open_memory_database_with_flags<CI: ConnectionInitializer>(
195 flags: OpenFlags,
196 conn_initializer: &CI,
197) -> Result<Connection> {
198 open_database_with_flags(":memory:", flags, conn_initializer)
199}
200
201fn try_handle_db_failure<CI: ConnectionInitializer, P: AsRef<Path>>(
207 path: P,
208 open_flags: OpenFlags,
209 _connection_initializer: &CI,
210 err: Error,
211) -> Result<()> {
212 if !open_flags.contains(OpenFlags::SQLITE_OPEN_CREATE)
213 && matches!(err, Error::SqlError(rusqlite::Error::SqliteFailure(code, _)) if code.code == rusqlite::ErrorCode::CannotOpen)
214 {
215 info!(
216 "{}: database doesn't exist, but we weren't requested to create it",
217 CI::NAME
218 );
219 return Err(err);
220 }
221 warn!("{}: database operation failed: {}", CI::NAME, err);
222 if !open_flags.contains(OpenFlags::SQLITE_OPEN_READ_WRITE) {
223 warn!(
224 "{}: not attempting recovery as this is a read-only connection request",
225 CI::NAME
226 );
227 return Err(err);
228 }
229
230 let delete = matches!(err, Error::Corrupt);
231 if delete {
232 info!(
233 "{}: the database is fatally damaged; deleting and starting fresh",
234 CI::NAME
235 );
236 if let Err(io_err) = std::fs::remove_file(path) {
240 return Err(Error::RecoveryError(io_err));
241 }
242 Ok(())
243 } else {
244 Err(err)
245 }
246}
247
248fn is_db_empty(conn: &Connection) -> Result<bool> {
249 Ok(conn.query_one::<u32>("SELECT COUNT(*) FROM sqlite_master")? == 0)
250}
251
252fn get_schema_version(conn: &Connection) -> Result<u32> {
253 let version = conn.query_row_and_then("PRAGMA user_version", [], |row| row.get(0))?;
254 Ok(version)
255}
256
257fn set_schema_version(conn: &Connection, version: u32) -> Result<()> {
258 conn.set_pragma("user_version", version)?;
259 Ok(())
260}
261
262pub fn unique_in_memory_db_path() -> String {
266 static COUNTER: AtomicUsize = AtomicUsize::new(0);
267 format!(
268 "file:in-memory-db-{}?mode=memory&cache=shared",
269 COUNTER.fetch_add(1, Ordering::Relaxed)
270 )
271}
272
273pub mod test_utils {
276 use super::*;
277 use std::{
278 cell::RefCell,
279 collections::{HashMap, HashSet},
280 path::PathBuf,
281 };
282 use tempfile::TempDir;
283
284 pub struct TestConnectionInitializer {
285 pub calls: RefCell<Vec<&'static str>>,
286 pub buggy_v3_upgrade: bool,
287 }
288
289 impl Default for TestConnectionInitializer {
290 fn default() -> Self {
291 Self::new()
292 }
293 }
294
295 impl TestConnectionInitializer {
296 pub fn new() -> Self {
297 Self {
298 calls: RefCell::new(Vec::new()),
299 buggy_v3_upgrade: false,
300 }
301 }
302 pub fn new_with_buggy_logic() -> Self {
303 Self {
304 calls: RefCell::new(Vec::new()),
305 buggy_v3_upgrade: true,
306 }
307 }
308
309 pub fn clear_calls(&self) {
310 self.calls.borrow_mut().clear();
311 }
312
313 pub fn push_call(&self, call: &'static str) {
314 self.calls.borrow_mut().push(call);
315 }
316
317 pub fn check_calls(&self, expected: Vec<&'static str>) {
318 assert_eq!(*self.calls.borrow(), expected);
319 }
320 }
321
322 impl ConnectionInitializer for TestConnectionInitializer {
323 const NAME: &'static str = "test db";
324 const END_VERSION: u32 = 4;
325
326 fn prepare(&self, conn: &Connection, _: bool) -> Result<()> {
327 self.push_call("prep");
328 conn.execute_batch(
329 "
330 PRAGMA journal_mode = wal;
331 ",
332 )?;
333 Ok(())
334 }
335
336 fn init(&self, conn: &Transaction<'_>) -> Result<()> {
337 self.push_call("init");
338 conn.execute_batch(
339 "
340 CREATE TABLE prep_table(col);
341 INSERT INTO prep_table(col) VALUES ('correct-value');
342 CREATE TABLE my_table(col);
343 ",
344 )
345 .map_err(|e| e.into())
346 }
347
348 fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()> {
349 match version {
350 1 => {
353 self.push_call("upgrade_from_v1");
354 Err(Error::Corrupt)
355 }
356 2 => {
357 self.push_call("upgrade_from_v2");
358 conn.execute_batch(
359 "
360 ALTER TABLE my_old_table_name RENAME TO my_table;
361 ",
362 )?;
363 Ok(())
364 }
365 3 => {
366 self.push_call("upgrade_from_v3");
367
368 if self.buggy_v3_upgrade {
369 conn.execute_batch("ILLEGAL_SQL_CODE")?;
370 }
371
372 conn.execute_batch(
373 "
374 ALTER TABLE my_table RENAME COLUMN old_col to col;
375 ",
376 )?;
377 Ok(())
378 }
379 _ => {
380 panic!("Unexpected version: {}", version);
381 }
382 }
383 }
384
385 fn finish(&self, conn: &Connection) -> Result<()> {
386 self.push_call("finish");
387 conn.execute_batch(
388 "
389 INSERT INTO my_table(col) SELECT col FROM prep_table;
390 ",
391 )?;
392 Ok(())
393 }
394 }
395
396 pub struct MigratedDatabaseFile<CI: ConnectionInitializer> {
402 _tempdir: TempDir,
405 pub connection_initializer: CI,
406 pub path: PathBuf,
407 }
408
409 impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> {
410 pub fn new(connection_initializer: CI, init_sql: &str) -> Self {
411 Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default())
412 }
413
414 pub fn new_with_flags(
415 connection_initializer: CI,
416 init_sql: &str,
417 open_flags: OpenFlags,
418 ) -> Self {
419 let tempdir = tempfile::tempdir().unwrap();
420 let path = tempdir.path().join(Path::new("db.sql"));
421 let conn = Connection::open_with_flags(&path, open_flags).unwrap();
422 conn.execute_batch(init_sql).unwrap();
423 Self {
424 _tempdir: tempdir,
425 connection_initializer,
426 path,
427 }
428 }
429
430 pub fn upgrade_to(&self, version: u32) {
434 let mut conn = self.open();
435 let tx = conn.transaction().unwrap();
436 let mut current_version = get_schema_version(&tx).unwrap();
437 while current_version < version {
438 self.connection_initializer
439 .upgrade_from(&tx, current_version)
440 .unwrap();
441 current_version += 1;
442 }
443 set_schema_version(&tx, current_version).unwrap();
444 self.connection_initializer.finish(&tx).unwrap();
445 tx.commit().unwrap();
446 }
447
448 pub fn run_all_upgrades(&self) {
452 let current_version = get_schema_version(&self.open()).unwrap();
453 for version in current_version..CI::END_VERSION {
454 self.upgrade_to(version + 1);
455 }
456 }
457
458 pub fn assert_schema_matches_new_database(&self) {
459 let db = self.open();
460 let new_db = match open_memory_database(&self.connection_initializer) {
461 Ok(db) => db,
462 Err(e) => panic!("Creating new database failed:\n{e}"),
463 };
464
465 compare_sql_maps("table", get_sql(&db, "table"), get_sql(&new_db, "table"));
466 compare_sql_maps("index", get_sql(&db, "index"), get_sql(&new_db, "index"));
467 compare_sql_maps(
468 "trigger",
469 get_sql(&db, "trigger"),
470 get_sql(&new_db, "trigger"),
471 );
472 }
473
474 pub fn open(&self) -> Connection {
475 Connection::open(&self.path).unwrap()
476 }
477 }
478
479 fn get_sql(conn: &Connection, type_: &str) -> HashMap<String, Option<String>> {
480 conn.query_rows_and_then(
481 "SELECT name, sql FROM sqlite_master WHERE type=?",
482 (type_,),
483 |row| -> rusqlite::Result<(String, Option<String>)> { Ok((row.get(0)?, row.get(1)?)) },
484 )
485 .unwrap()
486 .into_iter()
487 .collect()
488 }
489
490 fn compare_sql_maps(
491 type_: &str,
492 old_items: HashMap<String, Option<String>>,
493 new_items: HashMap<String, Option<String>>,
494 ) {
495 let old_db_keys: HashSet<&String> = old_items.keys().collect();
496 let new_db_keys: HashSet<&String> = new_items.keys().collect();
497
498 let old_db_extra_keys = Vec::from_iter(old_db_keys.difference(&new_db_keys));
499 if !old_db_extra_keys.is_empty() {
500 panic!("Extra keys not present in new database for {type_}: {old_db_extra_keys:?}");
501 }
502 let new_db_extra_keys = Vec::from_iter(new_db_keys.difference(&old_db_keys));
503 if !new_db_extra_keys.is_empty() {
504 panic!("Extra keys only present in new database for {type_}: {new_db_extra_keys:?}");
505 }
506 for key in old_db_keys {
507 assert_eq!(
508 old_items.get(key).unwrap().as_deref().map(normalize),
509 new_items.get(key).unwrap().as_deref().map(normalize),
510 "sql differs for {type_} {key}"
511 );
512 }
513 }
514
515 fn normalize(sql: &str) -> String {
517 sql.split('\'')
518 .enumerate()
519 .map(|(i, part)| {
520 if (i % 2) == 0 {
525 Cow::Owned(part.split_whitespace().collect::<Vec<_>>().join(" "))
526 } else {
527 Cow::Borrowed(part)
528 }
529 })
530 .collect()
531 }
532}
533
534#[cfg(test)]
535mod test {
536 use super::test_utils::{MigratedDatabaseFile, TestConnectionInitializer};
537 use super::*;
538 use std::io::Write;
539
540 static INIT_V1: &str = "
543 CREATE TABLE prep_table(col);
544 PRAGMA user_version=1;
545 ";
546
547 static INIT_V2: &str = "
549 CREATE TABLE prep_table(col);
550 INSERT INTO prep_table(col) VALUES ('correct-value');
551 CREATE TABLE my_old_table_name(old_col);
552 PRAGMA user_version=2;
553 ";
554
555 fn check_final_data(conn: &Connection) {
556 let value: String = conn
557 .query_row("SELECT col FROM my_table", [], |r| r.get(0))
558 .unwrap();
559 assert_eq!(value, "correct-value");
560 assert_eq!(get_schema_version(conn).unwrap(), 4);
561 }
562
563 #[test]
564 fn test_init() {
565 let connection_initializer = TestConnectionInitializer::new();
566 let conn = open_memory_database(&connection_initializer).unwrap();
567 check_final_data(&conn);
568 connection_initializer.check_calls(vec!["prep", "init", "finish"]);
569 }
570
571 #[test]
572 fn test_upgrades() {
573 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
574 let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
575 check_final_data(&conn);
576 db_file.connection_initializer.check_calls(vec![
577 "prep",
578 "upgrade_from_v2",
579 "upgrade_from_v3",
580 "finish",
581 ]);
582 }
583
584 #[test]
585 fn test_open_current_version() {
586 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
587 db_file.upgrade_to(4);
588 db_file.connection_initializer.clear_calls();
589 let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
590 check_final_data(&conn);
591 db_file
592 .connection_initializer
593 .check_calls(vec!["prep", "finish"]);
594 }
595
596 #[test]
597 fn test_pragmas() {
598 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
599 let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
600 assert_eq!(
601 conn.query_one::<String>("PRAGMA journal_mode").unwrap(),
602 "wal"
603 );
604 }
605
606 #[test]
607 fn test_migration_error() {
608 let db_file =
609 MigratedDatabaseFile::new(TestConnectionInitializer::new_with_buggy_logic(), INIT_V2);
610 db_file
611 .open()
612 .execute(
613 "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')",
614 [],
615 )
616 .unwrap();
617
618 open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap_err();
619 assert_eq!(
622 db_file
623 .open()
624 .query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name")
625 .unwrap(),
626 1
627 );
628 }
629
630 #[test]
631 fn test_version_too_new() {
632 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
633 set_schema_version(&db_file.open(), 5).unwrap();
634
635 db_file
636 .open()
637 .execute(
638 "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')",
639 [],
640 )
641 .unwrap();
642
643 assert!(matches!(
644 open_database(db_file.path.clone(), &db_file.connection_initializer,),
645 Err(Error::IncompatibleVersion(5))
646 ));
647 assert_eq!(
650 db_file
651 .open()
652 .query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name")
653 .unwrap(),
654 1
655 );
656 }
657
658 #[test]
659 fn test_corrupt_db() {
660 let tempdir = tempfile::tempdir().unwrap();
661 let path = tempdir.path().join(Path::new("corrupt-db.sql"));
662 let mut file = std::fs::File::create(path.clone()).unwrap();
663 file.write_all(b"not sql").unwrap();
668 let metadata = std::fs::metadata(path.clone()).unwrap();
669 assert_eq!(metadata.len(), 7);
670 drop(file);
671 open_database(path.clone(), &TestConnectionInitializer::new()).unwrap();
672 let metadata = std::fs::metadata(path).unwrap();
673 assert_ne!(metadata.len(), 7);
675 }
676
677 #[test]
678 fn test_force_replace() {
679 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V1);
680 let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
681 check_final_data(&conn);
682 db_file.connection_initializer.check_calls(vec![
683 "prep",
684 "upgrade_from_v1",
685 "prep",
686 "init",
687 "finish",
688 ]);
689 }
690}