nimbus/sampling.rs
1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5//! This module implements the sampling logic required to hash,
6//! randomize and pick branches using pre-set ratios.
7
8use crate::error::{NimbusError, Result};
9use sha2::{Digest, Sha256};
10
11const HASH_BITS: u32 = 48;
12const HASH_LENGTH: u32 = HASH_BITS / 4;
13
14/// Sample by splitting the input space into a series of buckets, checking
15/// if the given input is in a range of buckets
16///
17/// The range to check is defined by a start point and length, and can wrap around
18/// the input space. For example, if there are 100 buckets, and we ask to check 50 buckets
19/// starting from bucket 70, then buckets 70-99 and 0-19 will be checked
20///
21/// # Arguments:
22///
23/// - `input` What will be hashed and matched against the range of the buckets
24/// - `start` the index of the bucket to start checking
25/// - `count` then number of buckets to check
26/// - `total` The total number of buckets to group inputs into
27///
28/// # Returns:
29///
30/// Returns true if the hash generated from the input belongs within the range
31/// otherwise false
32///
33/// # Errors:
34///
35/// Could error in the following cases (but not limited to)
36/// - An error occurred in the hashing process
37/// - an error occurred while checking if the hash belongs in the bucket
38pub(crate) fn bucket_sample<T: serde::Serialize>(
39 input: T,
40 start: u32,
41 count: u32,
42 total: u32,
43) -> Result<bool> {
44 let input_hash = hex::encode(truncated_hash(input)?);
45 let wrapped_start = start % total;
46 let end = wrapped_start + count;
47
48 Ok(if end > total {
49 is_hash_in_bucket(&input_hash, 0, end % total, total)?
50 || is_hash_in_bucket(&input_hash, wrapped_start, total, total)?
51 } else {
52 is_hash_in_bucket(&input_hash, wrapped_start, end, total)?
53 })
54}
55
56/// Sample over a list of ratios such that, over the input space, each
57/// ratio has a number of matches in correct proportion to the other ratios
58///
59/// # Arguments:
60/// - `input`: the input used in the sampling process
61/// - `ratios`: The list of ratios associated with each option
62///
63/// # Example:
64///
65/// Assuming the ratios: `[1, 2, 3, 4]`
66/// 10% of all inputs will return 0, 20% will return 1 and so on
67///
68/// # Returns
69/// Returns an index of the ratio that matched the input
70///
71/// # Errors
72/// Could return an error if the input couldn't be hashed
73pub(crate) fn ratio_sample<T: serde::Serialize>(input: T, ratios: &[u32]) -> Result<usize> {
74 if ratios.is_empty() {
75 return Err(NimbusError::EmptyRatiosError);
76 }
77 let input_hash = hex::encode(truncated_hash(input)?);
78 let ratio_total: u32 = ratios.iter().sum();
79 let mut sample_point = 0;
80 for (i, ratio) in ratios.iter().enumerate() {
81 sample_point += ratio;
82 if input_hash <= fraction_to_key(sample_point as f64 / ratio_total as f64)? {
83 return Ok(i);
84 }
85 }
86 Ok(ratios.len() - 1)
87}
88
89/// Provides a hash of `data`, truncated to the 6 most significant bytes
90/// For consistency with: https://searchfox.org/mozilla-central/source/toolkit/components/utils/Sampling.jsm#79
91/// # Arguments:
92/// - `data`: The data to be hashed
93///
94/// # Returns:
95/// Returns the 6 bytes associated with the SHA-256 of the data
96///
97/// # Errors:
98/// Would return an error if the hashing function fails to generate a hash
99/// that is larger than 6 bytes (Should never occur)
100pub(crate) fn truncated_hash<T: serde::Serialize>(data: T) -> Result<[u8; 6]> {
101 let mut hasher = Sha256::new();
102 let data_str = match serde_json::to_string(&data) {
103 Ok(v) => v,
104 Err(e) => {
105 return Err(NimbusError::JSONError(
106 "data_str = nimbus::sampling::truncated_hash::serde_json::to_string".into(),
107 e.to_string(),
108 ))
109 }
110 };
111 hasher.update(data_str.as_bytes());
112 Ok(hasher.finalize()[0..6].try_into()?)
113}
114
115/// Checks if a given hash (represented as a 6 byte hex string) fits within a bucket range
116///
117/// # Arguments:
118/// - `input_hash_num`: The hash as a 6 byte hex string (12 hex digits)
119/// - `min_bucket`: The minimum bucket number
120/// - `max_bucket`: The maximum bucket number
121/// - `bucket_count`: The number of buckets
122///
123/// # Returns
124/// Returns true if the has fits in the bucket range,
125/// otherwise false
126///
127/// # Errors:
128///
129/// Could return an error if bucket numbers are higher than the bucket count
130fn is_hash_in_bucket(
131 input_hash_num: &str,
132 min_bucket: u32,
133 max_bucket: u32,
134 bucket_count: u32,
135) -> Result<bool> {
136 let min_hash = fraction_to_key(min_bucket as f64 / bucket_count as f64)?;
137 let max_hash = fraction_to_key(max_bucket as f64 / bucket_count as f64)?;
138 Ok(min_hash.as_str() <= input_hash_num && input_hash_num < max_hash.as_str())
139}
140
141/// Maps from the range [0, 1] to [0, 2^48]
142///
143/// # Argument:
144/// - `fraction`: float in the range 0-1
145///
146/// # Returns
147/// returns a hex string representing the fraction multiplied to be within the
148/// [0, 2^48] range
149///
150/// # Errors
151/// returns an error if the fraction not within the 0-1 range
152fn fraction_to_key(fraction: f64) -> Result<String> {
153 if !(0.0..=1.0).contains(&fraction) {
154 return Err(NimbusError::InvalidFraction);
155 }
156 let multiplied = (fraction * (2u64.pow(HASH_BITS) - 1) as f64).floor();
157 let multiplied = format!("{:x}", multiplied as u64);
158 let padding = vec!['0'; HASH_LENGTH as usize - multiplied.len()];
159 let res = padding
160 .into_iter()
161 .chain(multiplied.chars())
162 .collect::<String>();
163 Ok(res)
164}