1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
16use std::collections::{HashMap, HashSet};
17
18use crate::error::{self, debug, info, PushError, Result};
19use crate::internal::communications::{Connection, PersistedRateLimiter};
20use crate::internal::config::PushConfiguration;
21use crate::internal::crypto::KeyV1 as Key;
22use crate::internal::storage::{PushRecord, Storage};
23use crate::{KeyInfo, PushSubscriptionChanged, SubscriptionInfo, SubscriptionResponse};
24
25use super::crypto::{Cryptography, PushPayload};
26const UPDATE_RATE_LIMITER_INTERVAL: u64 = 24 * 60 * 60; const UPDATE_RATE_LIMITER_MAX_CALLS: u16 = 500; impl From<Key> for KeyInfo {
30 fn from(key: Key) -> Self {
31 KeyInfo {
32 auth: URL_SAFE_NO_PAD.encode(key.auth_secret()),
33 p256dh: URL_SAFE_NO_PAD.encode(key.public_key()),
34 }
35 }
36}
37
38impl From<PushRecord> for PushSubscriptionChanged {
39 fn from(record: PushRecord) -> Self {
40 PushSubscriptionChanged {
41 channel_id: record.channel_id,
42 scope: record.scope,
43 }
44 }
45}
46
47impl TryFrom<PushRecord> for SubscriptionResponse {
48 type Error = PushError;
49 fn try_from(value: PushRecord) -> Result<Self, Self::Error> {
50 Ok(SubscriptionResponse {
51 channel_id: value.channel_id,
52 subscription_info: SubscriptionInfo {
53 endpoint: value.endpoint,
54 keys: Key::deserialize(&value.key)?.into(),
55 },
56 })
57 }
58}
59
60#[derive(Debug)]
61pub struct DecryptResponse {
62 pub result: Vec<i8>,
63 pub scope: String,
64}
65
66pub struct PushManager<Co, Cr, S> {
67 _crypo: Cr,
68 connection: Co,
69 uaid: Option<String>,
70 auth: Option<String>,
71 registration_id: Option<String>,
72 store: S,
73 update_rate_limiter: PersistedRateLimiter,
74 verify_connection_rate_limiter: PersistedRateLimiter,
75}
76
77impl<Co: Connection, Cr: Cryptography, S: Storage> PushManager<Co, Cr, S> {
78 pub fn new(config: PushConfiguration) -> Result<Self> {
79 let store = S::open(&config.database_path)?;
80 let uaid = store.get_uaid()?;
81 let auth = store.get_auth()?;
82 let registration_id = store.get_registration_id()?;
83 let verify_connection_rate_limiter = PersistedRateLimiter::new(
84 "verify_connection",
85 config
86 .verify_connection_rate_limiter
87 .unwrap_or(super::config::DEFAULT_VERIFY_CONNECTION_LIMITER_INTERVAL),
88 1,
89 );
90
91 let update_rate_limiter = PersistedRateLimiter::new(
92 "update_token",
93 UPDATE_RATE_LIMITER_INTERVAL,
94 UPDATE_RATE_LIMITER_MAX_CALLS,
95 );
96
97 Ok(Self {
98 connection: Co::connect(config),
99 _crypo: Default::default(),
100 uaid,
101 auth,
102 registration_id,
103 store,
104 update_rate_limiter,
105 verify_connection_rate_limiter,
106 })
107 }
108
109 fn ensure_auth_pair(&self) -> Result<(&str, &str)> {
110 if let (Some(uaid), Some(auth)) = (&self.uaid, &self.auth) {
111 Ok((uaid, auth))
112 } else {
113 Err(PushError::GeneralError(
114 "No subscriptions created yet.".into(),
115 ))
116 }
117 }
118
119 pub fn subscribe(
120 &mut self,
121 scope: &str,
122 server_key: Option<&str>,
123 ) -> Result<SubscriptionResponse> {
124 let server_key = if let Some("") = server_key {
127 None
128 } else {
129 server_key
130 };
131 if let Some(record) = self.store.get_record_by_scope(scope)? {
133 if self.uaid.is_none() {
134 return Err(PushError::StorageError(
136 "DB has a subscription but no UAID".to_string(),
137 ));
138 }
139 debug!("returning existing subscription for '{}'", scope);
140 return record.try_into();
141 }
142
143 let registration_id = self
144 .registration_id
145 .as_ref()
146 .ok_or_else(|| PushError::CommunicationError("No native id".to_string()))?
147 .clone();
148
149 self.impl_subscribe(scope, ®istration_id, server_key)
150 }
151
152 pub fn get_subscription(&self, scope: &str) -> Result<Option<SubscriptionResponse>> {
153 self.store
154 .get_record_by_scope(scope)?
155 .map(TryInto::try_into)
156 .transpose()
157 }
158
159 pub fn unsubscribe(&mut self, scope: &str) -> Result<bool> {
160 let (uaid, auth) = self.ensure_auth_pair()?;
161 let record = self.store.get_record_by_scope(scope)?;
162 if let Some(record) = record {
163 self.connection
164 .unsubscribe(&record.channel_id, uaid, auth)?;
165 self.store.delete_record(&record.channel_id)?;
166 Ok(true)
167 } else {
168 Ok(false)
169 }
170 }
171
172 pub fn unsubscribe_all(&mut self) -> Result<()> {
173 let (uaid, auth) = self.ensure_auth_pair()?;
174
175 self.connection.unsubscribe_all(uaid, auth)?;
176 self.wipe_local_registrations()?;
177 Ok(())
178 }
179
180 pub fn update(&mut self, new_token: &str) -> error::Result<()> {
181 if self.registration_id.as_deref() == Some(new_token) {
182 return Ok(());
187 }
188
189 if self.uaid.is_none() {
192 self.store.set_registration_id(new_token)?;
193 self.registration_id = Some(new_token.to_string());
194 info!("saved the registration ID but not telling the server as we have no subs yet");
195 return Ok(());
196 }
197
198 if !self.update_rate_limiter.check(&self.store) {
199 return Ok(());
200 }
201
202 let (uaid, auth) = self.ensure_auth_pair()?;
203
204 if let Err(e) = self.connection.update(new_token, uaid, auth) {
205 match e {
206 PushError::UAIDNotRecognizedError(_) => {
207 info!("updating our token indicated our subscriptions are gone");
210 }
211 _ => return Err(e),
212 }
213 }
214
215 self.store.set_registration_id(new_token)?;
216 self.registration_id = Some(new_token.to_string());
217 Ok(())
218 }
219
220 pub fn verify_connection(
221 &mut self,
222 force_verify: bool,
223 ) -> Result<Vec<PushSubscriptionChanged>> {
224 if force_verify {
225 self.verify_connection_rate_limiter.reset(&self.store);
226 }
227
228 if self.uaid.is_none() || !self.verify_connection_rate_limiter.check(&self.store) {
231 return Ok(vec![]);
232 }
233 let channels = self.store.get_channel_list()?;
234 let (uaid, auth) = self.ensure_auth_pair()?;
235
236 let local_channels: HashSet<String> = channels.into_iter().collect();
237 let remote_channels = match self.connection.channel_list(uaid, auth) {
238 Ok(v) => Some(HashSet::from_iter(v)),
239 Err(e) => match e {
240 PushError::UAIDNotRecognizedError(_) => {
241 None
243 }
244 _ => return Err(e),
245 },
246 };
247
248 match remote_channels {
250 Some(channels) if channels == local_channels => return Ok(Vec::new()),
252 Some(_) => {
253 info!("verify_connection found a mismatch - unsubscribing");
254 self.connection.unsubscribe_all(uaid, auth)?;
256 }
257 None => (),
260 };
261
262 let mut subscriptions: Vec<PushSubscriptionChanged> = Vec::new();
263 for channel in local_channels {
264 if let Some(record) = self.store.get_record(&channel)? {
265 subscriptions.push(record.into());
266 }
267 }
268 self.wipe_local_registrations()?;
271 Ok(subscriptions)
272 }
273
274 pub fn decrypt(&self, payload: HashMap<String, String>) -> Result<DecryptResponse> {
275 let payload = PushPayload::try_from(&payload)?;
276 let val = self
277 .store
278 .get_record(payload.channel_id)?
279 .ok_or_else(|| PushError::RecordNotFoundError(payload.channel_id.to_string()))?;
280 let key = Key::deserialize(&val.key)?;
281 let decrypted = Cr::decrypt(&key, payload)?;
282 Ok(DecryptResponse {
285 result: decrypted.into_iter().map(|ub| ub as i8).collect(),
286 scope: val.scope,
287 })
288 }
289
290 fn wipe_local_registrations(&mut self) -> error::Result<()> {
291 self.store.delete_all_records()?;
292 self.auth = None;
293 self.uaid = None;
294 Ok(())
295 }
296
297 fn impl_subscribe(
298 &mut self,
299 scope: &str,
300 registration_id: &str,
301 server_key: Option<&str>,
302 ) -> error::Result<SubscriptionResponse> {
303 if let (Some(uaid), Some(auth)) = (&self.uaid, &self.auth) {
304 self.subscribe_with_uaid(scope, uaid, auth, registration_id, server_key)
305 } else {
306 self.register(scope, registration_id, server_key)
307 }
308 }
309
310 fn subscribe_with_uaid(
311 &self,
312 scope: &str,
313 uaid: &str,
314 auth: &str,
315 registration_id: &str,
316 app_server_key: Option<&str>,
317 ) -> error::Result<SubscriptionResponse> {
318 let app_server_key = app_server_key.map(|v| v.to_owned());
319
320 let subscription_response =
321 self.connection
322 .subscribe(uaid, auth, registration_id, &app_server_key)?;
323 let subscription_key = Cr::generate_key()?;
324 let mut record = crate::internal::storage::PushRecord::new(
325 &subscription_response.channel_id,
326 &subscription_response.endpoint,
327 scope,
328 subscription_key.clone(),
329 )?;
330 record.app_server_key = app_server_key;
331 self.store.put_record(&record)?;
332 debug!("subscribed OK");
333 Ok(SubscriptionResponse {
334 channel_id: subscription_response.channel_id,
335 subscription_info: SubscriptionInfo {
336 endpoint: subscription_response.endpoint,
337 keys: subscription_key.into(),
338 },
339 })
340 }
341
342 fn register(
343 &mut self,
344 scope: &str,
345 registration_id: &str,
346 app_server_key: Option<&str>,
347 ) -> error::Result<SubscriptionResponse> {
348 let app_server_key = app_server_key.map(|v| v.to_owned());
349 let register_response = self.connection.register(registration_id, &app_server_key)?;
350 self.store.set_uaid(®ister_response.uaid)?;
352 self.store.set_auth(®ister_response.secret)?;
353 self.uaid = Some(register_response.uaid.clone());
354 self.auth = Some(register_response.secret.clone());
355
356 let subscription_key = Cr::generate_key()?;
357 let mut record = crate::internal::storage::PushRecord::new(
358 ®ister_response.channel_id,
359 ®ister_response.endpoint,
360 scope,
361 subscription_key.clone(),
362 )?;
363 record.app_server_key = app_server_key;
364 self.store.put_record(&record)?;
365 debug!("subscribed OK");
366 Ok(SubscriptionResponse {
367 channel_id: register_response.channel_id,
368 subscription_info: SubscriptionInfo {
369 endpoint: register_response.endpoint,
370 keys: subscription_key.into(),
371 },
372 })
373 }
374}
375
376#[cfg(test)]
377mod test {
378 use mockall::predicate::eq;
379 use rc_crypto::ece::{self, EcKeyComponents};
380
381 use crate::internal::{
382 communications::{MockConnection, RegisterResponse, SubscribeResponse},
383 crypto::MockCryptography,
384 };
385
386 use super::*;
387 use lazy_static::lazy_static;
388 use std::sync::{Mutex, MutexGuard};
389
390 use nss::ensure_initialized;
391
392 use crate::Store;
393
394 lazy_static! {
395 static ref MTX: Mutex<()> = Mutex::new(());
396 }
397
398 fn get_lock(m: &'static Mutex<()>) -> MutexGuard<'static, ()> {
402 match m.lock() {
403 Ok(guard) => guard,
404 Err(poisoned) => poisoned.into_inner(),
405 }
406 }
407
408 const TEST_UAID: &str = "abad1d3a00000000aabbccdd00000000";
409 const DATA: &[u8] = b"Mary had a little lamb, with some nice mint jelly";
410 const TEST_CHANNEL_ID: &str = "deadbeef00000000decafbad00000000";
411 const TEST_CHANNEL_ID2: &str = "decafbad00000000deadbeef00000000";
412
413 const PRIV_KEY_D: &str = "qJkxxWGVVxy7BKvraNY3hg8Gs-Y8qi0lRaXWJ3R3aJ8";
414 const TEST_AUTH: &str = "LsuUOBKVQRY6-l7_Ajo-Ag";
416 const PUB_KEY_RAW: &str =
418 "BBcJdfs1GtMyymFTtty6lIGWRFXrEtJP40Df0gOvRDR4D8CKVgqE6vlYR7tCYksIRdKD1MxDPhQVmKLnzuife50";
419
420 const ONE_DAY_AND_ONE_SECOND: u64 = (24 * 60 * 60) + 1;
421
422 fn get_test_manager() -> Result<PushManager<MockConnection, MockCryptography, Store>> {
423 let test_config = PushConfiguration {
424 sender_id: "test".to_owned(),
425 ..Default::default()
426 };
427
428 let mut pm: PushManager<MockConnection, MockCryptography, Store> =
429 PushManager::new(test_config)?;
430 pm.store.set_registration_id("native-id")?;
431 pm.registration_id = Some("native-id".to_string());
432 Ok(pm)
433 }
434
435 #[test]
436 fn basic() -> Result<()> {
437 let _m = get_lock(&MTX);
438 let ctx = MockConnection::connect_context();
439 ctx.expect().returning(|_| Default::default());
440
441 let mut pm = get_test_manager()?;
442 pm.connection
443 .expect_register()
444 .with(eq("native-id"), eq(None))
445 .times(1)
446 .returning(|_, _| {
447 Ok(RegisterResponse {
448 uaid: TEST_UAID.to_string(),
449 channel_id: TEST_CHANNEL_ID.to_string(),
450 secret: TEST_AUTH.to_string(),
451 endpoint: "https://example.com/dummy-endpoint".to_string(),
452 sender_id: Some("test".to_string()),
453 })
454 });
455 let crypto_ctx = MockCryptography::generate_key_context();
456 crypto_ctx.expect().returning(|| {
457 let components = EcKeyComponents::new(
458 URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
459 URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
460 );
461 let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
462 Ok(Key {
463 p256key: components,
464 auth,
465 })
466 });
467 let resp = pm.subscribe("test-scope", None)?;
468 let resp2 = pm.subscribe("test-scope", None)?;
470 assert_eq!(Some(TEST_AUTH.to_owned()), pm.store.get_auth()?);
471 assert_eq!(
472 resp.subscription_info.endpoint,
473 resp2.subscription_info.endpoint
474 );
475 assert_eq!(resp.subscription_info.keys, resp2.subscription_info.keys);
476
477 pm.connection
478 .expect_unsubscribe()
479 .with(eq(TEST_CHANNEL_ID), eq(TEST_UAID), eq(TEST_AUTH))
480 .times(1)
481 .returning(|_, _, _| Ok(()));
482 pm.connection
483 .expect_unsubscribe_all()
484 .with(eq(TEST_UAID), eq(TEST_AUTH))
485 .times(1)
486 .returning(|_, _| Ok(()));
487
488 pm.unsubscribe("test-scope")?;
489 pm.unsubscribe("test-scope")?;
491 pm.unsubscribe_all()?;
492 Ok(())
493 }
494
495 #[test]
496 fn full() -> Result<()> {
497 ensure_initialized();
498 rc_crypto::ensure_initialized();
499
500 let _m = get_lock(&MTX);
501 let ctx = MockConnection::connect_context();
502 ctx.expect().returning(|_| Default::default());
503 let data_string = b"Mary had a little lamb, with some nice mint jelly";
504 let mut pm = get_test_manager()?;
505 pm.connection
506 .expect_register()
507 .with(eq("native-id"), eq(None))
508 .times(1)
509 .returning(|_, _| {
510 Ok(RegisterResponse {
511 uaid: TEST_UAID.to_string(),
512 channel_id: TEST_CHANNEL_ID.to_string(),
513 secret: TEST_AUTH.to_string(),
514 endpoint: "https://example.com/dummy-endpoint".to_string(),
515 sender_id: Some("test".to_string()),
516 })
517 });
518 let crypto_ctx = MockCryptography::generate_key_context();
519 crypto_ctx.expect().returning(|| {
520 let components = EcKeyComponents::new(
521 URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
522 URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
523 );
524 let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
525 Ok(Key {
526 p256key: components,
527 auth,
528 })
529 });
530
531 let resp = pm.subscribe("test-scope", None)?;
532 let key_info = resp.subscription_info.keys;
533 let remote_pub = URL_SAFE_NO_PAD.decode(&key_info.p256dh).unwrap();
534 let auth = URL_SAFE_NO_PAD.decode(&key_info.auth).unwrap();
535 let ciphertext = ece::encrypt(&remote_pub, &auth, data_string).unwrap();
537 let body = URL_SAFE_NO_PAD.encode(ciphertext);
538
539 let decryp_ctx = MockCryptography::decrypt_context();
540 let body_clone = body.clone();
541 decryp_ctx
542 .expect()
543 .withf(move |key, push_payload| {
544 *key == Key {
545 p256key: EcKeyComponents::new(
546 URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
547 URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
548 ),
549 auth: URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap(),
550 } && push_payload.body == body_clone
551 && push_payload.encoding == "aes128gcm"
552 && push_payload.dh.is_empty()
553 && push_payload.salt.is_empty()
554 })
555 .returning(|_, _| Ok(data_string.to_vec()));
556
557 let payload = HashMap::from_iter(vec![
558 ("chid".to_string(), resp.channel_id),
559 ("body".to_string(), body),
560 ("con".to_string(), "aes128gcm".to_string()),
561 ("enc".to_string(), "".to_string()),
562 ("cryptokey".to_string(), "".to_string()),
563 ]);
564 pm.decrypt(payload).unwrap();
565 Ok(())
566 }
567
568 #[test]
569 fn test_aesgcm_decryption() -> Result<()> {
570 ensure_initialized();
571 rc_crypto::ensure_initialized();
572
573 let _m = get_lock(&MTX);
574
575 let ctx = MockConnection::connect_context();
576 ctx.expect().returning(|_| Default::default());
577
578 let mut pm = get_test_manager()?;
579
580 pm.connection
581 .expect_register()
582 .with(eq("native-id"), eq(None))
583 .times(1)
584 .returning(|_, _| {
585 Ok(RegisterResponse {
586 uaid: TEST_UAID.to_string(),
587 channel_id: TEST_CHANNEL_ID.to_string(),
588 secret: TEST_AUTH.to_string(),
589 endpoint: "https://example.com/dummy-endpoint".to_string(),
590 sender_id: Some("test".to_string()),
591 })
592 });
593 let crypto_ctx = MockCryptography::generate_key_context();
594 crypto_ctx.expect().returning(|| {
595 let components = EcKeyComponents::new(
596 URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
597 URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
598 );
599 let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
600 Ok(Key {
601 p256key: components,
602 auth,
603 })
604 });
605 let resp = pm.subscribe("test-scope", None)?;
606 let key_info = resp.subscription_info.keys;
607 let remote_pub = URL_SAFE_NO_PAD.decode(&key_info.p256dh).unwrap();
608 let auth = URL_SAFE_NO_PAD.decode(&key_info.auth).unwrap();
609 let ciphertext = ece::encrypt(&remote_pub, &auth, DATA).unwrap();
611 let body = URL_SAFE_NO_PAD.encode(ciphertext);
612
613 let decryp_ctx = MockCryptography::decrypt_context();
614 let body_clone = body.clone();
615 decryp_ctx
616 .expect()
617 .withf(move |key, push_payload| {
618 *key == Key {
619 p256key: EcKeyComponents::new(
620 URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
621 URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
622 ),
623 auth: URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap(),
624 } && push_payload.body == body_clone
625 && push_payload.encoding == "aesgcm"
626 && push_payload.dh.is_empty()
627 && push_payload.salt.is_empty()
628 })
629 .returning(|_, _| Ok(DATA.to_vec()));
630
631 let payload = HashMap::from_iter(vec![
632 ("chid".to_string(), resp.channel_id),
633 ("body".to_string(), body),
634 ("con".to_string(), "aesgcm".to_string()),
635 ("enc".to_string(), "".to_string()),
636 ("cryptokey".to_string(), "".to_string()),
637 ]);
638 pm.decrypt(payload).unwrap();
639 Ok(())
640 }
641
642 #[test]
643 fn test_duplicate_subscription_requests() -> Result<()> {
644 ensure_initialized();
645 rc_crypto::ensure_initialized();
646
647 let _m = get_lock(&MTX);
648
649 let ctx = MockConnection::connect_context();
650 ctx.expect().returning(|_| Default::default());
651
652 let mut pm = get_test_manager()?;
653
654 pm.connection
655 .expect_register()
656 .with(eq("native-id"), eq(None))
657 .times(1) .returning(|_, _| {
659 Ok(RegisterResponse {
660 uaid: TEST_UAID.to_string(),
661 channel_id: TEST_CHANNEL_ID.to_string(),
662 secret: TEST_AUTH.to_string(),
663 endpoint: "https://example.com/dummy-endpoint".to_string(),
664 sender_id: Some("test".to_string()),
665 })
666 });
667 let crypto_ctx = MockCryptography::generate_key_context();
668 crypto_ctx.expect().returning(|| {
669 let components = EcKeyComponents::new(
670 URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
671 URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
672 );
673 let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
674 Ok(Key {
675 p256key: components,
676 auth,
677 })
678 });
679 let sub_1 = pm.subscribe("test-scope", None)?;
680 let sub_2 = pm.subscribe("test-scope", None)?;
681 assert_eq!(sub_1, sub_2);
682 Ok(())
683 }
684 #[test]
685 fn test_verify_wipe_uaid_if_mismatch() -> Result<()> {
686 let _m = get_lock(&MTX);
687 let ctx = MockConnection::connect_context();
688 ctx.expect().returning(|_| Default::default());
689
690 let mut pm = get_test_manager()?;
691 pm.connection
692 .expect_register()
693 .with(eq("native-id"), eq(None))
694 .times(2)
695 .returning(|_, _| {
696 Ok(RegisterResponse {
697 uaid: TEST_UAID.to_string(),
698 channel_id: TEST_CHANNEL_ID.to_string(),
699 secret: TEST_AUTH.to_string(),
700 endpoint: "https://example.com/dummy-endpoint".to_string(),
701 sender_id: Some("test".to_string()),
702 })
703 });
704
705 let crypto_ctx = MockCryptography::generate_key_context();
706 crypto_ctx.expect().returning(|| {
707 let components = EcKeyComponents::new(
708 URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
709 URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
710 );
711 let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
712 Ok(Key {
713 p256key: components,
714 auth,
715 })
716 });
717 pm.connection
718 .expect_channel_list()
719 .with(eq(TEST_UAID), eq(TEST_AUTH))
720 .times(1)
721 .returning(|_, _| Ok(vec![TEST_CHANNEL_ID2.to_string()]));
722
723 pm.connection
724 .expect_unsubscribe_all()
725 .with(eq(TEST_UAID), eq(TEST_AUTH))
726 .times(1)
727 .returning(|_, _| Ok(()));
728 let _ = pm.subscribe("test-scope", None)?;
729 assert_eq!(pm.store.get_uaid()?.unwrap(), TEST_UAID);
732 assert_eq!(
733 pm.store.get_record(TEST_CHANNEL_ID)?.unwrap().channel_id,
734 TEST_CHANNEL_ID
735 );
736 let unsubscribed_channels = pm.verify_connection(false)?;
737 assert_eq!(unsubscribed_channels.len(), 1);
738 assert_eq!(unsubscribed_channels[0].channel_id, TEST_CHANNEL_ID);
739 assert!(pm.store.get_uaid()?.is_none());
742 assert!(pm.store.get_record(TEST_CHANNEL_ID)?.is_none());
743
744 let _ = pm.subscribe("test-scope", None)?;
747 assert_eq!(pm.store.get_uaid()?.unwrap(), TEST_UAID);
750 assert_eq!(
751 pm.store.get_record(TEST_CHANNEL_ID)?.unwrap().channel_id,
752 TEST_CHANNEL_ID
753 );
754 Ok(())
755 }
756
757 #[test]
758 fn test_verify_server_lost_uaid_not_error() -> Result<()> {
759 let _m = get_lock(&MTX);
760 let ctx = MockConnection::connect_context();
761 ctx.expect().returning(|_| Default::default());
762
763 let mut pm = get_test_manager()?;
764 pm.connection
765 .expect_register()
766 .with(eq("native-id"), eq(None))
767 .times(1)
768 .returning(|_, _| {
769 Ok(RegisterResponse {
770 uaid: TEST_UAID.to_string(),
771 channel_id: TEST_CHANNEL_ID.to_string(),
772 secret: TEST_AUTH.to_string(),
773 endpoint: "https://example.com/dummy-endpoint".to_string(),
774 sender_id: Some("test".to_string()),
775 })
776 });
777
778 let crypto_ctx = MockCryptography::generate_key_context();
779 crypto_ctx.expect().returning(|| {
780 let components = EcKeyComponents::new(
781 URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
782 URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
783 );
784 let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
785 Ok(Key {
786 p256key: components,
787 auth,
788 })
789 });
790 pm.connection
791 .expect_channel_list()
792 .with(eq(TEST_UAID), eq(TEST_AUTH))
793 .times(1)
794 .returning(|_, _| {
795 Err(PushError::UAIDNotRecognizedError(
796 "Couldn't find uaid".to_string(),
797 ))
798 });
799
800 let _ = pm.subscribe("test-scope", None)?;
801 assert_eq!(pm.store.get_uaid()?.unwrap(), TEST_UAID);
804 assert_eq!(
805 pm.store.get_record(TEST_CHANNEL_ID)?.unwrap().channel_id,
806 TEST_CHANNEL_ID
807 );
808 let unsubscribed_channels = pm.verify_connection(false)?;
809 assert_eq!(unsubscribed_channels.len(), 1);
810 assert_eq!(unsubscribed_channels[0].channel_id, TEST_CHANNEL_ID);
811 assert!(pm.store.get_uaid()?.is_none());
814 assert!(pm.store.get_record(TEST_CHANNEL_ID)?.is_none());
815 Ok(())
816 }
817
818 #[test]
819 fn test_verify_server_hard_error() -> Result<()> {
820 let _m = get_lock(&MTX);
821 let ctx = MockConnection::connect_context();
822 ctx.expect().returning(|_| Default::default());
823
824 let mut pm = get_test_manager()?;
825 pm.connection
826 .expect_register()
827 .with(eq("native-id"), eq(None))
828 .times(1)
829 .returning(|_, _| {
830 Ok(RegisterResponse {
831 uaid: TEST_UAID.to_string(),
832 channel_id: TEST_CHANNEL_ID.to_string(),
833 secret: TEST_AUTH.to_string(),
834 endpoint: "https://example.com/dummy-endpoint".to_string(),
835 sender_id: Some("test".to_string()),
836 })
837 });
838
839 let crypto_ctx = MockCryptography::generate_key_context();
840 crypto_ctx.expect().returning(|| {
841 let components = EcKeyComponents::new(
842 URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
843 URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
844 );
845 let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
846 Ok(Key {
847 p256key: components,
848 auth,
849 })
850 });
851 pm.connection
852 .expect_channel_list()
853 .with(eq(TEST_UAID), eq(TEST_AUTH))
854 .times(1)
855 .returning(|_, _| {
856 Err(PushError::CommunicationError(
857 "Unrecoverable error".to_string(),
858 ))
859 });
860
861 let _ = pm.subscribe("test-scope", None)?;
862 assert_eq!(pm.store.get_uaid()?.unwrap(), TEST_UAID);
865 assert_eq!(
866 pm.store.get_record(TEST_CHANNEL_ID)?.unwrap().channel_id,
867 TEST_CHANNEL_ID
868 );
869 let err = pm.verify_connection(false).unwrap_err();
870
871 assert!(matches!(err, PushError::CommunicationError(_)));
873 Ok(())
874 }
875
876 #[test]
877 fn test_verify_no_local_uaid_ok() -> Result<()> {
878 let _m = get_lock(&MTX);
879 let ctx = MockConnection::connect_context();
880 ctx.expect().returning(|_| Default::default());
881
882 let mut pm = get_test_manager()?;
883 let channel_list = pm
884 .verify_connection(true)
885 .expect("There are no subscriptions, so verify connection should not fail");
886 assert!(channel_list.is_empty());
887 Ok(())
888 }
889
890 #[test]
891 fn test_second_subscribe_hits_subscribe_endpoint() -> Result<()> {
892 let _m = get_lock(&MTX);
893 let ctx = MockConnection::connect_context();
894 ctx.expect().returning(|_| Default::default());
895
896 let mut pm = get_test_manager()?;
897 pm.connection
898 .expect_register()
899 .with(eq("native-id"), eq(None))
900 .times(1)
901 .returning(|_, _| {
902 Ok(RegisterResponse {
903 uaid: TEST_UAID.to_string(),
904 channel_id: TEST_CHANNEL_ID.to_string(),
905 secret: TEST_AUTH.to_string(),
906 endpoint: "https://example.com/dummy-endpoint".to_string(),
907 sender_id: Some("test".to_string()),
908 })
909 });
910
911 pm.connection
912 .expect_subscribe()
913 .with(eq(TEST_UAID), eq(TEST_AUTH), eq("native-id"), eq(None))
914 .times(1)
915 .returning(|_, _, _, _| {
916 Ok(SubscribeResponse {
917 channel_id: TEST_CHANNEL_ID2.to_string(),
918 endpoint: "https://example.com/different-dummy-endpoint".to_string(),
919 sender_id: Some("test".to_string()),
920 })
921 });
922
923 let crypto_ctx = MockCryptography::generate_key_context();
924 crypto_ctx.expect().returning(|| {
925 let components = EcKeyComponents::new(
926 URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
927 URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
928 );
929 let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
930 Ok(Key {
931 p256key: components,
932 auth,
933 })
934 });
935
936 let resp_1 = pm.subscribe("test-scope", None)?;
937 let resp_2 = pm.subscribe("another-scope", None)?;
938 assert_eq!(
939 resp_1.subscription_info.endpoint,
940 "https://example.com/dummy-endpoint"
941 );
942 assert_eq!(
943 resp_2.subscription_info.endpoint,
944 "https://example.com/different-dummy-endpoint"
945 );
946 Ok(())
947 }
948
949 #[test]
950 fn test_verify_connection_rate_limiter() -> Result<()> {
951 let _m = get_lock(&MTX);
952 let ctx = MockConnection::connect_context();
953 ctx.expect().returning(|_| Default::default());
954
955 let mut pm = get_test_manager()?;
956 pm.connection
957 .expect_register()
958 .with(eq("native-id"), eq(None))
959 .times(1)
960 .returning(|_, _| {
961 Ok(RegisterResponse {
962 uaid: TEST_UAID.to_string(),
963 channel_id: TEST_CHANNEL_ID.to_string(),
964 secret: TEST_AUTH.to_string(),
965 endpoint: "https://example.com/dummy-endpoint".to_string(),
966 sender_id: Some("test".to_string()),
967 })
968 });
969 let crypto_ctx = MockCryptography::generate_key_context();
970 crypto_ctx.expect().returning(|| {
971 let components = EcKeyComponents::new(
972 URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
973 URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
974 );
975 let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
976 Ok(Key {
977 p256key: components,
978 auth,
979 })
980 });
981 let _ = pm.subscribe("test-scope", None)?;
982 pm.connection
983 .expect_channel_list()
984 .with(eq(TEST_UAID), eq(TEST_AUTH))
985 .times(3)
986 .returning(|_, _| Ok(vec![TEST_CHANNEL_ID.to_string()]));
987 let _ = pm.verify_connection(false)?;
988 let (_, count) = pm.verify_connection_rate_limiter.get_counters(&pm.store);
989 assert_eq!(count, 1);
990 let _ = pm.verify_connection(false)?;
991 let (timestamp, count) = pm.verify_connection_rate_limiter.get_counters(&pm.store);
992
993 assert_eq!(count, 2);
994
995 pm.verify_connection_rate_limiter.persist_counters(
996 &pm.store,
997 timestamp - ONE_DAY_AND_ONE_SECOND,
998 count,
999 );
1000
1001 let _ = pm.verify_connection(false)?;
1002 let (_, count) = pm.verify_connection_rate_limiter.get_counters(&pm.store);
1003 assert_eq!(count, 1);
1004
1005 let _ = pm.verify_connection(true)?;
1008 let (_, count) = pm.verify_connection_rate_limiter.get_counters(&pm.store);
1009 assert_eq!(count, 1);
1010
1011 Ok(())
1012 }
1013}