interrupt_support/
sql.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5use crate::{in_shutdown, Interrupted, Interruptee};
6use rusqlite::{Connection, InterruptHandle};
7use std::fmt;
8use std::sync::{
9    atomic::{AtomicUsize, Ordering},
10    Arc,
11};
12
13/// Interrupt operations that use SQL
14///
15/// Typical usage of this type:
16///   - Components typically create a wrapper class around an `rusqlite::Connection`
17///     (`PlacesConnection`, `LoginStore`, etc.)
18///   - The wrapper stores an `Arc<SqlInterruptHandle>`
19///   - The wrapper has a method that clones and returns that `Arc`.  This allows passing the interrupt
20///     handle to a different thread in order to interrupt a particular operation.
21///   - The wrapper calls `begin_interrupt_scope()` at the start of each operation.  The code that
22///     performs the operation periodically calls `err_if_interrupted()`.
23///   - Finally, the wrapper class implements `AsRef<SqlInterruptHandle>` and calls
24///     `register_interrupt()`.  This causes all operations to be interrupted when we enter
25///     shutdown mode.
26pub struct SqlInterruptHandle {
27    db_handle: InterruptHandle,
28    // Counter that we increment on each interrupt() call.
29    // We use Ordering::Relaxed to read/write to this variable.  This is safe because we're
30    // basically using it as a flag and don't need stronger synchronization guarantees.
31    interrupt_counter: Arc<AtomicUsize>,
32}
33
34impl SqlInterruptHandle {
35    #[inline]
36    pub fn new(conn: &Connection) -> Self {
37        Self {
38            db_handle: conn.get_interrupt_handle(),
39            interrupt_counter: Arc::new(AtomicUsize::new(0)),
40        }
41    }
42
43    /// Begin an interrupt scope that will be interrupted by this handle
44    ///
45    /// Returns Err(Interrupted) if we're in shutdown mode
46    #[inline]
47    pub fn begin_interrupt_scope(&self) -> Result<SqlInterruptScope, Interrupted> {
48        if in_shutdown() {
49            Err(Interrupted)
50        } else {
51            Ok(SqlInterruptScope::new(Arc::clone(&self.interrupt_counter)))
52        }
53    }
54
55    /// Interrupt all interrupt scopes created by this handle
56    #[inline]
57    pub fn interrupt(&self) {
58        self.interrupt_counter.fetch_add(1, Ordering::Relaxed);
59        self.db_handle.interrupt();
60    }
61}
62
63impl fmt::Debug for SqlInterruptHandle {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        f.debug_struct("SqlInterruptHandle")
66            .field(
67                "interrupt_counter",
68                &self.interrupt_counter.load(Ordering::Relaxed),
69            )
70            .finish()
71    }
72}
73
74/// Check if an operation has been interrupted
75///
76/// This is used by the rust code to check if an operation should fail because it was interrupted.
77/// It handles the case where we get interrupted outside of an SQL query.
78#[derive(Debug)]
79pub struct SqlInterruptScope {
80    start_value: usize,
81    interrupt_counter: Arc<AtomicUsize>,
82}
83
84impl SqlInterruptScope {
85    fn new(interrupt_counter: Arc<AtomicUsize>) -> Self {
86        let start_value = interrupt_counter.load(Ordering::Relaxed);
87        Self {
88            start_value,
89            interrupt_counter,
90        }
91    }
92
93    // Create an `SqlInterruptScope` that's never interrupted.
94    //
95    // This should only be used for testing purposes.
96    pub fn dummy() -> Self {
97        Self::new(Arc::new(AtomicUsize::new(0)))
98    }
99
100    /// Check if scope has been interrupted
101    #[inline]
102    pub fn was_interrupted(&self) -> bool {
103        self.interrupt_counter.load(Ordering::Relaxed) != self.start_value
104    }
105
106    /// Return Err(Interrupted) if we were interrupted
107    #[inline]
108    pub fn err_if_interrupted(&self) -> Result<(), Interrupted> {
109        if self.was_interrupted() {
110            Err(Interrupted)
111        } else {
112            Ok(())
113        }
114    }
115}
116
117impl Interruptee for SqlInterruptScope {
118    #[inline]
119    fn was_interrupted(&self) -> bool {
120        self.was_interrupted()
121    }
122}
123
124// Needed to allow Weak<SqlInterruptHandle> to be passed to `interrupt::register_interrupt`
125impl AsRef<SqlInterruptHandle> for SqlInterruptHandle {
126    fn as_ref(&self) -> &SqlInterruptHandle {
127        self
128    }
129}