relevancy/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
//! Proposed API for the relevancy component (validation phase)
//!
//! The goal here is to allow us to validate that we can reliably detect user interests from
//! history data, without spending too much time building the API out. There's some hand-waving
//! towards how we would use this data to rank search results, but we don't need to come to a final
//! decision on that yet.
mod db;
mod error;
mod ingest;
mod interest;
mod ranker;
mod rs;
mod schema;
pub mod url_hash;
use rand_distr::{Beta, Distribution};
pub use db::RelevancyDb;
pub use error::{ApiResult, Error, RelevancyApiError, Result};
pub use interest::{Interest, InterestVector};
use parking_lot::Mutex;
pub use ranker::score;
use error_support::handle_error;
use db::BanditData;
use std::collections::HashMap;
uniffi::setup_scaffolding!();
#[derive(uniffi::Object)]
pub struct RelevancyStore {
db: RelevancyDb,
cache: Mutex<BanditCache>,
}
/// Top-level API for the Relevancy component
// Impl block to be exported via `UniFFI`.
#[uniffi::export]
impl RelevancyStore {
/// Construct a new RelevancyStore
///
/// This is non-blocking since databases and other resources are lazily opened.
#[uniffi::constructor]
pub fn new(db_path: String) -> Self {
Self {
db: RelevancyDb::new(db_path),
cache: Mutex::new(BanditCache::new()),
}
}
/// Close any open resources (for example databases)
///
/// Calling `close` will interrupt any in-progress queries on other threads.
pub fn close(&self) {
self.db.close()
}
/// Interrupt any current database queries
pub fn interrupt(&self) {
self.db.interrupt()
}
/// Ingest top URLs to build the user's interest vector.
///
/// Consumer should pass a list of the user's top URLs by frecency to this method. It will
/// then:
///
/// - Download the URL interest data from remote settings. Eventually this should be cached /
/// stored in the database, but for now it would be fine to download fresh data each time.
/// - Match the user's top URls against the interest data to build up their interest vector.
/// - Store the user's interest vector in the database.
///
/// This method may execute for a long time and should only be called from a worker thread.
#[handle_error(Error)]
pub fn ingest(&self, top_urls_by_frecency: Vec<String>) -> ApiResult<InterestVector> {
ingest::ensure_interest_data_populated(&self.db)?;
let interest_vec = self.classify(top_urls_by_frecency)?;
self.db
.read_write(|dao| dao.update_frecency_user_interest_vector(&interest_vec))?;
Ok(interest_vec)
}
/// Calculate metrics for the validation phase
///
/// This runs after [Self::ingest]. It takes the interest vector that ingest created and
/// calculates a set of metrics that we can report to glean.
#[handle_error(Error)]
pub fn calculate_metrics(&self) -> ApiResult<InterestMetrics> {
todo!()
}
/// Get the user's interest vector directly.
///
/// This runs after [Self::ingest]. It returns the interest vector directly so that the
/// consumer can show it in an `about:` page.
#[handle_error(Error)]
pub fn user_interest_vector(&self) -> ApiResult<InterestVector> {
self.db.read(|dao| dao.get_frecency_user_interest_vector())
}
/// Initializes probability distributions for any uninitialized items (arms) within a bandit model.
///
/// This method takes a `bandit` identifier and a list of `arms` (items) and ensures that each arm
/// in the list has an initialized probability distribution in the database. For each arm, if the
/// probability distribution does not already exist, it will be created, using Beta(1,1) as default,
/// which represents uniform distribution.
#[handle_error(Error)]
pub fn bandit_init(&self, bandit: String, arms: &[String]) -> ApiResult<()> {
self.db.read_write(|dao| {
for arm in arms {
dao.initialize_multi_armed_bandit(&bandit, arm)?;
}
Ok(())
})?;
Ok(())
}
/// Selects the optimal item (arm) to display to the user based on a multi-armed bandit model.
///
/// This method takes in a `bandit` identifier and a list of possible `arms` (items) and uses a
/// Thompson sampling approach to select the arm with the highest probability of success.
/// For each arm, it retrieves the Beta distribution parameters (alpha and beta) from the
/// database, creates a Beta distribution, and samples from it to estimate the arm's probability
/// of success. The arm with the highest sampled probability is selected and returned.
#[handle_error(Error)]
pub fn bandit_select(&self, bandit: String, arms: &[String]) -> ApiResult<String> {
let mut cache = self.cache.lock();
let mut best_sample = f64::MIN;
let mut selected_arm = String::new();
for arm in arms {
let (alpha, beta) = cache.get_beta_distribution(&bandit, arm, &self.db)?;
// this creates a Beta distribution for an alpha & beta pair
let beta_dist = Beta::new(alpha as f64, beta as f64)
.expect("computing betas dist unexpectedly failed");
// Sample from the Beta distribution
let sampled_prob = beta_dist.sample(&mut rand::thread_rng());
if sampled_prob > best_sample {
best_sample = sampled_prob;
selected_arm.clone_from(arm);
}
}
return Ok(selected_arm);
}
/// Updates the bandit model's arm data based on user interaction (selection or non-selection).
///
/// This method takes in a `bandit` identifier, an `arm` identifier, and a `selected` flag.
/// If `selected` is true, it updates the model to reflect a successful selection of the arm,
/// reinforcing its positive reward probability. If `selected` is false, it updates the
/// beta (failure) distribution of the arm, reflecting a lack of selection and reinforcing
/// its likelihood of a negative outcome.
#[handle_error(Error)]
pub fn bandit_update(&self, bandit: String, arm: String, selected: bool) -> ApiResult<()> {
let mut cache = self.cache.lock();
cache.clear(&bandit, &arm);
self.db
.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, selected))?;
Ok(())
}
/// Retrieves the data for a specific bandit and arm.
#[handle_error(Error)]
pub fn get_bandit_data(&self, bandit: String, arm: String) -> ApiResult<BanditData> {
let bandit_data = self
.db
.read(|dao| dao.retrieve_bandit_data(&bandit, &arm))?;
Ok(bandit_data)
}
}
#[derive(Default)]
pub struct BanditCache {
cache: HashMap<(String, String), (u64, u64)>,
}
impl BanditCache {
/// Creates a new, empty `BanditCache`.
///
/// The cache is initialized as an empty `HashMap` and is used to store
/// precomputed Beta distribution parameters for faster access during
/// Thompson Sampling operations.
pub fn new() -> Self {
Self::default()
}
/// Retrieves the Beta distribution parameters for a given bandit and arm.
///
/// If the parameters for the specified `bandit` and `arm` are already cached,
/// they are returned directly. Otherwise, the parameters are fetched from
/// the database, added to the cache, and then returned.
pub fn get_beta_distribution(
&mut self,
bandit: &str,
arm: &str,
db: &RelevancyDb,
) -> Result<(u64, u64)> {
let key = (bandit.to_string(), arm.to_string());
// Check if the distribution is already cached
if let Some(¶ms) = self.cache.get(&key) {
return Ok(params);
}
let params = db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(bandit, arm))?;
// Cache the retrieved parameters for future use
self.cache.insert(key, params);
Ok(params)
}
/// Clears the cached Beta distribution parameters for a given bandit and arm.
///
/// This removes the cached values for the specified `bandit` and `arm` from the cache.
/// Use this method if the cached parameters are no longer valid or need to be refreshed.
pub fn clear(&mut self, bandit: &str, arm: &str) {
let key = (bandit.to_string(), arm.to_string());
self.cache.remove(&key);
}
}
impl RelevancyStore {
/// Download the interest data from remote settings if needed
#[handle_error(Error)]
pub fn ensure_interest_data_populated(&self) -> ApiResult<()> {
ingest::ensure_interest_data_populated(&self.db)?;
Ok(())
}
pub fn classify(&self, top_urls_by_frecency: Vec<String>) -> Result<InterestVector> {
let mut interest_vector = InterestVector::default();
for url in top_urls_by_frecency {
let interest_count = self.db.read(|dao| dao.get_url_interest_vector(&url))?;
log::trace!("classified: {url} {}", interest_count.summary());
interest_vector = interest_vector + interest_count;
}
Ok(interest_vector)
}
}
/// Interest metrics that we want to send to Glean as part of the validation process. These contain
/// the cosine similarity when comparing the user's interest against various interest vectors that
/// consumers may use.
///
/// Cosine similarly was chosen because it seems easy to calculate. This was then matched against
/// some semi-plausible real-world interest vectors that consumers might use. This is all up for
/// debate and we may decide to switch to some other metrics.
///
/// Similarity values are transformed to integers by multiplying the floating point value by 1000 and
/// rounding. This is to make them compatible with Glean's distribution metrics.
#[derive(uniffi::Record)]
pub struct InterestMetrics {
/// Similarity between the user's interest vector and an interest vector where the element for
/// the user's top interest is copied, but all other interests are set to zero. This measures
/// the highest possible similarity with consumers that used interest vectors with a single
/// interest set.
pub top_single_interest_similarity: u32,
/// The same as before, but the top 2 interests are copied. This measures the highest possible
/// similarity with consumers that used interest vectors with a two interests (note: this means
/// they would need to choose the user's top two interests and have the exact same proportion
/// between them as the user).
pub top_2interest_similarity: u32,
/// The same as before, but the top 3 interests are copied.
pub top_3interest_similarity: u32,
}
#[cfg(test)]
mod test {
use crate::url_hash::hash_url;
use super::*;
use rand::Rng;
use std::collections::HashMap;
fn make_fixture() -> Vec<(String, Interest)> {
vec![
("https://food.com/".to_string(), Interest::Food),
("https://hello.com".to_string(), Interest::Inconclusive),
("https://pasta.com".to_string(), Interest::Food),
("https://dog.com".to_string(), Interest::Animals),
]
}
fn expected_interest_vector() -> InterestVector {
InterestVector {
inconclusive: 1,
animals: 1,
food: 2,
..InterestVector::default()
}
}
fn setup_store(test_id: &'static str) -> RelevancyStore {
let relevancy_store =
RelevancyStore::new(format!("file:test_{test_id}_data?mode=memory&cache=shared"));
relevancy_store
.db
.read_write(|dao| {
for (url, interest) in make_fixture() {
dao.add_url_interest(hash_url(&url).unwrap(), interest)?;
}
Ok(())
})
.expect("Insert should succeed");
relevancy_store
}
#[test]
fn test_ingest() {
let relevancy_store = setup_store("ingest");
let (top_urls, _): (Vec<String>, Vec<Interest>) = make_fixture().into_iter().unzip();
assert_eq!(
relevancy_store.ingest(top_urls).unwrap(),
expected_interest_vector()
);
}
#[test]
fn test_get_user_interest_vector() {
let relevancy_store = setup_store("get_user_interest_vector");
let (top_urls, _): (Vec<String>, Vec<Interest>) = make_fixture().into_iter().unzip();
relevancy_store
.ingest(top_urls)
.expect("Ingest should succeed");
assert_eq!(
relevancy_store.user_interest_vector().unwrap(),
expected_interest_vector()
);
}
#[test]
fn test_thompson_sampling_convergence() {
let relevancy_store = setup_store("thompson_sampling_convergence");
let arms_to_ctr_map: HashMap<String, f64> = [
("wiki".to_string(), 0.1), // 10% CTR
("geolocation".to_string(), 0.3), // 30% CTR
("weather".to_string(), 0.8), // 80% CTR
]
.into_iter()
.collect();
let arm_names: Vec<String> = arms_to_ctr_map.keys().cloned().collect();
let bandit = "provider".to_string();
// initialize bandit
relevancy_store
.bandit_init(bandit.clone(), &arm_names)
.unwrap();
let mut rng = rand::thread_rng();
// Create a HashMap to map arm names to their selection counts
let mut selection_counts: HashMap<String, usize> =
arm_names.iter().map(|name| (name.clone(), 0)).collect();
// Simulate 1000 rounds of Thompson Sampling
for _ in 0..1000 {
// Use Thompson Sampling to select an arm
let selected_arm_name = relevancy_store
.bandit_select(bandit.clone(), &arm_names)
.expect("Failed to select arm");
// increase the selection count for the selected arm
*selection_counts.get_mut(&selected_arm_name).unwrap() += 1;
// get the true CTR for the selected arm
let true_ctr = &arms_to_ctr_map[&selected_arm_name];
// simulate a click or no-click based on the true CTR
let clicked = rng.gen_bool(*true_ctr);
// update beta distribution for arm based on click/no click
relevancy_store
.bandit_update(bandit.clone(), selected_arm_name, clicked)
.expect("Failed to update beta distribution for arm");
}
//retrieve arm with maximum selection count
let most_selected_arm_name = selection_counts
.iter()
.max_by_key(|(_, count)| *count)
.unwrap()
.0;
assert_eq!(
most_selected_arm_name, "weather",
"Thompson Sampling did not favor the best-performing arm"
);
}
#[test]
fn test_get_bandit_data() {
let relevancy_store = setup_store("get_bandit_data");
let bandit = "provider".to_string();
let arm = "wiki".to_string();
// initialize bandit
relevancy_store
.bandit_init(
"provider".to_string(),
&["weather".to_string(), "fakespot".to_string(), arm.clone()],
)
.unwrap();
// update beta distribution for arm based on click/no click
relevancy_store
.bandit_update(bandit.clone(), arm.clone(), true)
.expect("Failed to update beta distribution for arm");
relevancy_store
.bandit_update(bandit.clone(), arm.clone(), true)
.expect("Failed to update beta distribution for arm");
let bandit_data = relevancy_store
.get_bandit_data(bandit.clone(), arm.clone())
.unwrap();
let expected_bandit_data = BanditData {
bandit: bandit.clone(),
arm: arm.clone(),
impressions: 2,
clicks: 2,
alpha: 3,
beta: 1,
};
assert_eq!(bandit_data, expected_bandit_data);
}
}