1use crate::{
6 error::{BehaviorError, NimbusError, Result},
7 stateful::persistence::{Database, StoreId},
8};
9use chrono::{DateTime, Datelike, Duration, TimeZone, Utc};
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::collections::vec_deque::Iter;
13use std::collections::{HashMap, VecDeque};
14use std::fmt;
15use std::hash::{Hash, Hasher};
16use std::str::FromStr;
17use std::sync::{Arc, Mutex};
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub enum Interval {
21 Minutes,
22 Hours,
23 Days,
24 Weeks,
25 Months,
26 Years,
27}
28
29impl Interval {
30 pub fn num_rotations(&self, then: DateTime<Utc>, now: DateTime<Utc>) -> Result<i32> {
31 let date_diff = now - then;
32 Ok(i32::try_from(match self {
33 Interval::Minutes => date_diff.num_minutes(),
34 Interval::Hours => date_diff.num_hours(),
35 Interval::Days => date_diff.num_days(),
36 Interval::Weeks => date_diff.num_weeks(),
37 Interval::Months => date_diff.num_days() / 28,
38 Interval::Years => date_diff.num_days() / 365,
39 })?)
40 }
41
42 pub fn to_duration(&self, count: i64) -> Duration {
43 match self {
44 Interval::Minutes => Duration::minutes(count),
45 Interval::Hours => Duration::hours(count),
46 Interval::Days => Duration::days(count),
47 Interval::Weeks => Duration::weeks(count),
48 Interval::Months => Duration::days(28 * count),
49 Interval::Years => Duration::days(365 * count),
50 }
51 }
52}
53
54impl fmt::Display for Interval {
55 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
56 fmt::Debug::fmt(self, f)
57 }
58}
59
60impl PartialEq for Interval {
61 fn eq(&self, other: &Self) -> bool {
62 self.to_string() == other.to_string()
63 }
64}
65
66impl Eq for Interval {}
67
68impl Hash for Interval {
69 fn hash<H: Hasher>(&self, state: &mut H) {
70 self.to_string().as_bytes().hash(state);
71 }
72}
73
74impl FromStr for Interval {
75 type Err = NimbusError;
76
77 fn from_str(input: &str) -> Result<Self> {
78 Ok(match input {
79 "Minutes" => Self::Minutes,
80 "Hours" => Self::Hours,
81 "Days" => Self::Days,
82 "Weeks" => Self::Weeks,
83 "Months" => Self::Months,
84 "Years" => Self::Years,
85 _ => {
86 return Err(NimbusError::BehaviorError(
87 BehaviorError::IntervalParseError(input.to_string()),
88 ));
89 }
90 })
91 }
92}
93
94#[derive(Clone, Serialize, Deserialize, Debug)]
95pub struct IntervalConfig {
96 bucket_count: usize,
97 interval: Interval,
98}
99
100impl Default for IntervalConfig {
101 fn default() -> Self {
102 Self::new(7, Interval::Days)
103 }
104}
105
106impl IntervalConfig {
107 pub fn new(bucket_count: usize, interval: Interval) -> Self {
108 Self {
109 bucket_count,
110 interval,
111 }
112 }
113}
114
115#[derive(Clone, Serialize, Deserialize, Debug)]
116pub struct IntervalData {
117 pub(crate) buckets: VecDeque<u64>,
118 pub(crate) bucket_count: usize,
119 pub(crate) starting_instant: DateTime<Utc>,
120}
121
122impl Default for IntervalData {
123 fn default() -> Self {
124 Self::new(1)
125 }
126}
127
128impl IntervalData {
129 pub fn new(bucket_count: usize) -> Self {
130 let mut buckets = VecDeque::with_capacity(bucket_count);
131 buckets.push_front(0);
132 let starting_instant = Utc.from_utc_datetime(
134 &Utc::now()
135 .with_month(1)
136 .unwrap()
137 .with_day(1)
138 .unwrap()
139 .date_naive()
140 .and_hms_opt(0, 0, 0)
141 .unwrap(),
142 );
143 Self {
144 buckets,
145 bucket_count,
146 starting_instant,
147 }
148 }
149
150 pub fn increment(&mut self, count: u64) -> Result<()> {
151 self.increment_at(0, count)
152 }
153
154 pub fn increment_at(&mut self, index: usize, count: u64) -> Result<()> {
155 if index < self.bucket_count {
156 let buckets = &mut self.buckets;
157 match buckets.get_mut(index) {
158 Some(x) => *x += count,
159 None => {
160 for _ in buckets.len()..index {
161 buckets.push_back(0);
162 }
163 self.buckets.insert(index, count)
164 }
165 };
166 }
167 Ok(())
168 }
169
170 pub fn rotate(&mut self, num_rotations: i32) -> Result<()> {
171 let num_rotations = usize::min(self.bucket_count, num_rotations as usize);
172 if num_rotations + self.buckets.len() > self.bucket_count {
173 self.buckets.drain((self.bucket_count - num_rotations)..);
174 }
175 for _ in 1..=num_rotations {
176 self.buckets.push_front(0);
177 }
178 Ok(())
179 }
180}
181
182#[derive(Clone, Serialize, Deserialize, Debug)]
183pub struct SingleIntervalCounter {
184 pub data: IntervalData,
185 pub config: IntervalConfig,
186}
187
188impl SingleIntervalCounter {
189 pub fn new(config: IntervalConfig) -> Self {
190 Self {
191 data: IntervalData::new(config.bucket_count),
192 config,
193 }
194 }
195
196 pub fn from_config(bucket_count: usize, interval: Interval) -> Self {
197 let config = IntervalConfig {
198 bucket_count,
199 interval,
200 };
201 Self::new(config)
202 }
203
204 pub fn increment_then(&mut self, then: DateTime<Utc>, count: u64) -> Result<()> {
205 use std::cmp::Ordering;
206 let now = self.data.starting_instant;
207 let rotations = self.config.interval.num_rotations(then, now)?;
208 match rotations.cmp(&0) {
209 Ordering::Less => {
210 return Err(NimbusError::BehaviorError(BehaviorError::InvalidState(
212 "Cannot increment events far into the future".to_string(),
213 )));
214 }
215 Ordering::Equal => {
216 if now < then {
217 self.data.increment_at(0, count)?;
218 } else {
219 self.data.increment_at(1, count)?;
220 }
221 }
222 Ordering::Greater => self.data.increment_at(1 + rotations as usize, count)?,
223 }
224 Ok(())
225 }
226
227 pub fn increment(&mut self, count: u64) -> Result<()> {
228 self.data.increment(count)
229 }
230
231 pub fn maybe_advance(&mut self, now: DateTime<Utc>) -> Result<()> {
232 let rotations = self
233 .config
234 .interval
235 .num_rotations(self.data.starting_instant, now)?;
236 if rotations > 0 {
237 self.data.starting_instant += self.config.interval.to_duration(rotations.into());
238 return self.data.rotate(rotations);
239 }
240 Ok(())
241 }
242}
243
244#[derive(Serialize, Deserialize, Clone, Debug)]
245pub struct MultiIntervalCounter {
246 pub intervals: HashMap<Interval, SingleIntervalCounter>,
247}
248
249impl MultiIntervalCounter {
250 pub fn new(intervals: Vec<SingleIntervalCounter>) -> Self {
251 Self {
252 intervals: intervals
253 .into_iter()
254 .map(|v| (v.config.interval.clone(), v))
255 .collect::<HashMap<Interval, SingleIntervalCounter>>(),
256 }
257 }
258
259 pub fn increment_then(&mut self, then: DateTime<Utc>, count: u64) -> Result<()> {
260 self.intervals
261 .iter_mut()
262 .try_for_each(|(_, v)| v.increment_then(then, count))
263 }
264
265 pub fn increment(&mut self, count: u64) -> Result<()> {
266 self.intervals
267 .iter_mut()
268 .try_for_each(|(_, v)| v.increment(count))
269 }
270
271 pub fn maybe_advance(&mut self, now: DateTime<Utc>) -> Result<()> {
272 self.intervals
273 .iter_mut()
274 .try_for_each(|(_, v)| v.maybe_advance(now))
275 }
276}
277
278impl Default for MultiIntervalCounter {
279 fn default() -> Self {
280 Self::new(vec![
281 SingleIntervalCounter::new(IntervalConfig {
282 bucket_count: 60,
283 interval: Interval::Minutes,
284 }),
285 SingleIntervalCounter::new(IntervalConfig {
286 bucket_count: 72,
287 interval: Interval::Hours,
288 }),
289 SingleIntervalCounter::new(IntervalConfig {
290 bucket_count: 56,
291 interval: Interval::Days,
292 }),
293 SingleIntervalCounter::new(IntervalConfig {
294 bucket_count: 52,
295 interval: Interval::Weeks,
296 }),
297 SingleIntervalCounter::new(IntervalConfig {
298 bucket_count: 12,
299 interval: Interval::Months,
300 }),
301 SingleIntervalCounter::new(IntervalConfig {
302 bucket_count: 4,
303 interval: Interval::Years,
304 }),
305 ])
306 }
307}
308
309#[derive(Debug)]
310pub enum EventQueryType {
311 Sum,
312 CountNonZero,
313 AveragePerInterval,
314 AveragePerNonZeroInterval,
315 LastSeen,
316}
317
318impl fmt::Display for EventQueryType {
319 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
320 fmt::Debug::fmt(self, f)
321 }
322}
323
324impl EventQueryType {
325 pub fn perform_query(&self, buckets: Iter<u64>, num_buckets: usize) -> Result<f64> {
326 Ok(match self {
327 Self::Sum => buckets.sum::<u64>() as f64,
328 Self::CountNonZero => buckets.filter(|v| v > &&0u64).count() as f64,
329 Self::AveragePerInterval => buckets.sum::<u64>() as f64 / num_buckets as f64,
330 Self::AveragePerNonZeroInterval => {
331 let values = buckets.fold((0, 0), |accum, item| {
332 (
333 accum.0 + item,
334 if item > &0 { accum.1 + 1 } else { accum.1 },
335 )
336 });
337 if values.1 == 0 {
338 0.0
339 } else {
340 values.0 as f64 / values.1 as f64
341 }
342 }
343 Self::LastSeen => match buckets.into_iter().position(|v| v > &0) {
344 Some(v) => v as f64,
345 None => f64::MAX,
346 },
347 })
348 }
349
350 fn validate_counting_arguments(
351 &self,
352 args: &[Value],
353 ) -> Result<(String, Interval, usize, usize)> {
354 if args.len() < 3 || args.len() > 4 {
355 return Err(NimbusError::TransformParameterError(format!(
356 "event transform {} requires 2-3 parameters",
357 self
358 )));
359 }
360 let event = match serde_json::from_value::<String>(args.first().unwrap().clone()) {
361 Ok(v) => v,
362 Err(e) => return Err(NimbusError::JSONError("event = nimbus::stateful::behavior::EventQueryType::validate_counting_arguments::serde_json::from_value".into(), e.to_string()))
363 };
364 let interval = match serde_json::from_value::<String>(args.get(1).unwrap().clone()) {
365 Ok(v) => v,
366 Err(e) => return Err(NimbusError::JSONError("interval = nimbus::stateful::behavior::EventQueryType::validate_counting_arguments::serde_json::from_value".into(), e.to_string()))
367 };
368 let interval = Interval::from_str(&interval)?;
369 let num_buckets = match args.get(2).unwrap().as_f64() {
370 Some(v) => v,
371 None => {
372 return Err(NimbusError::TransformParameterError(format!(
373 "event transform {} requires a positive number as the second parameter",
374 self
375 )));
376 }
377 } as usize;
378 let zero = &Value::from(0);
379 let starting_bucket = match args.get(3).unwrap_or(zero).as_f64() {
380 Some(v) => v,
381 None => {
382 return Err(NimbusError::TransformParameterError(format!(
383 "event transform {} requires a positive number as the third parameter",
384 self
385 )));
386 }
387 } as usize;
388
389 Ok((event, interval, num_buckets, starting_bucket))
390 }
391
392 fn validate_last_seen_arguments(
393 &self,
394 args: &[Value],
395 ) -> Result<(String, Interval, usize, usize)> {
396 if args.len() < 2 || args.len() > 3 {
397 return Err(NimbusError::TransformParameterError(format!(
398 "event transform {} requires 1-2 parameters",
399 self
400 )));
401 }
402 let event = match serde_json::from_value::<String>(args.first().unwrap().clone()) {
403 Ok(v) => v,
404 Err(e) => return Err(NimbusError::JSONError("event = nimbus::stateful::behavior::EventQueryType::validate_last_seen_arguments::serde_json::from_value".into(), e.to_string()))
405 };
406 let interval = match serde_json::from_value::<String>(args.get(1).unwrap().clone()) {
407 Ok(v) => v,
408 Err(e) => return Err(NimbusError::JSONError("interval = nimbus::stateful::behavior::EventQueryType::validate_last_seen_arguments::serde_json::from_value".into(), e.to_string()))
409 };
410 let interval = Interval::from_str(&interval)?;
411 let zero = &Value::from(0);
412 let starting_bucket = match args.get(2).unwrap_or(zero).as_f64() {
413 Some(v) => v,
414 None => {
415 return Err(NimbusError::TransformParameterError(format!(
416 "event transform {} requires a positive number as the second parameter",
417 self
418 )));
419 }
420 } as usize;
421
422 Ok((
423 event,
424 interval,
425 usize::MAX - starting_bucket,
426 starting_bucket,
427 ))
428 }
429
430 pub fn validate_arguments(&self, args: &[Value]) -> Result<(String, Interval, usize, usize)> {
431 Ok(match self {
435 Self::Sum
436 | Self::CountNonZero
437 | Self::AveragePerInterval
438 | Self::AveragePerNonZeroInterval => self.validate_counting_arguments(args)?,
439 Self::LastSeen => self.validate_last_seen_arguments(args)?,
440 })
441 }
442
443 pub fn validate_query(maybe_query: &str) -> Result<bool> {
444 let regex = regex::Regex::new(
445 r#"^(?:"[^"']+"|'[^"']+')\|event(?:Sum|LastSeen|CountNonZero|Average|AveragePerNonZeroInterval)\(["'](?:Years|Months|Weeks|Days|Hours|Minutes)["'],\s*\d+\s*(?:,\s*\d+\s*)?\)$"#,
446 )?;
447 Ok(regex.is_match(maybe_query))
448 }
449
450 fn error_value(&self) -> f64 {
451 match self {
452 Self::LastSeen => f64::MAX,
453 _ => 0.0,
454 }
455 }
456}
457
458#[derive(Default, Serialize, Deserialize, Debug, Clone)]
459pub struct EventStore {
460 pub(crate) events: HashMap<String, MultiIntervalCounter>,
461 datum: Option<DateTime<Utc>>,
462}
463
464impl From<Vec<(String, MultiIntervalCounter)>> for EventStore {
465 fn from(event_store: Vec<(String, MultiIntervalCounter)>) -> Self {
466 Self {
467 events: HashMap::from_iter(event_store),
468 datum: None,
469 }
470 }
471}
472
473impl From<HashMap<String, MultiIntervalCounter>> for EventStore {
474 fn from(event_store: HashMap<String, MultiIntervalCounter>) -> Self {
475 Self {
476 events: event_store,
477 datum: None,
478 }
479 }
480}
481
482impl TryFrom<&Database> for EventStore {
483 type Error = NimbusError;
484
485 fn try_from(db: &Database) -> Result<Self, NimbusError> {
486 let reader = db.read()?;
487 let events = db
488 .get_store(StoreId::EventCounts)
489 .collect_all::<(String, MultiIntervalCounter), _>(&reader)?;
490 Ok(EventStore::from(events))
491 }
492}
493
494impl EventStore {
495 pub fn new() -> Self {
496 Self {
497 events: HashMap::<String, MultiIntervalCounter>::new(),
498 datum: None,
499 }
500 }
501
502 fn now(&self) -> DateTime<Utc> {
503 self.datum.unwrap_or_else(Utc::now)
504 }
505
506 pub fn advance_datum(&mut self, duration: Duration) {
507 self.datum = Some(self.now() + duration);
508 }
509
510 pub fn read_from_db(&mut self, db: &Database) -> Result<()> {
511 let reader = db.read()?;
512
513 self.events =
514 HashMap::from_iter(
515 db.get_store(StoreId::EventCounts)
516 .collect_all::<(String, MultiIntervalCounter), _>(&reader)?,
517 );
518
519 Ok(())
520 }
521
522 pub fn record_event(
523 &mut self,
524 count: u64,
525 event_id: &str,
526 now: Option<DateTime<Utc>>,
527 ) -> Result<()> {
528 let now = now.unwrap_or_else(|| self.now());
529 let counter = self.get_or_create_counter(event_id);
530 counter.maybe_advance(now)?;
531 counter.increment(count)
532 }
533
534 pub fn record_past_event(
535 &mut self,
536 count: u64,
537 event_id: &str,
538 now: Option<DateTime<Utc>>,
539 duration: Duration,
540 ) -> Result<()> {
541 let now = now.unwrap_or_else(|| self.now());
542 let then = now - duration;
543 let counter = self.get_or_create_counter(event_id);
544 counter.maybe_advance(now)?;
545 counter.increment_then(then, count)
546 }
547
548 fn get_or_create_counter(&mut self, event_id: &str) -> &mut MultiIntervalCounter {
549 if !self.events.contains_key(event_id) {
550 let new_counter = Default::default();
551 self.events.insert(event_id.to_string(), new_counter);
552 }
553 self.events.get_mut(event_id).unwrap()
554 }
555
556 pub fn persist_data(&self, db: &Database) -> Result<()> {
557 let mut writer = db.write()?;
558 self.events.iter().try_for_each(|(key, value)| {
559 db.get_store(StoreId::EventCounts)
560 .put(&mut writer, key, &(key.clone(), value.clone()))
561 })?;
562 writer.commit()?;
563 Ok(())
564 }
565
566 pub fn clear(&mut self, db: &Database) -> Result<()> {
567 self.events = HashMap::<String, MultiIntervalCounter>::new();
568 self.datum = None;
569 self.persist_data(db)?;
570 Ok(())
571 }
572
573 pub fn query(
574 &mut self,
575 event_id: &str,
576 interval: Interval,
577 num_buckets: usize,
578 starting_bucket: usize,
579 query_type: EventQueryType,
580 ) -> Result<f64> {
581 let now = self.now();
582 if let Some(counter) = self.events.get_mut(event_id) {
583 counter.maybe_advance(now)?;
584 if let Some(single_counter) = counter.intervals.get(&interval) {
585 let safe_range = 0..single_counter.data.buckets.len();
586 if !safe_range.contains(&starting_bucket) {
587 return Ok(query_type.error_value());
588 }
589 let max = usize::min(
590 num_buckets + starting_bucket,
591 single_counter.data.buckets.len(),
592 );
593 let buckets = single_counter.data.buckets.range(starting_bucket..max);
594 return query_type.perform_query(buckets, num_buckets);
595 }
596 }
597 Ok(query_type.error_value())
598 }
599}
600
601pub fn query_event_store(
602 event_store: Arc<Mutex<EventStore>>,
603 query_type: EventQueryType,
604 args: &[Value],
605) -> Result<Value> {
606 let (event, interval, num_buckets, starting_bucket) = query_type.validate_arguments(args)?;
607
608 Ok(json!(event_store.lock().unwrap().query(
609 &event,
610 interval,
611 num_buckets,
612 starting_bucket,
613 query_type,
614 )?))
615}