1use crate::Error::BanditNotFound;
7use crate::{
8 interest::InterestVectorKind,
9 schema::RelevancyConnectionInitializer,
10 url_hash::{hash_url, UrlHash},
11 Interest, InterestVector, Result,
12};
13use interrupt_support::SqlInterruptScope;
14use rusqlite::{Connection, OpenFlags};
15use sql_support::{ConnExt, LazyDb};
16use std::path::Path;
17
18pub struct RelevancyDb {
20 reader: LazyDb<RelevancyConnectionInitializer>,
21 writer: LazyDb<RelevancyConnectionInitializer>,
22}
23
24#[derive(Debug, PartialEq, uniffi::Record)]
25pub struct BanditData {
26 pub bandit: String,
27 pub arm: String,
28 pub impressions: u64,
29 pub clicks: u64,
30 pub alpha: u64,
31 pub beta: u64,
32}
33
34impl RelevancyDb {
35 pub fn new(path: impl AsRef<Path>) -> Self {
36 let db_open_flags = OpenFlags::SQLITE_OPEN_URI
43 | OpenFlags::SQLITE_OPEN_NO_MUTEX
44 | OpenFlags::SQLITE_OPEN_CREATE
45 | OpenFlags::SQLITE_OPEN_READ_WRITE;
46 Self {
47 reader: LazyDb::new(path.as_ref(), db_open_flags, RelevancyConnectionInitializer),
48 writer: LazyDb::new(path.as_ref(), db_open_flags, RelevancyConnectionInitializer),
49 }
50 }
51
52 pub fn close(&self) {
53 self.reader.close(true);
54 self.writer.close(true);
55 }
56
57 pub fn interrupt(&self) {
58 self.reader.interrupt();
59 self.writer.interrupt();
60 }
61
62 #[cfg(test)]
63 pub fn new_for_test() -> Self {
64 use std::sync::atomic::{AtomicU32, Ordering};
65 static COUNTER: AtomicU32 = AtomicU32::new(0);
66 let count = COUNTER.fetch_add(1, Ordering::Relaxed);
67 Self::new(format!("file:test{count}.sqlite?mode=memory&cache=shared"))
68 }
69
70 pub fn read<T>(&self, op: impl FnOnce(&RelevancyDao) -> Result<T>) -> Result<T> {
72 let (mut conn, scope) = self.reader.lock()?;
73 let tx = conn.transaction()?;
74 let dao = RelevancyDao::new(&tx, scope);
75 op(&dao)
76 }
77
78 pub fn read_write<T>(&self, op: impl FnOnce(&mut RelevancyDao) -> Result<T>) -> Result<T> {
80 let (mut conn, scope) = self.writer.lock()?;
81 let tx = conn.transaction()?;
82 let mut dao = RelevancyDao::new(&tx, scope);
83 let result = op(&mut dao)?;
84 tx.commit()?;
85 Ok(result)
86 }
87}
88
89pub struct RelevancyDao<'a> {
95 pub conn: &'a Connection,
96 pub scope: SqlInterruptScope,
97}
98
99impl<'a> RelevancyDao<'a> {
100 fn new(conn: &'a Connection, scope: SqlInterruptScope) -> Self {
101 Self { conn, scope }
102 }
103
104 pub fn err_if_interrupted(&self) -> Result<()> {
106 Ok(self.scope.err_if_interrupted()?)
107 }
108
109 pub fn add_url_interest(&mut self, url_hash: UrlHash, interest: Interest) -> Result<()> {
111 let sql = "
112 INSERT OR REPLACE INTO url_interest(url_hash, interest_code)
113 VALUES (?, ?)
114 ";
115 self.conn.execute(sql, (url_hash, interest as u32))?;
116 Ok(())
117 }
118
119 pub fn get_url_interest_vector(&self, url: &str) -> Result<InterestVector> {
121 let hash = match hash_url(url) {
122 Some(u) => u,
123 None => return Ok(InterestVector::default()),
124 };
125 let mut stmt = self.conn.prepare_cached(
126 "
127 SELECT interest_code
128 FROM url_interest
129 WHERE url_hash=?
130 ",
131 )?;
132 let interests = stmt.query_and_then((hash,), |row| -> Result<Interest> {
133 row.get::<_, u32>(0)?.try_into()
134 })?;
135
136 let mut interest_vec = InterestVector::default();
137 for interest in interests {
138 interest_vec[interest?] += 1
139 }
140 Ok(interest_vec)
141 }
142
143 pub fn need_to_load_url_interests(&self) -> Result<bool> {
145 Ok(self
147 .conn
148 .query_one("SELECT NOT EXISTS (SELECT 1 FROM url_interest)")?)
149 }
150
151 pub fn update_frecency_user_interest_vector(&self, interests: &InterestVector) -> Result<()> {
156 let mut stmt = self.conn.prepare(
157 "
158 INSERT OR REPLACE INTO user_interest(kind, interest_code, count)
159 VALUES (?, ?, ?)
160 ",
161 )?;
162 for (interest, count) in interests.as_vec() {
163 stmt.execute((InterestVectorKind::Frecency, interest, count))?;
164 }
165
166 Ok(())
167 }
168
169 pub fn get_frecency_user_interest_vector(&self) -> Result<InterestVector> {
170 let mut stmt = self
171 .conn
172 .prepare("SELECT interest_code, count FROM user_interest WHERE kind = ?")?;
173 let mut interest_vec = InterestVector::default();
174 let rows = stmt.query_and_then((InterestVectorKind::Frecency,), |row| {
175 crate::Result::Ok((
176 Interest::try_from(row.get::<_, u32>(0)?)?,
177 row.get::<_, u32>(1)?,
178 ))
179 })?;
180 for row in rows {
181 let (interest_code, count) = row?;
182 interest_vec.set(interest_code, count);
183 }
184 Ok(interest_vec)
185 }
186
187 pub fn initialize_multi_armed_bandit(&mut self, bandit: &str, arm: &str) -> Result<()> {
194 let mut new_statement = self.conn.prepare(
195 "INSERT OR IGNORE INTO multi_armed_bandit (bandit, arm, alpha, beta, impressions, clicks) VALUES (?, ?, ?, ?, ?, ?)"
196 )?;
197 new_statement.execute((bandit, arm, 1, 1, 0, 0))?;
198
199 Ok(())
200 }
201
202 pub fn retrieve_bandit_arm_beta_distribution(
207 &self,
208 bandit: &str,
209 arm: &str,
210 ) -> Result<(u64, u64)> {
211 let mut stmt = self
212 .conn
213 .prepare("SELECT alpha, beta FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
214
215 let mut result = stmt.query((&bandit, &arm))?;
216
217 match result.next()? {
218 Some(row) => Ok((row.get(0)?, row.get(1)?)),
219 None => Err(BanditNotFound {
220 bandit: bandit.to_string(),
221 arm: arm.to_string(),
222 }),
223 }
224 }
225
226 pub fn retrieve_bandit_data(&self, bandit: &str, arm: &str) -> Result<BanditData> {
234 let mut stmt = self
235 .conn
236 .prepare("SELECT bandit, arm, impressions, clicks, alpha, beta FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
237
238 let mut result = stmt.query((&bandit, &arm))?;
239
240 match result.next()? {
241 Some(row) => {
242 let bandit = row.get::<_, String>(0)?;
243 let arm = row.get::<_, String>(1)?;
244 let impressions = row.get::<_, u64>(2)?;
245 let clicks = row.get::<_, u64>(3)?;
246 let alpha = row.get::<_, u64>(4)?;
247 let beta = row.get::<_, u64>(5)?;
248
249 Ok(BanditData {
250 bandit,
251 arm,
252 impressions,
253 clicks,
254 alpha,
255 beta,
256 })
257 }
258 None => Err(BanditNotFound {
259 bandit: bandit.to_string(),
260 arm: arm.to_string(),
261 }),
262 }
263 }
264
265 pub fn update_bandit_arm_data(&self, bandit: &str, arm: &str, selected: bool) -> Result<()> {
273 let mut stmt = if selected {
274 self
275 .conn
276 .prepare("UPDATE multi_armed_bandit SET alpha=alpha+1, clicks=clicks+1, impressions=impressions+1 WHERE bandit=? AND arm=?")?
277 } else {
278 self
279 .conn
280 .prepare("UPDATE multi_armed_bandit SET beta=beta+1, impressions=impressions+1 WHERE bandit=? AND arm=?")?
281 };
282
283 let result = stmt.execute((&bandit, &arm))?;
284
285 if result == 0 {
286 return Err(BanditNotFound {
287 bandit: bandit.to_string(),
288 arm: arm.to_string(),
289 });
290 }
291
292 Ok(())
293 }
294}
295
296#[cfg(test)]
297mod test {
298 use super::*;
299 use rusqlite::params;
300
301 #[test]
302 fn test_store_frecency_user_interest_vector() {
303 let db = RelevancyDb::new_for_test();
304 assert_eq!(
306 db.read_write(|dao| dao.get_frecency_user_interest_vector())
307 .unwrap(),
308 InterestVector::default()
309 );
310
311 let interest_vec = InterestVector {
312 animals: 2,
313 autos: 1,
314 news: 5,
315 ..InterestVector::default()
316 };
317 db.read_write(|dao| dao.update_frecency_user_interest_vector(&interest_vec))
318 .unwrap();
319 assert_eq!(
320 db.read_write(|dao| dao.get_frecency_user_interest_vector())
321 .unwrap(),
322 interest_vec,
323 );
324 }
325
326 #[test]
327 fn test_update_frecency_user_interest_vector() {
328 let db = RelevancyDb::new_for_test();
329 let interest_vec1 = InterestVector {
330 animals: 2,
331 autos: 1,
332 news: 5,
333 ..InterestVector::default()
334 };
335 let interest_vec2 = InterestVector {
336 animals: 1,
337 career: 3,
338 ..InterestVector::default()
339 };
340 db.read_write(|dao| dao.update_frecency_user_interest_vector(&interest_vec1))
342 .unwrap();
343 db.read_write(|dao| dao.update_frecency_user_interest_vector(&interest_vec2))
344 .unwrap();
345 assert_eq!(
347 db.read_write(|dao| dao.get_frecency_user_interest_vector())
348 .unwrap(),
349 interest_vec2,
350 );
351 }
352
353 #[test]
354 fn test_initialize_multi_armed_bandit() -> Result<()> {
355 let db = RelevancyDb::new_for_test();
356
357 let bandit = "provider".to_string();
358 let arm = "weather".to_string();
359
360 db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
361
362 let result = db.read(|dao| {
363 let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
364
365 stmt.query_row(params![&bandit, &arm], |row| {
366 let alpha: usize = row.get(0)?;
367 let beta: usize = row.get(1)?;
368 let impressions: usize = row.get(2)?;
369 let clicks: usize = row.get(3)?;
370
371 Ok((alpha, beta, impressions, clicks))
372 }).map_err(|e| e.into())
373 })?;
374
375 assert_eq!(result.0, 1); assert_eq!(result.1, 1); assert_eq!(result.2, 0); assert_eq!(result.3, 0); Ok(())
381 }
382
383 #[test]
384 fn test_initialize_multi_armed_bandit_existing_data() -> Result<()> {
385 let db = RelevancyDb::new_for_test();
386
387 let bandit = "provider".to_string();
388 let arm = "weather".to_string();
389
390 db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
391
392 let result = db.read(|dao| {
393 let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
394
395 stmt.query_row(params![&bandit, &arm], |row| {
396 let alpha: usize = row.get(0)?;
397 let beta: usize = row.get(1)?;
398 let impressions: usize = row.get(2)?;
399 let clicks: usize = row.get(3)?;
400
401 Ok((alpha, beta, impressions, clicks))
402 }).map_err(|e| e.into())
403 })?;
404
405 assert_eq!(result.0, 1); assert_eq!(result.1, 1); assert_eq!(result.2, 0); assert_eq!(result.3, 0); db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, true))?;
411
412 let (alpha, beta) =
413 db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
414
415 assert_eq!(alpha, 2);
416 assert_eq!(beta, 1);
417
418 db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
420
421 let (alpha, beta) =
422 db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
423
424 assert_eq!(alpha, 2);
426 assert_eq!(beta, 1);
427
428 Ok(())
429 }
430
431 #[test]
432 fn test_retrieve_bandit_arm_beta_distribution() -> Result<()> {
433 let db = RelevancyDb::new_for_test();
434 let bandit = "provider".to_string();
435 let arm = "weather".to_string();
436
437 db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
438
439 db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, true))?;
440
441 db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
442
443 db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
444
445 let (alpha, beta) =
446 db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
447
448 assert_eq!(alpha, 2);
449 assert_eq!(beta, 3);
450
451 Ok(())
452 }
453
454 #[test]
455 fn test_retrieve_bandit_arm_beta_distribution_not_found() -> Result<()> {
456 let db = RelevancyDb::new_for_test();
457 let bandit = "provider".to_string();
458 let arm = "weather".to_string();
459
460 let result = db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm));
461
462 match result {
463 Ok((alpha, beta)) => panic!(
464 "Expected BanditNotFound error, but got Ok result with alpha: {} and beta: {}",
465 alpha, beta
466 ),
467 Err(BanditNotFound { bandit: b, arm: a }) => {
468 assert_eq!(b, bandit);
469 assert_eq!(a, arm);
470 }
471 _ => {}
472 }
473
474 Ok(())
475 }
476
477 #[test]
478 fn test_update_bandit_arm_data_selected() -> Result<()> {
479 let db = RelevancyDb::new_for_test();
480 let bandit = "provider".to_string();
481 let arm = "weather".to_string();
482
483 db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
484
485 let result = db.read(|dao| {
486 let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
487
488 stmt.query_row(params![&bandit, &arm], |row| {
489 let alpha: usize = row.get(0)?;
490 let beta: usize = row.get(1)?;
491 let impressions: usize = row.get(2)?;
492 let clicks: usize = row.get(3)?;
493
494 Ok((alpha, beta, impressions, clicks))
495 }).map_err(|e| e.into())
496 })?;
497
498 assert_eq!(result.0, 1);
499 assert_eq!(result.1, 1);
500 assert_eq!(result.2, 0);
501 assert_eq!(result.3, 0);
502
503 db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, true))?;
504
505 let (alpha, beta) =
506 db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
507
508 assert_eq!(alpha, 2);
509 assert_eq!(beta, 1);
510
511 Ok(())
512 }
513
514 #[test]
515 fn test_update_bandit_arm_data_not_selected() -> Result<()> {
516 let db = RelevancyDb::new_for_test();
517 let bandit = "provider".to_string();
518 let arm = "weather".to_string();
519
520 db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
521
522 let result = db.read(|dao| {
523 let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
524
525 stmt.query_row(params![&bandit, &arm], |row| {
526 let alpha: usize = row.get(0)?;
527 let beta: usize = row.get(1)?;
528 let impressions: usize = row.get(2)?;
529 let clicks: usize = row.get(3)?;
530
531 Ok((alpha, beta, impressions, clicks))
532 }).map_err(|e| e.into())
533 })?;
534
535 assert_eq!(result.0, 1);
536 assert_eq!(result.1, 1);
537 assert_eq!(result.2, 0);
538 assert_eq!(result.3, 0);
539
540 db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
541
542 let (alpha, beta) =
543 db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
544
545 assert_eq!(alpha, 1);
546 assert_eq!(beta, 2);
547
548 Ok(())
549 }
550
551 #[test]
552 fn test_update_bandit_arm_data_not_found() -> Result<()> {
553 let db = RelevancyDb::new_for_test();
554 let bandit = "provider".to_string();
555 let arm = "weather".to_string();
556
557 let result = db.read(|dao| dao.update_bandit_arm_data(&bandit, &arm, false));
558
559 match result {
560 Ok(()) => panic!("Expected BanditNotFound error, but got Ok result"),
561 Err(BanditNotFound { bandit: b, arm: a }) => {
562 assert_eq!(b, bandit);
563 assert_eq!(a, arm);
564 }
565 _ => {}
566 }
567
568 Ok(())
569 }
570
571 #[test]
572 fn test_retrieve_bandit_data() -> Result<()> {
573 let db = RelevancyDb::new_for_test();
574 let bandit = "provider".to_string();
575 let arm = "weather".to_string();
576
577 db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
578
579 db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, true))?;
581 db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
582 db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
583
584 let bandit_data = db.read(|dao| dao.retrieve_bandit_data(&bandit, &arm))?;
585
586 let expected_bandit_data = BanditData {
587 bandit: bandit.clone(),
588 arm: arm.clone(),
589 impressions: 3, clicks: 1, alpha: 2,
592 beta: 3,
593 };
594
595 assert_eq!(bandit_data, expected_bandit_data);
596
597 Ok(())
598 }
599
600 #[test]
601 fn test_retrieve_bandit_data_not_found() -> Result<()> {
602 let db = RelevancyDb::new_for_test();
603 let bandit = "provider".to_string();
604 let arm = "weather".to_string();
605
606 let result = db.read(|dao| dao.retrieve_bandit_data(&bandit, &arm));
607
608 match result {
609 Ok(bandit_data) => panic!(
610 "Expected BanditNotFound error, but got Ok result with alpha: {}, beta: {}, impressions: {}, clicks: {}, bandit: {}, arm: {}",
611 bandit_data.alpha, bandit_data.beta, bandit_data.impressions, bandit_data.clicks, bandit_data.arm, bandit_data.arm
612 ),
613 Err(BanditNotFound { bandit: b, arm: a }) => {
614 assert_eq!(b, bandit);
615 assert_eq!(a, arm);
616 }
617 _ => {}
618 }
619
620 Ok(())
621 }
622}