relevancy/
lib.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5//! Proposed API for the relevancy component (validation phase)
6//!
7//! The goal here is to allow us to validate that we can reliably detect user interests from
8//! history data, without spending too much time building the API out.  There's some hand-waving
9//! towards how we would use this data to rank search results, but we don't need to come to a final
10//! decision on that yet.
11
12mod db;
13mod error;
14mod ingest;
15mod interest;
16mod ranker;
17mod rs;
18mod schema;
19pub mod url_hash;
20
21use rand_distr::{Beta, Distribution};
22
23use std::{collections::HashMap, sync::Arc};
24
25use parking_lot::Mutex;
26use remote_settings::{RemoteSettingsClient, RemoteSettingsService};
27
28pub use db::RelevancyDb;
29pub use error::{ApiResult, Error, RelevancyApiError, Result};
30// reexport logging helpers.
31pub use error_support::{debug, error, info, trace, warn};
32
33pub use interest::{Interest, InterestVector};
34pub use ranker::score;
35
36use error_support::handle_error;
37
38use db::BanditData;
39
40uniffi::setup_scaffolding!();
41
42#[derive(uniffi::Object)]
43pub struct RelevancyStore {
44    inner: RelevancyStoreInner<Arc<RemoteSettingsClient>>,
45}
46
47/// Top-level API for the Relevancy component
48// Impl block to be exported via `UniFFI`.
49#[uniffi::export]
50impl RelevancyStore {
51    /// Construct a new RelevancyStore
52    ///
53    /// This is non-blocking since databases and other resources are lazily opened.
54    #[uniffi::constructor]
55    pub fn new(db_path: String, remote_settings: Arc<RemoteSettingsService>) -> Self {
56        Self {
57            inner: RelevancyStoreInner::new(
58                db_path,
59                remote_settings.make_client(rs::REMOTE_SETTINGS_COLLECTION.to_string()),
60            ),
61        }
62    }
63
64    /// Close any open resources (for example databases)
65    ///
66    /// Calling `close` will interrupt any in-progress queries on other threads.
67    pub fn close(&self) {
68        self.inner.close()
69    }
70
71    /// Interrupt any current database queries
72    pub fn interrupt(&self) {
73        self.inner.interrupt()
74    }
75
76    /// Ingest top URLs to build the user's interest vector.
77    ///
78    /// Consumer should pass a list of the user's top URLs by frecency to this method.  It will
79    /// then:
80    ///
81    ///  - Download the URL interest data from remote settings.  Eventually this should be cached /
82    ///    stored in the database, but for now it would be fine to download fresh data each time.
83    ///  - Match the user's top URls against the interest data to build up their interest vector.
84    ///  - Store the user's interest vector in the database.
85    ///
86    ///  This method may execute for a long time and should only be called from a worker thread.
87    #[handle_error(Error)]
88    pub fn ingest(&self, top_urls_by_frecency: Vec<String>) -> ApiResult<InterestVector> {
89        self.inner.ingest(top_urls_by_frecency)
90    }
91
92    /// Get the user's interest vector directly.
93    ///
94    /// This runs after [Self::ingest].  It returns the interest vector directly so that the
95    /// consumer can show it in an `about:` page.
96    #[handle_error(Error)]
97    pub fn user_interest_vector(&self) -> ApiResult<InterestVector> {
98        self.inner.user_interest_vector()
99    }
100
101    /// Initializes probability distributions for any uninitialized items (arms) within a bandit model.
102    ///
103    /// This method takes a `bandit` identifier and a list of `arms` (items) and ensures that each arm
104    /// in the list has an initialized probability distribution in the database. For each arm, if the
105    /// probability distribution does not already exist, it will be created, using Beta(1,1) as default,
106    /// which represents uniform distribution.
107    #[handle_error(Error)]
108    pub fn bandit_init(&self, bandit: String, arms: &[String]) -> ApiResult<()> {
109        self.inner.bandit_init(bandit, arms)
110    }
111
112    /// Selects the optimal item (arm) to display to the user based on a multi-armed bandit model.
113    ///
114    /// This method takes in a `bandit` identifier and a list of possible `arms` (items) and uses a
115    /// Thompson sampling approach to select the arm with the highest probability of success.
116    /// For each arm, it retrieves the Beta distribution parameters (alpha and beta) from the
117    /// database, creates a Beta distribution, and samples from it to estimate the arm's probability
118    /// of success. The arm with the highest sampled probability is selected and returned.
119    #[handle_error(Error)]
120    pub fn bandit_select(&self, bandit: String, arms: &[String]) -> ApiResult<String> {
121        self.inner.bandit_select(bandit, arms)
122    }
123
124    /// Updates the bandit model's arm data based on user interaction (selection or non-selection).
125    ///
126    /// This method takes in a `bandit` identifier, an `arm` identifier, and a `selected` flag.
127    /// If `selected` is true, it updates the model to reflect a successful selection of the arm,
128    /// reinforcing its positive reward probability. If `selected` is false, it updates the
129    /// beta (failure) distribution of the arm, reflecting a lack of selection and reinforcing
130    /// its likelihood of a negative outcome.
131    #[handle_error(Error)]
132    pub fn bandit_update(&self, bandit: String, arm: String, selected: bool) -> ApiResult<()> {
133        self.inner.bandit_update(bandit, arm, selected)
134    }
135
136    /// Retrieves the data for a specific bandit and arm.
137    #[handle_error(Error)]
138    pub fn get_bandit_data(&self, bandit: String, arm: String) -> ApiResult<BanditData> {
139        self.inner.get_bandit_data(bandit, arm)
140    }
141
142    /// Download the interest data from remote settings if needed
143    #[handle_error(Error)]
144    pub fn ensure_interest_data_populated(&self) -> ApiResult<()> {
145        self.inner.ensure_interest_data_populated()
146    }
147}
148
149pub(crate) struct RelevancyStoreInner<C> {
150    db: RelevancyDb,
151    cache: Mutex<BanditCache>,
152    client: C,
153}
154
155/// Top-level API for the Relevancy component
156// Impl block to be exported via `UniFFI`.
157impl<C: rs::RelevancyRemoteSettingsClient> RelevancyStoreInner<C> {
158    pub fn new(db_path: String, client: C) -> Self {
159        Self {
160            db: RelevancyDb::new(db_path),
161            cache: Mutex::new(BanditCache::new()),
162            client,
163        }
164    }
165
166    /// Close any open resources (for example databases)
167    ///
168    /// Calling `close` will interrupt any in-progress queries on other threads.
169    pub fn close(&self) {
170        self.db.close();
171        self.client.close();
172    }
173
174    /// Interrupt any current database queries
175    pub fn interrupt(&self) {
176        self.db.interrupt()
177    }
178
179    /// Ingest top URLs to build the user's interest vector.
180    ///
181    /// Consumer should pass a list of the user's top URLs by frecency to this method.  It will
182    /// then:
183    ///
184    ///  - Download the URL interest data from remote settings.  Eventually this should be cached /
185    ///    stored in the database, but for now it would be fine to download fresh data each time.
186    ///  - Match the user's top URls against the interest data to build up their interest vector.
187    ///  - Store the user's interest vector in the database.
188    ///
189    ///  This method may execute for a long time and should only be called from a worker thread.
190    pub fn ingest(&self, top_urls_by_frecency: Vec<String>) -> Result<InterestVector> {
191        let interest_vec = self.classify(top_urls_by_frecency)?;
192        self.db
193            .read_write(|dao| dao.update_frecency_user_interest_vector(&interest_vec))?;
194        Ok(interest_vec)
195    }
196
197    pub fn classify(&self, top_urls_by_frecency: Vec<String>) -> Result<InterestVector> {
198        let mut interest_vector = InterestVector::default();
199        for url in top_urls_by_frecency {
200            let interest_count = self.db.read(|dao| dao.get_url_interest_vector(&url))?;
201            crate::trace!("classified: {url} {}", interest_count.summary());
202            interest_vector = interest_vector + interest_count;
203        }
204        Ok(interest_vector)
205    }
206
207    /// Get the user's interest vector directly.
208    ///
209    /// This runs after [Self::ingest].  It returns the interest vector directly so that the
210    /// consumer can show it in an `about:` page.
211    pub fn user_interest_vector(&self) -> Result<InterestVector> {
212        self.db.read(|dao| dao.get_frecency_user_interest_vector())
213    }
214
215    /// Initializes probability distributions for any uninitialized items (arms) within a bandit model.
216    ///
217    /// This method takes a `bandit` identifier and a list of `arms` (items) and ensures that each arm
218    /// in the list has an initialized probability distribution in the database. For each arm, if the
219    /// probability distribution does not already exist, it will be created, using Beta(1,1) as default,
220    /// which represents uniform distribution.
221    pub fn bandit_init(&self, bandit: String, arms: &[String]) -> Result<()> {
222        self.db.read_write(|dao| {
223            for arm in arms {
224                dao.initialize_multi_armed_bandit(&bandit, arm)?;
225            }
226            Ok(())
227        })?;
228
229        Ok(())
230    }
231
232    /// Selects the optimal item (arm) to display to the user based on a multi-armed bandit model.
233    ///
234    /// This method takes in a `bandit` identifier and a list of possible `arms` (items) and uses a
235    /// Thompson sampling approach to select the arm with the highest probability of success.
236    /// For each arm, it retrieves the Beta distribution parameters (alpha and beta) from the
237    /// database, creates a Beta distribution, and samples from it to estimate the arm's probability
238    /// of success. The arm with the highest sampled probability is selected and returned.
239    pub fn bandit_select(&self, bandit: String, arms: &[String]) -> Result<String> {
240        let mut cache = self.cache.lock();
241        let mut best_sample = f64::MIN;
242        let mut selected_arm = String::new();
243
244        for arm in arms {
245            let (alpha, beta) = cache.get_beta_distribution(&bandit, arm, &self.db)?;
246            // this creates a Beta distribution for an alpha & beta pair
247            let beta_dist = Beta::new(alpha as f64, beta as f64)
248                .expect("computing betas dist unexpectedly failed");
249
250            // Sample from the Beta distribution
251            let sampled_prob = beta_dist.sample(&mut rand::thread_rng());
252
253            if sampled_prob > best_sample {
254                best_sample = sampled_prob;
255                selected_arm.clone_from(arm);
256            }
257        }
258
259        Ok(selected_arm)
260    }
261
262    /// Updates the bandit model's arm data based on user interaction (selection or non-selection).
263    ///
264    /// This method takes in a `bandit` identifier, an `arm` identifier, and a `selected` flag.
265    /// If `selected` is true, it updates the model to reflect a successful selection of the arm,
266    /// reinforcing its positive reward probability. If `selected` is false, it updates the
267    /// beta (failure) distribution of the arm, reflecting a lack of selection and reinforcing
268    /// its likelihood of a negative outcome.
269    pub fn bandit_update(&self, bandit: String, arm: String, selected: bool) -> Result<()> {
270        let mut cache = self.cache.lock();
271
272        cache.clear(&bandit, &arm);
273
274        self.db
275            .read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, selected))?;
276
277        Ok(())
278    }
279
280    /// Retrieves the data for a specific bandit and arm.
281    pub fn get_bandit_data(&self, bandit: String, arm: String) -> Result<BanditData> {
282        let bandit_data = self
283            .db
284            .read(|dao| dao.retrieve_bandit_data(&bandit, &arm))?;
285
286        Ok(bandit_data)
287    }
288
289    pub fn ensure_interest_data_populated(&self) -> Result<()> {
290        ingest::ensure_interest_data_populated(&self.db, &self.client)
291    }
292}
293
294#[derive(Default)]
295pub struct BanditCache {
296    cache: HashMap<(String, String), (u64, u64)>,
297}
298
299impl BanditCache {
300    /// Creates a new, empty `BanditCache`.
301    ///
302    /// The cache is initialized as an empty `HashMap` and is used to store
303    /// precomputed Beta distribution parameters for faster access during
304    /// Thompson Sampling operations.
305    pub fn new() -> Self {
306        Self::default()
307    }
308
309    /// Retrieves the Beta distribution parameters for a given bandit and arm.
310    ///
311    /// If the parameters for the specified `bandit` and `arm` are already cached,
312    /// they are returned directly. Otherwise, the parameters are fetched from
313    /// the database, added to the cache, and then returned.
314    pub fn get_beta_distribution(
315        &mut self,
316        bandit: &str,
317        arm: &str,
318        db: &RelevancyDb,
319    ) -> Result<(u64, u64)> {
320        let key = (bandit.to_string(), arm.to_string());
321
322        // Check if the distribution is already cached
323        if let Some(&params) = self.cache.get(&key) {
324            return Ok(params);
325        }
326
327        let params = db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(bandit, arm))?;
328
329        // Cache the retrieved parameters for future use
330        self.cache.insert(key, params);
331
332        Ok(params)
333    }
334
335    /// Clears the cached Beta distribution parameters for a given bandit and arm.
336    ///
337    /// This removes the cached values for the specified `bandit` and `arm` from the cache.
338    /// Use this method if the cached parameters are no longer valid or need to be refreshed.
339    pub fn clear(&mut self, bandit: &str, arm: &str) {
340        let key = (bandit.to_string(), arm.to_string());
341
342        self.cache.remove(&key);
343    }
344}
345
346/// Interest metrics that we want to send to Glean as part of the validation process.  These contain
347/// the cosine similarity when comparing the user's interest against various interest vectors that
348/// consumers may use.
349///
350/// Cosine similarly was chosen because it seems easy to calculate.  This was then matched against
351/// some semi-plausible real-world interest vectors that consumers might use.  This is all up for
352/// debate and we may decide to switch to some other metrics.
353///
354/// Similarity values are transformed to integers by multiplying the floating point value by 1000 and
355/// rounding.  This is to make them compatible with Glean's distribution metrics.
356#[derive(uniffi::Record)]
357pub struct InterestMetrics {
358    /// Similarity between the user's interest vector and an interest vector where the element for
359    /// the user's top interest is copied, but all other interests are set to zero.  This measures
360    /// the highest possible similarity with consumers that used interest vectors with a single
361    /// interest set.
362    pub top_single_interest_similarity: u32,
363    /// The same as before, but the top 2 interests are copied. This measures the highest possible
364    /// similarity with consumers that used interest vectors with a two interests (note: this means
365    /// they would need to choose the user's top two interests and have the exact same proportion
366    /// between them as the user).
367    pub top_2interest_similarity: u32,
368    /// The same as before, but the top 3 interests are copied.
369    pub top_3interest_similarity: u32,
370}
371
372#[cfg(test)]
373mod test {
374    use crate::url_hash::hash_url;
375
376    use super::*;
377    use crate::rs::test::NullRelavancyRemoteSettingsClient;
378    use rand::Rng;
379    use std::collections::HashMap;
380
381    fn make_fixture() -> Vec<(String, Interest)> {
382        vec![
383            ("https://food.com/".to_string(), Interest::Food),
384            ("https://hello.com".to_string(), Interest::Inconclusive),
385            ("https://pasta.com".to_string(), Interest::Food),
386            ("https://dog.com".to_string(), Interest::Animals),
387        ]
388    }
389
390    fn expected_interest_vector() -> InterestVector {
391        InterestVector {
392            inconclusive: 1,
393            animals: 1,
394            food: 2,
395            ..InterestVector::default()
396        }
397    }
398
399    fn setup_store(
400        test_id: &'static str,
401    ) -> RelevancyStoreInner<NullRelavancyRemoteSettingsClient> {
402        let relevancy_store = RelevancyStoreInner::new(
403            format!("file:test_{test_id}_data?mode=memory&cache=shared"),
404            NullRelavancyRemoteSettingsClient,
405        );
406        relevancy_store
407            .db
408            .read_write(|dao| {
409                for (url, interest) in make_fixture() {
410                    dao.add_url_interest(hash_url(&url).unwrap(), interest)?;
411                }
412                Ok(())
413            })
414            .expect("Insert should succeed");
415
416        relevancy_store
417    }
418
419    #[test]
420    fn test_ingest() {
421        let relevancy_store = setup_store("ingest");
422        let (top_urls, _): (Vec<String>, Vec<Interest>) = make_fixture().into_iter().unzip();
423
424        assert_eq!(
425            relevancy_store.ingest(top_urls).unwrap(),
426            expected_interest_vector()
427        );
428    }
429
430    #[test]
431    fn test_get_user_interest_vector() {
432        let relevancy_store = setup_store("get_user_interest_vector");
433        let (top_urls, _): (Vec<String>, Vec<Interest>) = make_fixture().into_iter().unzip();
434
435        relevancy_store
436            .ingest(top_urls)
437            .expect("Ingest should succeed");
438
439        assert_eq!(
440            relevancy_store.user_interest_vector().unwrap(),
441            expected_interest_vector()
442        );
443    }
444
445    #[test]
446    fn test_thompson_sampling_convergence() {
447        let relevancy_store = setup_store("thompson_sampling_convergence");
448
449        let arms_to_ctr_map: HashMap<String, f64> = [
450            ("wiki".to_string(), 0.1),        // 10% CTR
451            ("geolocation".to_string(), 0.3), // 30% CTR
452            ("weather".to_string(), 0.8),     // 80% CTR
453        ]
454        .into_iter()
455        .collect();
456
457        let arm_names: Vec<String> = arms_to_ctr_map.keys().cloned().collect();
458
459        let bandit = "provider".to_string();
460
461        // initialize bandit
462        relevancy_store
463            .bandit_init(bandit.clone(), &arm_names)
464            .unwrap();
465
466        let mut rng = rand::thread_rng();
467
468        // Create a HashMap to map arm names to their selection counts
469        let mut selection_counts: HashMap<String, usize> =
470            arm_names.iter().map(|name| (name.clone(), 0)).collect();
471
472        // Simulate 1000 rounds of Thompson Sampling
473        for _ in 0..1000 {
474            // Use Thompson Sampling to select an arm
475            let selected_arm_name = relevancy_store
476                .bandit_select(bandit.clone(), &arm_names)
477                .expect("Failed to select arm");
478
479            // increase the selection count for the selected arm
480            *selection_counts.get_mut(&selected_arm_name).unwrap() += 1;
481
482            // get the true CTR for the selected arm
483            let true_ctr = &arms_to_ctr_map[&selected_arm_name];
484
485            // simulate a click or no-click based on the true CTR
486            let clicked = rng.gen_bool(*true_ctr);
487
488            // update beta distribution for arm based on click/no click
489            relevancy_store
490                .bandit_update(bandit.clone(), selected_arm_name, clicked)
491                .expect("Failed to update beta distribution for arm");
492        }
493
494        //retrieve arm with maximum selection count
495        let most_selected_arm_name = selection_counts
496            .iter()
497            .max_by_key(|(_, count)| *count)
498            .unwrap()
499            .0;
500
501        assert_eq!(
502            most_selected_arm_name, "weather",
503            "Thompson Sampling did not favor the best-performing arm"
504        );
505    }
506
507    #[test]
508    fn test_get_bandit_data() {
509        let relevancy_store = setup_store("get_bandit_data");
510
511        let bandit = "provider".to_string();
512        let arm = "wiki".to_string();
513
514        // initialize bandit
515        relevancy_store
516            .bandit_init(
517                "provider".to_string(),
518                &["weather".to_string(), "fakespot".to_string(), arm.clone()],
519            )
520            .unwrap();
521
522        // update beta distribution for arm based on click/no click
523        relevancy_store
524            .bandit_update(bandit.clone(), arm.clone(), true)
525            .expect("Failed to update beta distribution for arm");
526
527        relevancy_store
528            .bandit_update(bandit.clone(), arm.clone(), true)
529            .expect("Failed to update beta distribution for arm");
530
531        let bandit_data = relevancy_store
532            .get_bandit_data(bandit.clone(), arm.clone())
533            .unwrap();
534
535        let expected_bandit_data = BanditData {
536            bandit: bandit.clone(),
537            arm: arm.clone(),
538            impressions: 2,
539            clicks: 2,
540            alpha: 3,
541            beta: 1,
542        };
543
544        assert_eq!(bandit_data, expected_bandit_data);
545    }
546}