relevancy/
db.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
4 */
5
6use crate::Error::BanditNotFound;
7use crate::{
8    interest::InterestVectorKind,
9    schema::RelevancyConnectionInitializer,
10    url_hash::{hash_url, UrlHash},
11    Interest, InterestVector, Result,
12};
13use interrupt_support::SqlInterruptScope;
14use rusqlite::{Connection, OpenFlags};
15use sql_support::{ConnExt, LazyDb};
16use std::path::Path;
17
18/// A thread-safe wrapper around an SQLite connection to the Relevancy database
19pub struct RelevancyDb {
20    reader: LazyDb<RelevancyConnectionInitializer>,
21    writer: LazyDb<RelevancyConnectionInitializer>,
22}
23
24#[derive(Debug, PartialEq, uniffi::Record)]
25pub struct BanditData {
26    pub bandit: String,
27    pub arm: String,
28    pub impressions: u64,
29    pub clicks: u64,
30    pub alpha: u64,
31    pub beta: u64,
32}
33
34impl RelevancyDb {
35    pub fn new(path: impl AsRef<Path>) -> Self {
36        // Note: use `SQLITE_OPEN_READ_WRITE` for both read and write connections.
37        // Even if we're opening a read connection, we may need to do a write as part of the
38        // initialization process.
39        //
40        // The read-only nature of the connection is enforced by the fact that [RelevancyDb::read] uses a
41        // shared ref to the `RelevancyDao`.
42        let db_open_flags = OpenFlags::SQLITE_OPEN_URI
43            | OpenFlags::SQLITE_OPEN_NO_MUTEX
44            | OpenFlags::SQLITE_OPEN_CREATE
45            | OpenFlags::SQLITE_OPEN_READ_WRITE;
46        Self {
47            reader: LazyDb::new(path.as_ref(), db_open_flags, RelevancyConnectionInitializer),
48            writer: LazyDb::new(path.as_ref(), db_open_flags, RelevancyConnectionInitializer),
49        }
50    }
51
52    pub fn close(&self) {
53        self.reader.close(true);
54        self.writer.close(true);
55    }
56
57    pub fn interrupt(&self) {
58        self.reader.interrupt();
59        self.writer.interrupt();
60    }
61
62    #[cfg(test)]
63    pub fn new_for_test() -> Self {
64        use std::sync::atomic::{AtomicU32, Ordering};
65        static COUNTER: AtomicU32 = AtomicU32::new(0);
66        let count = COUNTER.fetch_add(1, Ordering::Relaxed);
67        Self::new(format!("file:test{count}.sqlite?mode=memory&cache=shared"))
68    }
69
70    /// Accesses the Suggest database in a transaction for reading.
71    pub fn read<T>(&self, op: impl FnOnce(&RelevancyDao) -> Result<T>) -> Result<T> {
72        let (mut conn, scope) = self.reader.lock()?;
73        let tx = conn.transaction()?;
74        let dao = RelevancyDao::new(&tx, scope);
75        op(&dao)
76    }
77
78    /// Accesses the Suggest database in a transaction for reading and writing.
79    pub fn read_write<T>(&self, op: impl FnOnce(&mut RelevancyDao) -> Result<T>) -> Result<T> {
80        let (mut conn, scope) = self.writer.lock()?;
81        let tx = conn.transaction()?;
82        let mut dao = RelevancyDao::new(&tx, scope);
83        let result = op(&mut dao)?;
84        tx.commit()?;
85        Ok(result)
86    }
87}
88
89/// A data access object (DAO) that wraps a connection to the Relevancy database
90///
91/// Methods that only read from the database take an immutable reference to
92/// `self` (`&self`), and methods that write to the database take a mutable
93/// reference (`&mut self`).
94pub struct RelevancyDao<'a> {
95    pub conn: &'a Connection,
96    pub scope: SqlInterruptScope,
97}
98
99impl<'a> RelevancyDao<'a> {
100    fn new(conn: &'a Connection, scope: SqlInterruptScope) -> Self {
101        Self { conn, scope }
102    }
103
104    /// Return Err(Interrupted) if we were interrupted
105    pub fn err_if_interrupted(&self) -> Result<()> {
106        Ok(self.scope.err_if_interrupted()?)
107    }
108
109    /// Associate a URL with an interest
110    pub fn add_url_interest(&mut self, url_hash: UrlHash, interest: Interest) -> Result<()> {
111        let sql = "
112            INSERT OR REPLACE INTO url_interest(url_hash, interest_code)
113            VALUES (?, ?)
114        ";
115        self.conn.execute(sql, (url_hash, interest as u32))?;
116        Ok(())
117    }
118
119    /// Get an interest vector for a URL
120    pub fn get_url_interest_vector(&self, url: &str) -> Result<InterestVector> {
121        let hash = match hash_url(url) {
122            Some(u) => u,
123            None => return Ok(InterestVector::default()),
124        };
125        let mut stmt = self.conn.prepare_cached(
126            "
127            SELECT interest_code
128            FROM url_interest
129            WHERE url_hash=?
130        ",
131        )?;
132        let interests = stmt.query_and_then((hash,), |row| -> Result<Interest> {
133            row.get::<_, u32>(0)?.try_into()
134        })?;
135
136        let mut interest_vec = InterestVector::default();
137        for interest in interests {
138            interest_vec[interest?] += 1
139        }
140        Ok(interest_vec)
141    }
142
143    /// Do we need to load the interest data?
144    pub fn need_to_load_url_interests(&self) -> Result<bool> {
145        // TODO: we probably will need a better check than this.
146        Ok(self
147            .conn
148            .query_one("SELECT NOT EXISTS (SELECT 1 FROM url_interest)")?)
149    }
150
151    /// Update the frecency user interest vector based on a new measurement.
152    ///
153    /// Right now this completely replaces the interest vector with the new data.  At some point,
154    /// we may switch to incrementally updating it instead.
155    pub fn update_frecency_user_interest_vector(&self, interests: &InterestVector) -> Result<()> {
156        let mut stmt = self.conn.prepare(
157            "
158            INSERT OR REPLACE INTO user_interest(kind, interest_code, count)
159            VALUES (?, ?, ?)
160            ",
161        )?;
162        for (interest, count) in interests.as_vec() {
163            stmt.execute((InterestVectorKind::Frecency, interest, count))?;
164        }
165
166        Ok(())
167    }
168
169    pub fn get_frecency_user_interest_vector(&self) -> Result<InterestVector> {
170        let mut stmt = self
171            .conn
172            .prepare("SELECT interest_code, count FROM user_interest WHERE kind = ?")?;
173        let mut interest_vec = InterestVector::default();
174        let rows = stmt.query_and_then((InterestVectorKind::Frecency,), |row| {
175            crate::Result::Ok((
176                Interest::try_from(row.get::<_, u32>(0)?)?,
177                row.get::<_, u32>(1)?,
178            ))
179        })?;
180        for row in rows {
181            let (interest_code, count) = row?;
182            interest_vec.set(interest_code, count);
183        }
184        Ok(interest_vec)
185    }
186
187    /// Initializes a multi-armed bandit record in the database for a specific bandit and arm.
188    ///
189    /// This method inserts a new record into the `multi_armed_bandit` table with default probability
190    /// distribution parameters (`alpha` and `beta` set to 1) and usage counters (`impressions` and
191    /// `clicks` set to 0) for the specified `bandit` and `arm`. If a record for this bandit-arm pair
192    /// already exists, the insertion is ignored, preserving the existing data.
193    pub fn initialize_multi_armed_bandit(&mut self, bandit: &str, arm: &str) -> Result<()> {
194        let mut new_statement = self.conn.prepare(
195            "INSERT OR IGNORE INTO multi_armed_bandit (bandit, arm, alpha, beta, impressions, clicks) VALUES (?, ?, ?, ?, ?, ?)"
196        )?;
197        new_statement.execute((bandit, arm, 1, 1, 0, 0))?;
198
199        Ok(())
200    }
201
202    /// Retrieves the Beta distribution parameters (`alpha` and `beta`) for a specific arm in a bandit model.
203    ///
204    /// If the specified `bandit` and `arm` do not exist in the table, an error is returned indicating
205    /// that the record was not found.
206    pub fn retrieve_bandit_arm_beta_distribution(
207        &self,
208        bandit: &str,
209        arm: &str,
210    ) -> Result<(u64, u64)> {
211        let mut stmt = self
212            .conn
213            .prepare("SELECT alpha, beta FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
214
215        let mut result = stmt.query((&bandit, &arm))?;
216
217        match result.next()? {
218            Some(row) => Ok((row.get(0)?, row.get(1)?)),
219            None => Err(BanditNotFound {
220                bandit: bandit.to_string(),
221                arm: arm.to_string(),
222            }),
223        }
224    }
225
226    /// Retrieves the data for a specific bandit and arm combination from the database.
227    ///
228    /// This method queries the `multi_armed_bandit` table to find a row matching the given
229    /// `bandit` and `arm` values. If a matching row is found, it extracts the corresponding
230    /// fields (`bandit`, `arm`, `impressions`, `clicks`, `alpha`, `beta`) and returns them
231    /// as a `BanditData` struct. If no matching row is found, it returns a `BanditNotFound`
232    /// error.
233    pub fn retrieve_bandit_data(&self, bandit: &str, arm: &str) -> Result<BanditData> {
234        let mut stmt = self
235            .conn
236            .prepare("SELECT bandit, arm, impressions, clicks, alpha, beta FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
237
238        let mut result = stmt.query((&bandit, &arm))?;
239
240        match result.next()? {
241            Some(row) => {
242                let bandit = row.get::<_, String>(0)?;
243                let arm = row.get::<_, String>(1)?;
244                let impressions = row.get::<_, u64>(2)?;
245                let clicks = row.get::<_, u64>(3)?;
246                let alpha = row.get::<_, u64>(4)?;
247                let beta = row.get::<_, u64>(5)?;
248
249                Ok(BanditData {
250                    bandit,
251                    arm,
252                    impressions,
253                    clicks,
254                    alpha,
255                    beta,
256                })
257            }
258            None => Err(BanditNotFound {
259                bandit: bandit.to_string(),
260                arm: arm.to_string(),
261            }),
262        }
263    }
264
265    /// Updates the Beta distribution parameters and counters for a specific arm in a bandit model based on user interaction.
266    ///
267    /// This method updates the `alpha` or `beta` parameters in the `multi_armed_bandit` table for the specified
268    /// `bandit` and `arm` based on whether the arm was selected by the user. If `selected` is true, it increments
269    /// both the `alpha` (indicating success) and the `clicks` and `impressions` counters. If `selected` is false,
270    /// it increments `beta` (indicating failure) and only the `impressions` counter. This approach adjusts the
271    /// distribution parameters to reflect the arm's performance over time.
272    pub fn update_bandit_arm_data(&self, bandit: &str, arm: &str, selected: bool) -> Result<()> {
273        let mut stmt = if selected {
274            self
275                .conn
276                .prepare("UPDATE multi_armed_bandit SET alpha=alpha+1, clicks=clicks+1, impressions=impressions+1 WHERE bandit=? AND arm=?")?
277        } else {
278            self
279                .conn
280                .prepare("UPDATE multi_armed_bandit SET beta=beta+1, impressions=impressions+1 WHERE bandit=? AND arm=?")?
281        };
282
283        let result = stmt.execute((&bandit, &arm))?;
284
285        if result == 0 {
286            return Err(BanditNotFound {
287                bandit: bandit.to_string(),
288                arm: arm.to_string(),
289            });
290        }
291
292        Ok(())
293    }
294}
295
296#[cfg(test)]
297mod test {
298    use super::*;
299    use rusqlite::params;
300
301    #[test]
302    fn test_store_frecency_user_interest_vector() {
303        let db = RelevancyDb::new_for_test();
304        // Initially the interest vector should be blank
305        assert_eq!(
306            db.read_write(|dao| dao.get_frecency_user_interest_vector())
307                .unwrap(),
308            InterestVector::default()
309        );
310
311        let interest_vec = InterestVector {
312            animals: 2,
313            autos: 1,
314            news: 5,
315            ..InterestVector::default()
316        };
317        db.read_write(|dao| dao.update_frecency_user_interest_vector(&interest_vec))
318            .unwrap();
319        assert_eq!(
320            db.read_write(|dao| dao.get_frecency_user_interest_vector())
321                .unwrap(),
322            interest_vec,
323        );
324    }
325
326    #[test]
327    fn test_update_frecency_user_interest_vector() {
328        let db = RelevancyDb::new_for_test();
329        let interest_vec1 = InterestVector {
330            animals: 2,
331            autos: 1,
332            news: 5,
333            ..InterestVector::default()
334        };
335        let interest_vec2 = InterestVector {
336            animals: 1,
337            career: 3,
338            ..InterestVector::default()
339        };
340        // Update the first interest vec, then the second one
341        db.read_write(|dao| dao.update_frecency_user_interest_vector(&interest_vec1))
342            .unwrap();
343        db.read_write(|dao| dao.update_frecency_user_interest_vector(&interest_vec2))
344            .unwrap();
345        // The current behavior is the second one should replace the first
346        assert_eq!(
347            db.read_write(|dao| dao.get_frecency_user_interest_vector())
348                .unwrap(),
349            interest_vec2,
350        );
351    }
352
353    #[test]
354    fn test_initialize_multi_armed_bandit() -> Result<()> {
355        let db = RelevancyDb::new_for_test();
356
357        let bandit = "provider".to_string();
358        let arm = "weather".to_string();
359
360        db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
361
362        let result = db.read(|dao| {
363            let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
364
365            stmt.query_row(params![&bandit, &arm], |row| {
366                let alpha: usize = row.get(0)?;
367                let beta: usize = row.get(1)?;
368                let impressions: usize = row.get(2)?;
369                let clicks: usize = row.get(3)?;
370
371                Ok((alpha, beta, impressions, clicks))
372            }).map_err(|e| e.into())
373        })?;
374
375        assert_eq!(result.0, 1); // Default alpha
376        assert_eq!(result.1, 1); // Default beta
377        assert_eq!(result.2, 0); // Default impressions
378        assert_eq!(result.3, 0); // Default clicks
379
380        Ok(())
381    }
382
383    #[test]
384    fn test_initialize_multi_armed_bandit_existing_data() -> Result<()> {
385        let db = RelevancyDb::new_for_test();
386
387        let bandit = "provider".to_string();
388        let arm = "weather".to_string();
389
390        db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
391
392        let result = db.read(|dao| {
393            let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
394
395            stmt.query_row(params![&bandit, &arm], |row| {
396                let alpha: usize = row.get(0)?;
397                let beta: usize = row.get(1)?;
398                let impressions: usize = row.get(2)?;
399                let clicks: usize = row.get(3)?;
400
401                Ok((alpha, beta, impressions, clicks))
402            }).map_err(|e| e.into())
403        })?;
404
405        assert_eq!(result.0, 1); // Default alpha
406        assert_eq!(result.1, 1); // Default beta
407        assert_eq!(result.2, 0); // Default impressions
408        assert_eq!(result.3, 0); // Default clicks
409
410        db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, true))?;
411
412        let (alpha, beta) =
413            db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
414
415        assert_eq!(alpha, 2);
416        assert_eq!(beta, 1);
417
418        // this should be a no-op since the same bandit-arm has already been initialized
419        db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
420
421        let (alpha, beta) =
422            db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
423
424        // alpha & beta values for the bandit-arm should remain unchanged
425        assert_eq!(alpha, 2);
426        assert_eq!(beta, 1);
427
428        Ok(())
429    }
430
431    #[test]
432    fn test_retrieve_bandit_arm_beta_distribution() -> Result<()> {
433        let db = RelevancyDb::new_for_test();
434        let bandit = "provider".to_string();
435        let arm = "weather".to_string();
436
437        db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
438
439        db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, true))?;
440
441        db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
442
443        db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
444
445        let (alpha, beta) =
446            db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
447
448        assert_eq!(alpha, 2);
449        assert_eq!(beta, 3);
450
451        Ok(())
452    }
453
454    #[test]
455    fn test_retrieve_bandit_arm_beta_distribution_not_found() -> Result<()> {
456        let db = RelevancyDb::new_for_test();
457        let bandit = "provider".to_string();
458        let arm = "weather".to_string();
459
460        let result = db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm));
461
462        match result {
463            Ok((alpha, beta)) => panic!(
464                "Expected BanditNotFound error, but got Ok result with alpha: {} and beta: {}",
465                alpha, beta
466            ),
467            Err(BanditNotFound { bandit: b, arm: a }) => {
468                assert_eq!(b, bandit);
469                assert_eq!(a, arm);
470            }
471            _ => {}
472        }
473
474        Ok(())
475    }
476
477    #[test]
478    fn test_update_bandit_arm_data_selected() -> Result<()> {
479        let db = RelevancyDb::new_for_test();
480        let bandit = "provider".to_string();
481        let arm = "weather".to_string();
482
483        db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
484
485        let result = db.read(|dao| {
486            let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
487
488            stmt.query_row(params![&bandit, &arm], |row| {
489                let alpha: usize = row.get(0)?;
490                let beta: usize = row.get(1)?;
491                let impressions: usize = row.get(2)?;
492                let clicks: usize = row.get(3)?;
493
494                Ok((alpha, beta, impressions, clicks))
495            }).map_err(|e| e.into())
496        })?;
497
498        assert_eq!(result.0, 1);
499        assert_eq!(result.1, 1);
500        assert_eq!(result.2, 0);
501        assert_eq!(result.3, 0);
502
503        db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, true))?;
504
505        let (alpha, beta) =
506            db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
507
508        assert_eq!(alpha, 2);
509        assert_eq!(beta, 1);
510
511        Ok(())
512    }
513
514    #[test]
515    fn test_update_bandit_arm_data_not_selected() -> Result<()> {
516        let db = RelevancyDb::new_for_test();
517        let bandit = "provider".to_string();
518        let arm = "weather".to_string();
519
520        db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
521
522        let result = db.read(|dao| {
523            let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
524
525            stmt.query_row(params![&bandit, &arm], |row| {
526                let alpha: usize = row.get(0)?;
527                let beta: usize = row.get(1)?;
528                let impressions: usize = row.get(2)?;
529                let clicks: usize = row.get(3)?;
530
531                Ok((alpha, beta, impressions, clicks))
532            }).map_err(|e| e.into())
533        })?;
534
535        assert_eq!(result.0, 1);
536        assert_eq!(result.1, 1);
537        assert_eq!(result.2, 0);
538        assert_eq!(result.3, 0);
539
540        db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
541
542        let (alpha, beta) =
543            db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
544
545        assert_eq!(alpha, 1);
546        assert_eq!(beta, 2);
547
548        Ok(())
549    }
550
551    #[test]
552    fn test_update_bandit_arm_data_not_found() -> Result<()> {
553        let db = RelevancyDb::new_for_test();
554        let bandit = "provider".to_string();
555        let arm = "weather".to_string();
556
557        let result = db.read(|dao| dao.update_bandit_arm_data(&bandit, &arm, false));
558
559        match result {
560            Ok(()) => panic!("Expected BanditNotFound error, but got Ok result"),
561            Err(BanditNotFound { bandit: b, arm: a }) => {
562                assert_eq!(b, bandit);
563                assert_eq!(a, arm);
564            }
565            _ => {}
566        }
567
568        Ok(())
569    }
570
571    #[test]
572    fn test_retrieve_bandit_data() -> Result<()> {
573        let db = RelevancyDb::new_for_test();
574        let bandit = "provider".to_string();
575        let arm = "weather".to_string();
576
577        db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
578
579        // Update the bandit arm data (simulate interactions)
580        db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, true))?;
581        db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
582        db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
583
584        let bandit_data = db.read(|dao| dao.retrieve_bandit_data(&bandit, &arm))?;
585
586        let expected_bandit_data = BanditData {
587            bandit: bandit.clone(),
588            arm: arm.clone(),
589            impressions: 3, // 3 updates (true + false + false)
590            clicks: 1,      // 1 `true` interaction
591            alpha: 2,
592            beta: 3,
593        };
594
595        assert_eq!(bandit_data, expected_bandit_data);
596
597        Ok(())
598    }
599
600    #[test]
601    fn test_retrieve_bandit_data_not_found() -> Result<()> {
602        let db = RelevancyDb::new_for_test();
603        let bandit = "provider".to_string();
604        let arm = "weather".to_string();
605
606        let result = db.read(|dao| dao.retrieve_bandit_data(&bandit, &arm));
607
608        match result {
609            Ok(bandit_data) => panic!(
610                "Expected BanditNotFound error, but got Ok result with alpha: {}, beta: {}, impressions: {}, clicks: {}, bandit: {}, arm: {}",
611                bandit_data.alpha, bandit_data.beta, bandit_data.impressions, bandit_data.clicks, bandit_data.arm, bandit_data.arm
612            ),
613            Err(BanditNotFound { bandit: b, arm: a }) => {
614                assert_eq!(b, bandit);
615                assert_eq!(a, arm);
616            }
617            _ => {}
618        }
619
620        Ok(())
621    }
622}