push/internal/communications/
rate_limiter.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5use crate::error::{info, warn};
6use crate::internal::storage::Storage;
7use std::{
8    str::FromStr,
9    time::{SystemTime, UNIX_EPOCH},
10};
11
12// DB persisted rate limiter.
13// Implementation notes: This saves the timestamp of our latest call and the number of times we have
14// called `Self::check` within the `Self::periodic_interval` interval of time.
15pub struct PersistedRateLimiter {
16    op_name: String,
17    periodic_interval: u64, // In seconds.
18    max_requests_in_interval: u16,
19}
20
21impl PersistedRateLimiter {
22    pub fn new(op_name: &str, periodic_interval: u64, max_requests_in_interval: u16) -> Self {
23        Self {
24            op_name: op_name.to_owned(),
25            periodic_interval,
26            max_requests_in_interval,
27        }
28    }
29
30    pub fn check<S: Storage>(&self, store: &S) -> bool {
31        let (mut timestamp, mut count) = self.impl_get_counters(store);
32
33        let now = now_secs();
34        if (now - timestamp) >= self.periodic_interval {
35            info!(
36                "Resetting. now({}) - {} < {} for {}.",
37                now, timestamp, self.periodic_interval, &self.op_name
38            );
39            count = 0;
40            timestamp = now;
41        } else {
42            info!(
43                "No need to reset inner timestamp and count for {}.",
44                &self.op_name
45            )
46        }
47
48        count += 1;
49        self.impl_persist_counters(store, timestamp, count);
50
51        // within interval counter
52        if count > self.max_requests_in_interval {
53            info!(
54                "Not allowed: count({}) > {} for {}.",
55                count, self.max_requests_in_interval, &self.op_name
56            );
57            return false;
58        }
59
60        info!("Allowed to pass through for {}!", &self.op_name);
61
62        true
63    }
64
65    pub fn reset<S: Storage>(&self, store: &S) {
66        self.impl_persist_counters(store, now_secs(), 0)
67    }
68
69    fn db_meta_keys(&self) -> (String, String) {
70        (
71            format!("ratelimit_{}_timestamp", &self.op_name),
72            format!("ratelimit_{}_count", &self.op_name),
73        )
74    }
75
76    fn impl_get_counters<S: Storage>(&self, store: &S) -> (u64, u16) {
77        let (timestamp_key, count_key) = self.db_meta_keys();
78        (
79            Self::get_meta_integer(store, &timestamp_key),
80            Self::get_meta_integer(store, &count_key),
81        )
82    }
83
84    #[cfg(test)]
85    pub(crate) fn get_counters<S: Storage>(&self, store: &S) -> (u64, u16) {
86        self.impl_get_counters(store)
87    }
88
89    fn get_meta_integer<S: Storage, T: FromStr + Default>(store: &S, key: &str) -> T {
90        store
91            .get_meta(key)
92            .ok()
93            .flatten()
94            .map(|s| s.parse())
95            .transpose()
96            .ok()
97            .flatten()
98            .unwrap_or_default()
99    }
100
101    fn impl_persist_counters<S: Storage>(&self, store: &S, timestamp: u64, count: u16) {
102        let (timestamp_key, count_key) = self.db_meta_keys();
103        let r1 = store.set_meta(&timestamp_key, &timestamp.to_string());
104        let r2 = store.set_meta(&count_key, &count.to_string());
105        if r1.is_err() || r2.is_err() {
106            warn!("Error updating persisted counters for {}.", &self.op_name);
107        }
108    }
109
110    #[cfg(test)]
111    pub(crate) fn persist_counters<S: Storage>(&self, store: &S, timestamp: u64, count: u16) {
112        self.impl_persist_counters(store, timestamp, count)
113    }
114}
115
116fn now_secs() -> u64 {
117    SystemTime::now()
118        .duration_since(UNIX_EPOCH)
119        .expect("Current date before unix epoch.")
120        .as_secs()
121}
122
123#[cfg(test)]
124mod test {
125    use super::*;
126    use crate::error::Result;
127    use crate::Store;
128
129    static PERIODIC_INTERVAL: u64 = 24 * 3600;
130    static VERIFY_NOW_INTERVAL: u64 = PERIODIC_INTERVAL + 3600;
131    static MAX_REQUESTS: u16 = 500;
132
133    #[test]
134    fn test_persisted_rate_limiter_store_counters_roundtrip() -> Result<()> {
135        let limiter = PersistedRateLimiter::new("op1", PERIODIC_INTERVAL, MAX_REQUESTS);
136        let store = Store::open_in_memory()?;
137        limiter.impl_persist_counters(&store, 123, 321);
138        assert_eq!((123, 321), limiter.impl_get_counters(&store));
139        Ok(())
140    }
141
142    #[test]
143    fn test_persisted_rate_limiter_after_interval_counter_resets() -> Result<()> {
144        let limiter = PersistedRateLimiter::new("op1", PERIODIC_INTERVAL, MAX_REQUESTS);
145        let store = Store::open_in_memory()?;
146        limiter.impl_persist_counters(&store, now_secs() - VERIFY_NOW_INTERVAL, 50);
147        assert!(limiter.check(&store));
148        assert_eq!(1, limiter.impl_get_counters(&store).1);
149        Ok(())
150    }
151
152    #[test]
153    fn test_persisted_rate_limiter_false_above_rate_limit() -> Result<()> {
154        let limiter = PersistedRateLimiter::new("op1", PERIODIC_INTERVAL, MAX_REQUESTS);
155        let store = Store::open_in_memory()?;
156        limiter.impl_persist_counters(&store, now_secs(), MAX_REQUESTS + 1);
157        assert!(!limiter.check(&store));
158        assert_eq!(MAX_REQUESTS + 2, limiter.impl_get_counters(&store).1);
159        Ok(())
160    }
161
162    #[test]
163    fn test_persisted_rate_limiter_reset_above_rate_limit_and_interval() -> Result<()> {
164        let limiter = PersistedRateLimiter::new("op1", PERIODIC_INTERVAL, MAX_REQUESTS);
165        let store = Store::open_in_memory()?;
166        limiter.impl_persist_counters(&store, now_secs() - VERIFY_NOW_INTERVAL, 501);
167        assert!(limiter.check(&store));
168        assert_eq!(1, limiter.impl_get_counters(&store).1);
169        Ok(())
170    }
171
172    #[test]
173    fn test_persisted_rate_limiter_no_reset_with_rate_limits() -> Result<()> {
174        let limiter = PersistedRateLimiter::new("op1", PERIODIC_INTERVAL, MAX_REQUESTS);
175        let store = Store::open_in_memory()?;
176        assert!(limiter.check(&store));
177        Ok(())
178    }
179}