1use std::{
33 borrow::Cow,
34 path::Path,
35 sync::atomic::{AtomicUsize, Ordering},
36};
37
38use rusqlite::{
39 Connection, Error as RusqliteError, ErrorCode, OpenFlags, Transaction, TransactionBehavior,
40};
41use thiserror::Error;
42
43use crate::ConnExt;
44use crate::{debug, info, warn};
45
46#[derive(Error, Debug)]
47pub enum Error {
48 #[error("Incompatible database version: {0}")]
49 IncompatibleVersion(u32),
50 #[error("Database is corrupt")]
51 Corrupt,
52 #[error("Error executing SQL: {0}")]
53 SqlError(rusqlite::Error),
54 #[error("Failed to recover a corrupt database due to an error deleting the file: {0}")]
55 RecoveryError(std::io::Error),
56 #[error("In shutdown mode")]
57 Shutdown,
58}
59
60impl From<rusqlite::Error> for Error {
61 fn from(value: rusqlite::Error) -> Self {
62 match value {
63 RusqliteError::SqliteFailure(e, _)
64 if matches!(e.code, ErrorCode::DatabaseCorrupt | ErrorCode::NotADatabase) =>
65 {
66 Self::Corrupt
67 }
68 _ => Self::SqlError(value),
69 }
70 }
71}
72
73pub type Result<T> = std::result::Result<T, Error>;
74
75pub trait ConnectionInitializer {
76 const NAME: &'static str;
78
79 const END_VERSION: u32;
81
82 fn init(&self, tx: &Transaction<'_>) -> Result<()>;
85
86 fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()>;
88
89 fn prepare(&self, _conn: &Connection, _db_empty: bool) -> Result<()> {
92 Ok(())
93 }
94
95 fn finish(&self, _conn: &Connection) -> Result<()> {
99 Ok(())
100 }
101}
102
103pub fn open_database<CI: ConnectionInitializer, P: AsRef<Path>>(
104 path: P,
105 connection_initializer: &CI,
106) -> Result<Connection> {
107 open_database_with_flags(path, OpenFlags::default(), connection_initializer)
108}
109
110pub fn open_memory_database<CI: ConnectionInitializer>(
111 conn_initializer: &CI,
112) -> Result<Connection> {
113 open_memory_database_with_flags(OpenFlags::default(), conn_initializer)
114}
115
116pub fn open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>(
117 path: P,
118 open_flags: OpenFlags,
119 connection_initializer: &CI,
120) -> Result<Connection> {
121 do_open_database_with_flags(&path, open_flags, connection_initializer).or_else(|e| {
122 try_handle_db_failure(&path, open_flags, connection_initializer, e)?;
124 do_open_database_with_flags(&path, open_flags, connection_initializer)
125 })
126}
127
128pub fn read_write_flags() -> OpenFlags {
130 OpenFlags::SQLITE_OPEN_URI
131 | OpenFlags::SQLITE_OPEN_NO_MUTEX
132 | OpenFlags::SQLITE_OPEN_CREATE
133 | OpenFlags::SQLITE_OPEN_READ_WRITE
134}
135
136pub fn read_only_flags() -> OpenFlags {
138 OpenFlags::SQLITE_OPEN_URI | OpenFlags::SQLITE_OPEN_NO_MUTEX | OpenFlags::SQLITE_OPEN_READ_ONLY
139}
140
141fn do_open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>(
142 path: P,
143 open_flags: OpenFlags,
144 connection_initializer: &CI,
145) -> Result<Connection> {
146 debug!("{}: opening database", CI::NAME);
148 let mut conn = Connection::open_with_flags(path, open_flags)?;
149 debug!("{}: checking if initialization is necessary", CI::NAME);
150 let db_empty = is_db_empty(&conn)?;
151
152 debug!("{}: preparing", CI::NAME);
153 connection_initializer.prepare(&conn, db_empty)?;
154
155 if open_flags.contains(OpenFlags::SQLITE_OPEN_READ_WRITE) {
156 let mut write_schema_version = true;
157 let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate)?;
158 if db_empty {
159 debug!("{}: initializing new database", CI::NAME);
160 connection_initializer.init(&tx)?;
161 } else {
162 let mut current_version = get_schema_version(&tx)?;
163 if current_version > CI::END_VERSION {
164 return Err(Error::IncompatibleVersion(current_version));
165 } else if current_version == CI::END_VERSION {
166 write_schema_version = false;
167 } else {
168 while current_version < CI::END_VERSION {
169 debug!(
170 "{}: upgrading database to {}",
171 CI::NAME,
172 current_version + 1
173 );
174 connection_initializer.upgrade_from(&tx, current_version)?;
175 current_version += 1;
176 }
177 }
178 }
179 debug!("{}: finishing writable database open", CI::NAME);
180 connection_initializer.finish(&tx)?;
181 if write_schema_version {
182 set_schema_version(&tx, CI::END_VERSION)?;
183 }
184 tx.commit()?;
185 } else {
186 assert!(!db_empty, "existing writer must have initialized");
189 assert!(
190 get_schema_version(&conn)? == CI::END_VERSION,
191 "existing writer must have migrated"
192 );
193 debug!("{}: finishing readonly database open", CI::NAME);
194 connection_initializer.finish(&conn)?;
195 }
196 debug!("{}: database open successful", CI::NAME);
197 Ok(conn)
198}
199
200pub fn open_memory_database_with_flags<CI: ConnectionInitializer>(
201 flags: OpenFlags,
202 conn_initializer: &CI,
203) -> Result<Connection> {
204 open_database_with_flags(":memory:", flags, conn_initializer)
205}
206
207fn try_handle_db_failure<CI: ConnectionInitializer, P: AsRef<Path>>(
213 path: P,
214 open_flags: OpenFlags,
215 _connection_initializer: &CI,
216 err: Error,
217) -> Result<()> {
218 if !open_flags.contains(OpenFlags::SQLITE_OPEN_CREATE)
219 && matches!(err, Error::SqlError(rusqlite::Error::SqliteFailure(code, _)) if code.code == rusqlite::ErrorCode::CannotOpen)
220 {
221 info!(
222 "{}: database doesn't exist, but we weren't requested to create it",
223 CI::NAME
224 );
225 return Err(err);
226 }
227 warn!("{}: database operation failed: {}", CI::NAME, err);
228 if !open_flags.contains(OpenFlags::SQLITE_OPEN_READ_WRITE) {
229 warn!(
230 "{}: not attempting recovery as this is a read-only connection request",
231 CI::NAME
232 );
233 return Err(err);
234 }
235
236 let delete = matches!(err, Error::Corrupt);
237 if delete {
238 info!(
239 "{}: the database is fatally damaged; deleting and starting fresh",
240 CI::NAME
241 );
242 if let Err(io_err) = std::fs::remove_file(path) {
246 return Err(Error::RecoveryError(io_err));
247 }
248 Ok(())
249 } else {
250 Err(err)
251 }
252}
253
254fn is_db_empty(conn: &Connection) -> Result<bool> {
255 Ok(conn.conn_ext_query_one::<u32>("SELECT COUNT(*) FROM sqlite_master")? == 0)
256}
257
258fn get_schema_version(conn: &Connection) -> Result<u32> {
259 let version = conn.query_row_and_then("PRAGMA user_version", [], |row| row.get(0))?;
260 Ok(version)
261}
262
263fn set_schema_version(conn: &Connection, version: u32) -> Result<()> {
264 conn.set_pragma("user_version", version)?;
265 Ok(())
266}
267
268pub fn unique_in_memory_db_path() -> String {
272 static COUNTER: AtomicUsize = AtomicUsize::new(0);
273 format!(
274 "file:in-memory-db-{}?mode=memory&cache=shared",
275 COUNTER.fetch_add(1, Ordering::Relaxed)
276 )
277}
278
279pub mod test_utils {
282 use super::*;
283 use std::{
284 cell::RefCell,
285 collections::{HashMap, HashSet},
286 path::PathBuf,
287 };
288 use tempfile::TempDir;
289
290 pub struct TestConnectionInitializer {
291 pub calls: RefCell<Vec<&'static str>>,
292 pub buggy_v3_upgrade: bool,
293 }
294
295 impl Default for TestConnectionInitializer {
296 fn default() -> Self {
297 Self::new()
298 }
299 }
300
301 impl TestConnectionInitializer {
302 pub fn new() -> Self {
303 Self {
304 calls: RefCell::new(Vec::new()),
305 buggy_v3_upgrade: false,
306 }
307 }
308 pub fn new_with_buggy_logic() -> Self {
309 Self {
310 calls: RefCell::new(Vec::new()),
311 buggy_v3_upgrade: true,
312 }
313 }
314
315 pub fn clear_calls(&self) {
316 self.calls.borrow_mut().clear();
317 }
318
319 pub fn push_call(&self, call: &'static str) {
320 self.calls.borrow_mut().push(call);
321 }
322
323 pub fn check_calls(&self, expected: Vec<&'static str>) {
324 assert_eq!(*self.calls.borrow(), expected);
325 }
326 }
327
328 impl ConnectionInitializer for TestConnectionInitializer {
329 const NAME: &'static str = "test db";
330 const END_VERSION: u32 = 4;
331
332 fn prepare(&self, conn: &Connection, _: bool) -> Result<()> {
333 self.push_call("prep");
334 conn.execute_batch(
335 "
336 PRAGMA journal_mode = wal;
337 ",
338 )?;
339 Ok(())
340 }
341
342 fn init(&self, conn: &Transaction<'_>) -> Result<()> {
343 self.push_call("init");
344 conn.execute_batch(
345 "
346 CREATE TABLE prep_table(col);
347 INSERT INTO prep_table(col) VALUES ('correct-value');
348 CREATE TABLE my_table(col);
349 ",
350 )
351 .map_err(|e| e.into())
352 }
353
354 fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()> {
355 match version {
356 1 => {
359 self.push_call("upgrade_from_v1");
360 Err(Error::Corrupt)
361 }
362 2 => {
363 self.push_call("upgrade_from_v2");
364 conn.execute_batch(
365 "
366 ALTER TABLE my_old_table_name RENAME TO my_table;
367 ",
368 )?;
369 Ok(())
370 }
371 3 => {
372 self.push_call("upgrade_from_v3");
373
374 if self.buggy_v3_upgrade {
375 conn.execute_batch("ILLEGAL_SQL_CODE")?;
376 }
377
378 conn.execute_batch(
379 "
380 ALTER TABLE my_table RENAME COLUMN old_col to col;
381 ",
382 )?;
383 Ok(())
384 }
385 _ => {
386 panic!("Unexpected version: {}", version);
387 }
388 }
389 }
390
391 fn finish(&self, conn: &Connection) -> Result<()> {
392 self.push_call("finish");
393 conn.execute_batch(
394 "
395 INSERT INTO my_table(col) SELECT col FROM prep_table;
396 ",
397 )?;
398 Ok(())
399 }
400 }
401
402 pub struct MigratedDatabaseFile<CI: ConnectionInitializer> {
408 _tempdir: TempDir,
411 pub connection_initializer: CI,
412 pub path: PathBuf,
413 }
414
415 impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> {
416 pub fn new(connection_initializer: CI, init_sql: &str) -> Self {
417 Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default())
418 }
419
420 pub fn new_with_flags(
421 connection_initializer: CI,
422 init_sql: &str,
423 open_flags: OpenFlags,
424 ) -> Self {
425 let tempdir = tempfile::tempdir().unwrap();
426 let path = tempdir.path().join(Path::new("db.sql"));
427 let conn = Connection::open_with_flags(&path, open_flags).unwrap();
428 conn.execute_batch(init_sql).unwrap();
429 Self {
430 _tempdir: tempdir,
431 connection_initializer,
432 path,
433 }
434 }
435
436 pub fn upgrade_to(&self, version: u32) {
440 let mut conn = self.open();
441 let tx = conn.transaction().unwrap();
442 let mut current_version = get_schema_version(&tx).unwrap();
443 while current_version < version {
444 self.connection_initializer
445 .upgrade_from(&tx, current_version)
446 .unwrap();
447 current_version += 1;
448 }
449 set_schema_version(&tx, current_version).unwrap();
450 self.connection_initializer.finish(&tx).unwrap();
451 tx.commit().unwrap();
452 }
453
454 pub fn run_all_upgrades(&self) {
458 let current_version = get_schema_version(&self.open()).unwrap();
459 for version in current_version..CI::END_VERSION {
460 self.upgrade_to(version + 1);
461 }
462 }
463
464 pub fn assert_schema_matches_new_database(&self) {
465 let db = self.open();
466 let new_db = match open_memory_database(&self.connection_initializer) {
467 Ok(db) => db,
468 Err(e) => panic!("Creating new database failed:\n{e}"),
469 };
470
471 compare_sql_maps("table", get_sql(&db, "table"), get_sql(&new_db, "table"));
472 compare_sql_maps("index", get_sql(&db, "index"), get_sql(&new_db, "index"));
473 compare_sql_maps(
474 "trigger",
475 get_sql(&db, "trigger"),
476 get_sql(&new_db, "trigger"),
477 );
478 }
479
480 pub fn open(&self) -> Connection {
481 Connection::open(&self.path).unwrap()
482 }
483 }
484
485 fn get_sql(conn: &Connection, type_: &str) -> HashMap<String, Option<String>> {
486 conn.query_rows_and_then(
487 "SELECT name, sql FROM sqlite_master WHERE type=?",
488 (type_,),
489 |row| -> rusqlite::Result<(String, Option<String>)> { Ok((row.get(0)?, row.get(1)?)) },
490 )
491 .unwrap()
492 .into_iter()
493 .collect()
494 }
495
496 fn compare_sql_maps(
497 type_: &str,
498 old_items: HashMap<String, Option<String>>,
499 new_items: HashMap<String, Option<String>>,
500 ) {
501 let old_db_keys: HashSet<&String> = old_items.keys().collect();
502 let new_db_keys: HashSet<&String> = new_items.keys().collect();
503
504 let old_db_extra_keys = Vec::from_iter(old_db_keys.difference(&new_db_keys));
505 if !old_db_extra_keys.is_empty() {
506 panic!("Extra keys not present in new database for {type_}: {old_db_extra_keys:?}");
507 }
508 let new_db_extra_keys = Vec::from_iter(new_db_keys.difference(&old_db_keys));
509 if !new_db_extra_keys.is_empty() {
510 panic!("Extra keys only present in new database for {type_}: {new_db_extra_keys:?}");
511 }
512 for key in old_db_keys {
513 assert_eq!(
514 old_items.get(key).unwrap().as_deref().map(normalize),
515 new_items.get(key).unwrap().as_deref().map(normalize),
516 "sql differs for {type_} {key}"
517 );
518 }
519 }
520
521 fn normalize(sql: &str) -> String {
523 sql.split('\'')
524 .enumerate()
525 .map(|(i, part)| {
526 if (i % 2) == 0 {
531 Cow::Owned(part.split_whitespace().collect::<Vec<_>>().join(" "))
532 } else {
533 Cow::Borrowed(part)
534 }
535 })
536 .collect()
537 }
538}
539
540#[cfg(test)]
541mod test {
542 use super::test_utils::{MigratedDatabaseFile, TestConnectionInitializer};
543 use super::*;
544 use std::io::Write;
545
546 static INIT_V1: &str = "
549 CREATE TABLE prep_table(col);
550 PRAGMA user_version=1;
551 ";
552
553 static INIT_V2: &str = "
555 CREATE TABLE prep_table(col);
556 INSERT INTO prep_table(col) VALUES ('correct-value');
557 CREATE TABLE my_old_table_name(old_col);
558 PRAGMA user_version=2;
559 ";
560
561 fn check_final_data(conn: &Connection) {
562 let value: String = conn
563 .query_row("SELECT col FROM my_table", [], |r| r.get(0))
564 .unwrap();
565 assert_eq!(value, "correct-value");
566 assert_eq!(get_schema_version(conn).unwrap(), 4);
567 }
568
569 #[test]
570 fn test_init() {
571 let connection_initializer = TestConnectionInitializer::new();
572 let conn = open_memory_database(&connection_initializer).unwrap();
573 check_final_data(&conn);
574 connection_initializer.check_calls(vec!["prep", "init", "finish"]);
575 }
576
577 #[test]
578 fn test_upgrades() {
579 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
580 let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
581 check_final_data(&conn);
582 db_file.connection_initializer.check_calls(vec![
583 "prep",
584 "upgrade_from_v2",
585 "upgrade_from_v3",
586 "finish",
587 ]);
588 }
589
590 #[test]
591 fn test_open_current_version() {
592 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
593 db_file.upgrade_to(4);
594 db_file.connection_initializer.clear_calls();
595 let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
596 check_final_data(&conn);
597 db_file
598 .connection_initializer
599 .check_calls(vec!["prep", "finish"]);
600 }
601
602 #[test]
603 fn test_pragmas() {
604 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
605 let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
606 assert_eq!(
607 conn.conn_ext_query_one::<String>("PRAGMA journal_mode")
608 .unwrap(),
609 "wal"
610 );
611 }
612
613 #[test]
614 fn test_migration_error() {
615 let db_file =
616 MigratedDatabaseFile::new(TestConnectionInitializer::new_with_buggy_logic(), INIT_V2);
617 db_file
618 .open()
619 .execute(
620 "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')",
621 [],
622 )
623 .unwrap();
624
625 open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap_err();
626 assert_eq!(
629 db_file
630 .open()
631 .conn_ext_query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name")
632 .unwrap(),
633 1
634 );
635 }
636
637 #[test]
638 fn test_version_too_new() {
639 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
640 set_schema_version(&db_file.open(), 5).unwrap();
641
642 db_file
643 .open()
644 .execute(
645 "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')",
646 [],
647 )
648 .unwrap();
649
650 assert!(matches!(
651 open_database(db_file.path.clone(), &db_file.connection_initializer,),
652 Err(Error::IncompatibleVersion(5))
653 ));
654 assert_eq!(
657 db_file
658 .open()
659 .conn_ext_query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name")
660 .unwrap(),
661 1
662 );
663 }
664
665 #[test]
666 fn test_corrupt_db() {
667 let tempdir = tempfile::tempdir().unwrap();
668 let path = tempdir.path().join(Path::new("corrupt-db.sql"));
669 let mut file = std::fs::File::create(path.clone()).unwrap();
670 file.write_all(b"not sql").unwrap();
675 let metadata = std::fs::metadata(path.clone()).unwrap();
676 assert_eq!(metadata.len(), 7);
677 drop(file);
678 open_database(path.clone(), &TestConnectionInitializer::new()).unwrap();
679 let metadata = std::fs::metadata(path).unwrap();
680 assert_ne!(metadata.len(), 7);
682 }
683
684 #[test]
685 fn test_force_replace() {
686 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V1);
687 let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
688 check_final_data(&conn);
689 db_file.connection_initializer.check_calls(vec![
690 "prep",
691 "upgrade_from_v1",
692 "prep",
693 "init",
694 "finish",
695 ]);
696 }
697}