nimbus/stateful/
behavior.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5use crate::{
6    error::{BehaviorError, NimbusError, Result},
7    stateful::persistence::{Database, StoreId},
8};
9use chrono::{DateTime, Datelike, Duration, TimeZone, Utc};
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::collections::vec_deque::Iter;
13use std::collections::{HashMap, VecDeque};
14use std::fmt;
15use std::hash::{Hash, Hasher};
16use std::str::FromStr;
17use std::sync::{Arc, Mutex};
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub enum Interval {
21    Minutes,
22    Hours,
23    Days,
24    Weeks,
25    Months,
26    Years,
27}
28
29impl Interval {
30    pub fn num_rotations(&self, then: DateTime<Utc>, now: DateTime<Utc>) -> Result<i32> {
31        let date_diff = now - then;
32        Ok(i32::try_from(match self {
33            Interval::Minutes => date_diff.num_minutes(),
34            Interval::Hours => date_diff.num_hours(),
35            Interval::Days => date_diff.num_days(),
36            Interval::Weeks => date_diff.num_weeks(),
37            Interval::Months => date_diff.num_days() / 28,
38            Interval::Years => date_diff.num_days() / 365,
39        })?)
40    }
41
42    pub fn to_duration(&self, count: i64) -> Duration {
43        match self {
44            Interval::Minutes => Duration::minutes(count),
45            Interval::Hours => Duration::hours(count),
46            Interval::Days => Duration::days(count),
47            Interval::Weeks => Duration::weeks(count),
48            Interval::Months => Duration::days(28 * count),
49            Interval::Years => Duration::days(365 * count),
50        }
51    }
52}
53
54impl fmt::Display for Interval {
55    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
56        fmt::Debug::fmt(self, f)
57    }
58}
59
60impl PartialEq for Interval {
61    fn eq(&self, other: &Self) -> bool {
62        self.to_string() == other.to_string()
63    }
64}
65
66impl Eq for Interval {}
67
68impl Hash for Interval {
69    fn hash<H: Hasher>(&self, state: &mut H) {
70        self.to_string().as_bytes().hash(state);
71    }
72}
73
74impl FromStr for Interval {
75    type Err = NimbusError;
76
77    fn from_str(input: &str) -> Result<Self> {
78        Ok(match input {
79            "Minutes" => Self::Minutes,
80            "Hours" => Self::Hours,
81            "Days" => Self::Days,
82            "Weeks" => Self::Weeks,
83            "Months" => Self::Months,
84            "Years" => Self::Years,
85            _ => {
86                return Err(NimbusError::BehaviorError(
87                    BehaviorError::IntervalParseError(input.to_string()),
88                ));
89            }
90        })
91    }
92}
93
94#[derive(Clone, Serialize, Deserialize, Debug)]
95pub struct IntervalConfig {
96    bucket_count: usize,
97    interval: Interval,
98}
99
100impl Default for IntervalConfig {
101    fn default() -> Self {
102        Self::new(7, Interval::Days)
103    }
104}
105
106impl IntervalConfig {
107    pub fn new(bucket_count: usize, interval: Interval) -> Self {
108        Self {
109            bucket_count,
110            interval,
111        }
112    }
113}
114
115#[derive(Clone, Serialize, Deserialize, Debug)]
116pub struct IntervalData {
117    pub(crate) buckets: VecDeque<u64>,
118    pub(crate) bucket_count: usize,
119    pub(crate) starting_instant: DateTime<Utc>,
120}
121
122impl Default for IntervalData {
123    fn default() -> Self {
124        Self::new(1)
125    }
126}
127
128impl IntervalData {
129    pub fn new(bucket_count: usize) -> Self {
130        let mut buckets = VecDeque::with_capacity(bucket_count);
131        buckets.push_front(0);
132        // Set the starting instant to Jan 1 00:00:00 in order to sync rotations
133        let starting_instant = Utc.from_utc_datetime(
134            &Utc::now()
135                .with_month(1)
136                .unwrap()
137                .with_day(1)
138                .unwrap()
139                .date_naive()
140                .and_hms_opt(0, 0, 0)
141                .unwrap(),
142        );
143        Self {
144            buckets,
145            bucket_count,
146            starting_instant,
147        }
148    }
149
150    pub fn increment(&mut self, count: u64) -> Result<()> {
151        self.increment_at(0, count)
152    }
153
154    pub fn increment_at(&mut self, index: usize, count: u64) -> Result<()> {
155        if index < self.bucket_count {
156            let buckets = &mut self.buckets;
157            match buckets.get_mut(index) {
158                Some(x) => *x += count,
159                None => {
160                    for _ in buckets.len()..index {
161                        buckets.push_back(0);
162                    }
163                    self.buckets.insert(index, count)
164                }
165            };
166        }
167        Ok(())
168    }
169
170    pub fn rotate(&mut self, num_rotations: i32) -> Result<()> {
171        let num_rotations = usize::min(self.bucket_count, num_rotations as usize);
172        if num_rotations + self.buckets.len() > self.bucket_count {
173            self.buckets.drain((self.bucket_count - num_rotations)..);
174        }
175        for _ in 1..=num_rotations {
176            self.buckets.push_front(0);
177        }
178        Ok(())
179    }
180}
181
182#[derive(Clone, Serialize, Deserialize, Debug)]
183pub struct SingleIntervalCounter {
184    pub data: IntervalData,
185    pub config: IntervalConfig,
186}
187
188impl SingleIntervalCounter {
189    pub fn new(config: IntervalConfig) -> Self {
190        Self {
191            data: IntervalData::new(config.bucket_count),
192            config,
193        }
194    }
195
196    pub fn from_config(bucket_count: usize, interval: Interval) -> Self {
197        let config = IntervalConfig {
198            bucket_count,
199            interval,
200        };
201        Self::new(config)
202    }
203
204    pub fn increment_then(&mut self, then: DateTime<Utc>, count: u64) -> Result<()> {
205        use std::cmp::Ordering;
206        let now = self.data.starting_instant;
207        let rotations = self.config.interval.num_rotations(then, now)?;
208        match rotations.cmp(&0) {
209            Ordering::Less => {
210                /* We can't increment in the future */
211                return Err(NimbusError::BehaviorError(BehaviorError::InvalidState(
212                    "Cannot increment events far into the future".to_string(),
213                )));
214            }
215            Ordering::Equal => {
216                if now < then {
217                    self.data.increment_at(0, count)?;
218                } else {
219                    self.data.increment_at(1, count)?;
220                }
221            }
222            Ordering::Greater => self.data.increment_at(1 + rotations as usize, count)?,
223        }
224        Ok(())
225    }
226
227    pub fn increment(&mut self, count: u64) -> Result<()> {
228        self.data.increment(count)
229    }
230
231    pub fn maybe_advance(&mut self, now: DateTime<Utc>) -> Result<()> {
232        let rotations = self
233            .config
234            .interval
235            .num_rotations(self.data.starting_instant, now)?;
236        if rotations > 0 {
237            self.data.starting_instant += self.config.interval.to_duration(rotations.into());
238            return self.data.rotate(rotations);
239        }
240        Ok(())
241    }
242}
243
244#[derive(Serialize, Deserialize, Clone, Debug)]
245pub struct MultiIntervalCounter {
246    pub intervals: HashMap<Interval, SingleIntervalCounter>,
247}
248
249impl MultiIntervalCounter {
250    pub fn new(intervals: Vec<SingleIntervalCounter>) -> Self {
251        Self {
252            intervals: intervals
253                .into_iter()
254                .map(|v| (v.config.interval.clone(), v))
255                .collect::<HashMap<Interval, SingleIntervalCounter>>(),
256        }
257    }
258
259    pub fn increment_then(&mut self, then: DateTime<Utc>, count: u64) -> Result<()> {
260        self.intervals
261            .iter_mut()
262            .try_for_each(|(_, v)| v.increment_then(then, count))
263    }
264
265    pub fn increment(&mut self, count: u64) -> Result<()> {
266        self.intervals
267            .iter_mut()
268            .try_for_each(|(_, v)| v.increment(count))
269    }
270
271    pub fn maybe_advance(&mut self, now: DateTime<Utc>) -> Result<()> {
272        self.intervals
273            .iter_mut()
274            .try_for_each(|(_, v)| v.maybe_advance(now))
275    }
276}
277
278impl Default for MultiIntervalCounter {
279    fn default() -> Self {
280        Self::new(vec![
281            SingleIntervalCounter::new(IntervalConfig {
282                bucket_count: 60,
283                interval: Interval::Minutes,
284            }),
285            SingleIntervalCounter::new(IntervalConfig {
286                bucket_count: 72,
287                interval: Interval::Hours,
288            }),
289            SingleIntervalCounter::new(IntervalConfig {
290                bucket_count: 56,
291                interval: Interval::Days,
292            }),
293            SingleIntervalCounter::new(IntervalConfig {
294                bucket_count: 52,
295                interval: Interval::Weeks,
296            }),
297            SingleIntervalCounter::new(IntervalConfig {
298                bucket_count: 12,
299                interval: Interval::Months,
300            }),
301            SingleIntervalCounter::new(IntervalConfig {
302                bucket_count: 4,
303                interval: Interval::Years,
304            }),
305        ])
306    }
307}
308
309#[derive(Debug)]
310pub enum EventQueryType {
311    Sum,
312    CountNonZero,
313    AveragePerInterval,
314    AveragePerNonZeroInterval,
315    LastSeen,
316}
317
318impl fmt::Display for EventQueryType {
319    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
320        fmt::Debug::fmt(self, f)
321    }
322}
323
324impl EventQueryType {
325    pub fn perform_query(&self, buckets: Iter<u64>, num_buckets: usize) -> Result<f64> {
326        Ok(match self {
327            Self::Sum => buckets.sum::<u64>() as f64,
328            Self::CountNonZero => buckets.filter(|v| v > &&0u64).count() as f64,
329            Self::AveragePerInterval => buckets.sum::<u64>() as f64 / num_buckets as f64,
330            Self::AveragePerNonZeroInterval => {
331                let values = buckets.fold((0, 0), |accum, item| {
332                    (
333                        accum.0 + item,
334                        if item > &0 { accum.1 + 1 } else { accum.1 },
335                    )
336                });
337                if values.1 == 0 {
338                    0.0
339                } else {
340                    values.0 as f64 / values.1 as f64
341                }
342            }
343            Self::LastSeen => match buckets.into_iter().position(|v| v > &0) {
344                Some(v) => v as f64,
345                None => f64::MAX,
346            },
347        })
348    }
349
350    fn validate_counting_arguments(
351        &self,
352        args: &[Value],
353    ) -> Result<(String, Interval, usize, usize)> {
354        if args.len() < 3 || args.len() > 4 {
355            return Err(NimbusError::TransformParameterError(format!(
356                "event transform {} requires 2-3 parameters",
357                self
358            )));
359        }
360        let event = match serde_json::from_value::<String>(args.first().unwrap().clone()) {
361            Ok(v) => v,
362            Err(e) => return Err(NimbusError::JSONError("event = nimbus::stateful::behavior::EventQueryType::validate_counting_arguments::serde_json::from_value".into(), e.to_string()))
363        };
364        let interval = match serde_json::from_value::<String>(args.get(1).unwrap().clone()) {
365            Ok(v) => v,
366            Err(e) => return Err(NimbusError::JSONError("interval = nimbus::stateful::behavior::EventQueryType::validate_counting_arguments::serde_json::from_value".into(), e.to_string()))
367        };
368        let interval = Interval::from_str(&interval)?;
369        let num_buckets = match args.get(2).unwrap().as_f64() {
370            Some(v) => v,
371            None => {
372                return Err(NimbusError::TransformParameterError(format!(
373                    "event transform {} requires a positive number as the second parameter",
374                    self
375                )));
376            }
377        } as usize;
378        let zero = &Value::from(0);
379        let starting_bucket = match args.get(3).unwrap_or(zero).as_f64() {
380            Some(v) => v,
381            None => {
382                return Err(NimbusError::TransformParameterError(format!(
383                    "event transform {} requires a positive number as the third parameter",
384                    self
385                )));
386            }
387        } as usize;
388
389        Ok((event, interval, num_buckets, starting_bucket))
390    }
391
392    fn validate_last_seen_arguments(
393        &self,
394        args: &[Value],
395    ) -> Result<(String, Interval, usize, usize)> {
396        if args.len() < 2 || args.len() > 3 {
397            return Err(NimbusError::TransformParameterError(format!(
398                "event transform {} requires 1-2 parameters",
399                self
400            )));
401        }
402        let event = match serde_json::from_value::<String>(args.first().unwrap().clone()) {
403            Ok(v) => v,
404            Err(e) => return Err(NimbusError::JSONError("event = nimbus::stateful::behavior::EventQueryType::validate_last_seen_arguments::serde_json::from_value".into(), e.to_string()))
405        };
406        let interval = match serde_json::from_value::<String>(args.get(1).unwrap().clone()) {
407            Ok(v) => v,
408            Err(e) => return Err(NimbusError::JSONError("interval = nimbus::stateful::behavior::EventQueryType::validate_last_seen_arguments::serde_json::from_value".into(), e.to_string()))
409        };
410        let interval = Interval::from_str(&interval)?;
411        let zero = &Value::from(0);
412        let starting_bucket = match args.get(2).unwrap_or(zero).as_f64() {
413            Some(v) => v,
414            None => {
415                return Err(NimbusError::TransformParameterError(format!(
416                    "event transform {} requires a positive number as the second parameter",
417                    self
418                )));
419            }
420        } as usize;
421
422        Ok((
423            event,
424            interval,
425            usize::MAX - starting_bucket,
426            starting_bucket,
427        ))
428    }
429
430    pub fn validate_arguments(&self, args: &[Value]) -> Result<(String, Interval, usize, usize)> {
431        // `args` is an array of values sent by the evaluator for a JEXL transform.
432        // The first parameter will always be the event_id, and subsequent parameters are up to the developer's discretion.
433        // All parameters should be validated, and a `TransformParameterError` should be sent when there is an error.
434        Ok(match self {
435            Self::Sum
436            | Self::CountNonZero
437            | Self::AveragePerInterval
438            | Self::AveragePerNonZeroInterval => self.validate_counting_arguments(args)?,
439            Self::LastSeen => self.validate_last_seen_arguments(args)?,
440        })
441    }
442
443    pub fn validate_query(maybe_query: &str) -> Result<bool> {
444        let regex = regex::Regex::new(
445            r#"^(?:"[^"']+"|'[^"']+')\|event(?:Sum|LastSeen|CountNonZero|Average|AveragePerNonZeroInterval)\(["'](?:Years|Months|Weeks|Days|Hours|Minutes)["'],\s*\d+\s*(?:,\s*\d+\s*)?\)$"#,
446        )?;
447        Ok(regex.is_match(maybe_query))
448    }
449
450    fn error_value(&self) -> f64 {
451        match self {
452            Self::LastSeen => f64::MAX,
453            _ => 0.0,
454        }
455    }
456}
457
458#[derive(Default, Serialize, Deserialize, Debug, Clone)]
459pub struct EventStore {
460    pub(crate) events: HashMap<String, MultiIntervalCounter>,
461    datum: Option<DateTime<Utc>>,
462}
463
464impl From<Vec<(String, MultiIntervalCounter)>> for EventStore {
465    fn from(event_store: Vec<(String, MultiIntervalCounter)>) -> Self {
466        Self {
467            events: HashMap::from_iter(event_store),
468            datum: None,
469        }
470    }
471}
472
473impl From<HashMap<String, MultiIntervalCounter>> for EventStore {
474    fn from(event_store: HashMap<String, MultiIntervalCounter>) -> Self {
475        Self {
476            events: event_store,
477            datum: None,
478        }
479    }
480}
481
482impl TryFrom<&Database> for EventStore {
483    type Error = NimbusError;
484
485    fn try_from(db: &Database) -> Result<Self, NimbusError> {
486        let reader = db.read()?;
487        let events = db
488            .get_store(StoreId::EventCounts)
489            .collect_all::<(String, MultiIntervalCounter), _>(&reader)?;
490        Ok(EventStore::from(events))
491    }
492}
493
494impl EventStore {
495    pub fn new() -> Self {
496        Self {
497            events: HashMap::<String, MultiIntervalCounter>::new(),
498            datum: None,
499        }
500    }
501
502    fn now(&self) -> DateTime<Utc> {
503        self.datum.unwrap_or_else(Utc::now)
504    }
505
506    pub fn advance_datum(&mut self, duration: Duration) {
507        self.datum = Some(self.now() + duration);
508    }
509
510    pub fn read_from_db(&mut self, db: &Database) -> Result<()> {
511        let reader = db.read()?;
512
513        self.events =
514            HashMap::from_iter(
515                db.get_store(StoreId::EventCounts)
516                    .collect_all::<(String, MultiIntervalCounter), _>(&reader)?,
517            );
518
519        Ok(())
520    }
521
522    pub fn record_event(
523        &mut self,
524        count: u64,
525        event_id: &str,
526        now: Option<DateTime<Utc>>,
527    ) -> Result<()> {
528        let now = now.unwrap_or_else(|| self.now());
529        let counter = self.get_or_create_counter(event_id);
530        counter.maybe_advance(now)?;
531        counter.increment(count)
532    }
533
534    pub fn record_past_event(
535        &mut self,
536        count: u64,
537        event_id: &str,
538        now: Option<DateTime<Utc>>,
539        duration: Duration,
540    ) -> Result<()> {
541        let now = now.unwrap_or_else(|| self.now());
542        let then = now - duration;
543        let counter = self.get_or_create_counter(event_id);
544        counter.maybe_advance(now)?;
545        counter.increment_then(then, count)
546    }
547
548    fn get_or_create_counter(&mut self, event_id: &str) -> &mut MultiIntervalCounter {
549        if !self.events.contains_key(event_id) {
550            let new_counter = Default::default();
551            self.events.insert(event_id.to_string(), new_counter);
552        }
553        self.events.get_mut(event_id).unwrap()
554    }
555
556    pub fn persist_data(&self, db: &Database) -> Result<()> {
557        let mut writer = db.write()?;
558        self.events.iter().try_for_each(|(key, value)| {
559            db.get_store(StoreId::EventCounts)
560                .put(&mut writer, key, &(key.clone(), value.clone()))
561        })?;
562        writer.commit()?;
563        Ok(())
564    }
565
566    pub fn clear(&mut self, db: &Database) -> Result<()> {
567        self.events = HashMap::<String, MultiIntervalCounter>::new();
568        self.datum = None;
569        self.persist_data(db)?;
570        Ok(())
571    }
572
573    pub fn query(
574        &mut self,
575        event_id: &str,
576        interval: Interval,
577        num_buckets: usize,
578        starting_bucket: usize,
579        query_type: EventQueryType,
580    ) -> Result<f64> {
581        let now = self.now();
582        if let Some(counter) = self.events.get_mut(event_id) {
583            counter.maybe_advance(now)?;
584            if let Some(single_counter) = counter.intervals.get(&interval) {
585                let safe_range = 0..single_counter.data.buckets.len();
586                if !safe_range.contains(&starting_bucket) {
587                    return Ok(query_type.error_value());
588                }
589                let max = usize::min(
590                    num_buckets + starting_bucket,
591                    single_counter.data.buckets.len(),
592                );
593                let buckets = single_counter.data.buckets.range(starting_bucket..max);
594                return query_type.perform_query(buckets, num_buckets);
595            }
596        }
597        Ok(query_type.error_value())
598    }
599}
600
601pub fn query_event_store(
602    event_store: Arc<Mutex<EventStore>>,
603    query_type: EventQueryType,
604    args: &[Value],
605) -> Result<Value> {
606    let (event, interval, num_buckets, starting_bucket) = query_type.validate_arguments(args)?;
607
608    Ok(json!(event_store.lock().unwrap().query(
609        &event,
610        interval,
611        num_buckets,
612        starting_bucket,
613        query_type,
614    )?))
615}