push/internal/
push_manager.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5//! Main entrypoint for the push component, handles push subscriptions
6//!
7//! Exposes a struct [`PushManager`] that manages push subscriptions for a client
8//!
9//! The [`PushManager`] allows users to:
10//! - Create new subscriptions persist their private keys and return a URL for sender to send encrypted payloads using a returned public key
11//! - Delete existing subscriptions
12//! - Update native tokens with autopush server
13//! - routinely check subscriptions to make sure they are in a good state.
14
15use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
16use std::collections::{HashMap, HashSet};
17
18use crate::error::{self, debug, info, PushError, Result};
19use crate::internal::communications::{Connection, PersistedRateLimiter};
20use crate::internal::config::PushConfiguration;
21use crate::internal::crypto::KeyV1 as Key;
22use crate::internal::storage::{PushRecord, Storage};
23use crate::{KeyInfo, PushSubscriptionChanged, SubscriptionInfo, SubscriptionResponse};
24
25use super::crypto::{Cryptography, PushPayload};
26const UPDATE_RATE_LIMITER_INTERVAL: u64 = 24 * 60 * 60; // 24 hours.
27const UPDATE_RATE_LIMITER_MAX_CALLS: u16 = 500; // 500
28
29impl From<Key> for KeyInfo {
30    fn from(key: Key) -> Self {
31        KeyInfo {
32            auth: URL_SAFE_NO_PAD.encode(key.auth_secret()),
33            p256dh: URL_SAFE_NO_PAD.encode(key.public_key()),
34        }
35    }
36}
37
38impl From<PushRecord> for PushSubscriptionChanged {
39    fn from(record: PushRecord) -> Self {
40        PushSubscriptionChanged {
41            channel_id: record.channel_id,
42            scope: record.scope,
43        }
44    }
45}
46
47impl TryFrom<PushRecord> for SubscriptionResponse {
48    type Error = PushError;
49    fn try_from(value: PushRecord) -> Result<Self, Self::Error> {
50        Ok(SubscriptionResponse {
51            channel_id: value.channel_id,
52            subscription_info: SubscriptionInfo {
53                endpoint: value.endpoint,
54                keys: Key::deserialize(&value.key)?.into(),
55            },
56        })
57    }
58}
59
60#[derive(Debug)]
61pub struct DecryptResponse {
62    pub result: Vec<i8>,
63    pub scope: String,
64}
65
66pub struct PushManager<Co, Cr, S> {
67    _crypo: Cr,
68    connection: Co,
69    uaid: Option<String>,
70    auth: Option<String>,
71    registration_id: Option<String>,
72    store: S,
73    update_rate_limiter: PersistedRateLimiter,
74    verify_connection_rate_limiter: PersistedRateLimiter,
75}
76
77impl<Co: Connection, Cr: Cryptography, S: Storage> PushManager<Co, Cr, S> {
78    pub fn new(config: PushConfiguration) -> Result<Self> {
79        let store = S::open(&config.database_path)?;
80        let uaid = store.get_uaid()?;
81        let auth = store.get_auth()?;
82        let registration_id = store.get_registration_id()?;
83        let verify_connection_rate_limiter = PersistedRateLimiter::new(
84            "verify_connection",
85            config
86                .verify_connection_rate_limiter
87                .unwrap_or(super::config::DEFAULT_VERIFY_CONNECTION_LIMITER_INTERVAL),
88            1,
89        );
90
91        let update_rate_limiter = PersistedRateLimiter::new(
92            "update_token",
93            UPDATE_RATE_LIMITER_INTERVAL,
94            UPDATE_RATE_LIMITER_MAX_CALLS,
95        );
96
97        Ok(Self {
98            connection: Co::connect(config),
99            _crypo: Default::default(),
100            uaid,
101            auth,
102            registration_id,
103            store,
104            update_rate_limiter,
105            verify_connection_rate_limiter,
106        })
107    }
108
109    fn ensure_auth_pair(&self) -> Result<(&str, &str)> {
110        if let (Some(uaid), Some(auth)) = (&self.uaid, &self.auth) {
111            Ok((uaid, auth))
112        } else {
113            Err(PushError::GeneralError(
114                "No subscriptions created yet.".into(),
115            ))
116        }
117    }
118
119    pub fn subscribe(
120        &mut self,
121        scope: &str,
122        server_key: Option<&str>,
123    ) -> Result<SubscriptionResponse> {
124        // While potentially an error, a misconfigured system may use "" as
125        // an application key. In that case, we drop the application key.
126        let server_key = if let Some("") = server_key {
127            None
128        } else {
129            server_key
130        };
131        // Don't fetch the subscription from the server if we've already got one.
132        if let Some(record) = self.store.get_record_by_scope(scope)? {
133            if self.uaid.is_none() {
134                // should be impossible - we should delete all records when we lose our uiad.
135                return Err(PushError::StorageError(
136                    "DB has a subscription but no UAID".to_string(),
137                ));
138            }
139            debug!("returning existing subscription for '{}'", scope);
140            return record.try_into();
141        }
142
143        let registration_id = self
144            .registration_id
145            .as_ref()
146            .ok_or_else(|| PushError::CommunicationError("No native id".to_string()))?
147            .clone();
148
149        self.impl_subscribe(scope, &registration_id, server_key)
150    }
151
152    pub fn get_subscription(&self, scope: &str) -> Result<Option<SubscriptionResponse>> {
153        self.store
154            .get_record_by_scope(scope)?
155            .map(TryInto::try_into)
156            .transpose()
157    }
158
159    pub fn unsubscribe(&mut self, scope: &str) -> Result<bool> {
160        let (uaid, auth) = self.ensure_auth_pair()?;
161        let record = self.store.get_record_by_scope(scope)?;
162        if let Some(record) = record {
163            self.connection
164                .unsubscribe(&record.channel_id, uaid, auth)?;
165            self.store.delete_record(&record.channel_id)?;
166            Ok(true)
167        } else {
168            Ok(false)
169        }
170    }
171
172    pub fn unsubscribe_all(&mut self) -> Result<()> {
173        let (uaid, auth) = self.ensure_auth_pair()?;
174
175        self.connection.unsubscribe_all(uaid, auth)?;
176        self.wipe_local_registrations()?;
177        Ok(())
178    }
179
180    pub fn update(&mut self, new_token: &str) -> error::Result<()> {
181        if self.registration_id.as_deref() == Some(new_token) {
182            // Already up to date!
183            // if we haven't send it to the server yet, we will on the next subscribe!
184            // if we have sent it to the server, no need to do so again. We will catch any issues
185            // through the [`PushManager::verify_connection`] check
186            return Ok(());
187        }
188
189        // It's OK if we don't have a uaid yet - that means we don't have any subscriptions,
190        // let save our registration_id, so will use it on our first subscription.
191        if self.uaid.is_none() {
192            self.store.set_registration_id(new_token)?;
193            self.registration_id = Some(new_token.to_string());
194            info!("saved the registration ID but not telling the server as we have no subs yet");
195            return Ok(());
196        }
197
198        if !self.update_rate_limiter.check(&self.store) {
199            return Ok(());
200        }
201
202        let (uaid, auth) = self.ensure_auth_pair()?;
203
204        if let Err(e) = self.connection.update(new_token, uaid, auth) {
205            match e {
206                PushError::UAIDNotRecognizedError(_) => {
207                    // Our subscriptions are dead, but for now, just let the existing mechanisms
208                    // deal with that (eg, next `subscribe()` or `verify_connection()`)
209                    info!("updating our token indicated our subscriptions are gone");
210                }
211                _ => return Err(e),
212            }
213        }
214
215        self.store.set_registration_id(new_token)?;
216        self.registration_id = Some(new_token.to_string());
217        Ok(())
218    }
219
220    pub fn verify_connection(
221        &mut self,
222        force_verify: bool,
223    ) -> Result<Vec<PushSubscriptionChanged>> {
224        if force_verify {
225            self.verify_connection_rate_limiter.reset(&self.store);
226        }
227
228        // If we were rate limited or there are no subscriptions yet, we should signal to the
229        // consumer that everything is ok
230        if self.uaid.is_none() || !self.verify_connection_rate_limiter.check(&self.store) {
231            return Ok(vec![]);
232        }
233        let channels = self.store.get_channel_list()?;
234        let (uaid, auth) = self.ensure_auth_pair()?;
235
236        let local_channels: HashSet<String> = channels.into_iter().collect();
237        let remote_channels = match self.connection.channel_list(uaid, auth) {
238            Ok(v) => Some(HashSet::from_iter(v)),
239            Err(e) => match e {
240                PushError::UAIDNotRecognizedError(_) => {
241                    // We do not unsubscribe, because the server already lost our UAID
242                    None
243                }
244                _ => return Err(e),
245            },
246        };
247
248        // verify both lists match. Either side could have lost its mind.
249        match remote_channels {
250            // Everything is OK! Lets return early
251            Some(channels) if channels == local_channels => return Ok(Vec::new()),
252            Some(_) => {
253                info!("verify_connection found a mismatch - unsubscribing");
254                // Unsubscribe all the channels (just to be sure and avoid a loop).
255                self.connection.unsubscribe_all(uaid, auth)?;
256            }
257            // Means the server lost our UAID, lets not unsubscribe,
258            // as that operation will fail
259            None => (),
260        };
261
262        let mut subscriptions: Vec<PushSubscriptionChanged> = Vec::new();
263        for channel in local_channels {
264            if let Some(record) = self.store.get_record(&channel)? {
265                subscriptions.push(record.into());
266            }
267        }
268        // we wipe all existing subscriptions and the UAID if there is a mismatch; the next
269        // `subscribe()` call will get a new UAID.
270        self.wipe_local_registrations()?;
271        Ok(subscriptions)
272    }
273
274    pub fn decrypt(&self, payload: HashMap<String, String>) -> Result<DecryptResponse> {
275        let payload = PushPayload::try_from(&payload)?;
276        let val = self
277            .store
278            .get_record(payload.channel_id)?
279            .ok_or_else(|| PushError::RecordNotFoundError(payload.channel_id.to_string()))?;
280        let key = Key::deserialize(&val.key)?;
281        let decrypted = Cr::decrypt(&key, payload)?;
282        // NOTE: this returns a `Vec<i8>` since the kotlin consumer is expecting
283        // signed bytes.
284        Ok(DecryptResponse {
285            result: decrypted.into_iter().map(|ub| ub as i8).collect(),
286            scope: val.scope,
287        })
288    }
289
290    fn wipe_local_registrations(&mut self) -> error::Result<()> {
291        self.store.delete_all_records()?;
292        self.auth = None;
293        self.uaid = None;
294        Ok(())
295    }
296
297    fn impl_subscribe(
298        &mut self,
299        scope: &str,
300        registration_id: &str,
301        server_key: Option<&str>,
302    ) -> error::Result<SubscriptionResponse> {
303        if let (Some(uaid), Some(auth)) = (&self.uaid, &self.auth) {
304            self.subscribe_with_uaid(scope, uaid, auth, registration_id, server_key)
305        } else {
306            self.register(scope, registration_id, server_key)
307        }
308    }
309
310    fn subscribe_with_uaid(
311        &self,
312        scope: &str,
313        uaid: &str,
314        auth: &str,
315        registration_id: &str,
316        app_server_key: Option<&str>,
317    ) -> error::Result<SubscriptionResponse> {
318        let app_server_key = app_server_key.map(|v| v.to_owned());
319
320        let subscription_response =
321            self.connection
322                .subscribe(uaid, auth, registration_id, &app_server_key)?;
323        let subscription_key = Cr::generate_key()?;
324        let mut record = crate::internal::storage::PushRecord::new(
325            &subscription_response.channel_id,
326            &subscription_response.endpoint,
327            scope,
328            subscription_key.clone(),
329        )?;
330        record.app_server_key = app_server_key;
331        self.store.put_record(&record)?;
332        debug!("subscribed OK");
333        Ok(SubscriptionResponse {
334            channel_id: subscription_response.channel_id,
335            subscription_info: SubscriptionInfo {
336                endpoint: subscription_response.endpoint,
337                keys: subscription_key.into(),
338            },
339        })
340    }
341
342    fn register(
343        &mut self,
344        scope: &str,
345        registration_id: &str,
346        app_server_key: Option<&str>,
347    ) -> error::Result<SubscriptionResponse> {
348        let app_server_key = app_server_key.map(|v| v.to_owned());
349        let register_response = self.connection.register(registration_id, &app_server_key)?;
350        // Registration successful! Before we return our registration, lets save our uaid and auth
351        self.store.set_uaid(&register_response.uaid)?;
352        self.store.set_auth(&register_response.secret)?;
353        self.uaid = Some(register_response.uaid.clone());
354        self.auth = Some(register_response.secret.clone());
355
356        let subscription_key = Cr::generate_key()?;
357        let mut record = crate::internal::storage::PushRecord::new(
358            &register_response.channel_id,
359            &register_response.endpoint,
360            scope,
361            subscription_key.clone(),
362        )?;
363        record.app_server_key = app_server_key;
364        self.store.put_record(&record)?;
365        debug!("subscribed OK");
366        Ok(SubscriptionResponse {
367            channel_id: register_response.channel_id,
368            subscription_info: SubscriptionInfo {
369                endpoint: register_response.endpoint,
370                keys: subscription_key.into(),
371            },
372        })
373    }
374}
375
376#[cfg(test)]
377mod test {
378    use mockall::predicate::eq;
379    use rc_crypto::ece::{self, EcKeyComponents};
380
381    use crate::internal::{
382        communications::{MockConnection, RegisterResponse, SubscribeResponse},
383        crypto::MockCryptography,
384    };
385
386    use super::*;
387    use lazy_static::lazy_static;
388    use std::sync::{Mutex, MutexGuard};
389
390    use nss::ensure_initialized;
391
392    use crate::Store;
393
394    lazy_static! {
395        static ref MTX: Mutex<()> = Mutex::new(());
396    }
397
398    // we need to run our tests in sequence. The tests mock static
399    // methods. Mocked static methods are global are susceptible to data races
400    // see: https://docs.rs/mockall/latest/mockall/#static-methods
401    fn get_lock(m: &'static Mutex<()>) -> MutexGuard<'static, ()> {
402        match m.lock() {
403            Ok(guard) => guard,
404            Err(poisoned) => poisoned.into_inner(),
405        }
406    }
407
408    const TEST_UAID: &str = "abad1d3a00000000aabbccdd00000000";
409    const DATA: &[u8] = b"Mary had a little lamb, with some nice mint jelly";
410    const TEST_CHANNEL_ID: &str = "deadbeef00000000decafbad00000000";
411    const TEST_CHANNEL_ID2: &str = "decafbad00000000deadbeef00000000";
412
413    const PRIV_KEY_D: &str = "qJkxxWGVVxy7BKvraNY3hg8Gs-Y8qi0lRaXWJ3R3aJ8";
414    // The auth token
415    const TEST_AUTH: &str = "LsuUOBKVQRY6-l7_Ajo-Ag";
416    // This would be the public key sent to the subscription service.
417    const PUB_KEY_RAW: &str =
418        "BBcJdfs1GtMyymFTtty6lIGWRFXrEtJP40Df0gOvRDR4D8CKVgqE6vlYR7tCYksIRdKD1MxDPhQVmKLnzuife50";
419
420    const ONE_DAY_AND_ONE_SECOND: u64 = (24 * 60 * 60) + 1;
421
422    fn get_test_manager() -> Result<PushManager<MockConnection, MockCryptography, Store>> {
423        let test_config = PushConfiguration {
424            sender_id: "test".to_owned(),
425            ..Default::default()
426        };
427
428        let mut pm: PushManager<MockConnection, MockCryptography, Store> =
429            PushManager::new(test_config)?;
430        pm.store.set_registration_id("native-id")?;
431        pm.registration_id = Some("native-id".to_string());
432        Ok(pm)
433    }
434
435    #[test]
436    fn basic() -> Result<()> {
437        let _m = get_lock(&MTX);
438        let ctx = MockConnection::connect_context();
439        ctx.expect().returning(|_| Default::default());
440
441        let mut pm = get_test_manager()?;
442        pm.connection
443            .expect_register()
444            .with(eq("native-id"), eq(None))
445            .times(1)
446            .returning(|_, _| {
447                Ok(RegisterResponse {
448                    uaid: TEST_UAID.to_string(),
449                    channel_id: TEST_CHANNEL_ID.to_string(),
450                    secret: TEST_AUTH.to_string(),
451                    endpoint: "https://example.com/dummy-endpoint".to_string(),
452                    sender_id: Some("test".to_string()),
453                })
454            });
455        let crypto_ctx = MockCryptography::generate_key_context();
456        crypto_ctx.expect().returning(|| {
457            let components = EcKeyComponents::new(
458                URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
459                URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
460            );
461            let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
462            Ok(Key {
463                p256key: components,
464                auth,
465            })
466        });
467        let resp = pm.subscribe("test-scope", None)?;
468        // verify that a subsequent request for the same channel ID returns the same subscription
469        let resp2 = pm.subscribe("test-scope", None)?;
470        assert_eq!(Some(TEST_AUTH.to_owned()), pm.store.get_auth()?);
471        assert_eq!(
472            resp.subscription_info.endpoint,
473            resp2.subscription_info.endpoint
474        );
475        assert_eq!(resp.subscription_info.keys, resp2.subscription_info.keys);
476
477        pm.connection
478            .expect_unsubscribe()
479            .with(eq(TEST_CHANNEL_ID), eq(TEST_UAID), eq(TEST_AUTH))
480            .times(1)
481            .returning(|_, _, _| Ok(()));
482        pm.connection
483            .expect_unsubscribe_all()
484            .with(eq(TEST_UAID), eq(TEST_AUTH))
485            .times(1)
486            .returning(|_, _| Ok(()));
487
488        pm.unsubscribe("test-scope")?;
489        // It's already deleted, we still return an OK, but it won't trigger a network request
490        pm.unsubscribe("test-scope")?;
491        pm.unsubscribe_all()?;
492        Ok(())
493    }
494
495    #[test]
496    fn full() -> Result<()> {
497        ensure_initialized();
498        rc_crypto::ensure_initialized();
499
500        let _m = get_lock(&MTX);
501        let ctx = MockConnection::connect_context();
502        ctx.expect().returning(|_| Default::default());
503        let data_string = b"Mary had a little lamb, with some nice mint jelly";
504        let mut pm = get_test_manager()?;
505        pm.connection
506            .expect_register()
507            .with(eq("native-id"), eq(None))
508            .times(1)
509            .returning(|_, _| {
510                Ok(RegisterResponse {
511                    uaid: TEST_UAID.to_string(),
512                    channel_id: TEST_CHANNEL_ID.to_string(),
513                    secret: TEST_AUTH.to_string(),
514                    endpoint: "https://example.com/dummy-endpoint".to_string(),
515                    sender_id: Some("test".to_string()),
516                })
517            });
518        let crypto_ctx = MockCryptography::generate_key_context();
519        crypto_ctx.expect().returning(|| {
520            let components = EcKeyComponents::new(
521                URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
522                URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
523            );
524            let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
525            Ok(Key {
526                p256key: components,
527                auth,
528            })
529        });
530
531        let resp = pm.subscribe("test-scope", None)?;
532        let key_info = resp.subscription_info.keys;
533        let remote_pub = URL_SAFE_NO_PAD.decode(&key_info.p256dh).unwrap();
534        let auth = URL_SAFE_NO_PAD.decode(&key_info.auth).unwrap();
535        // Act like a subscription provider, so create a "local" key to encrypt the data
536        let ciphertext = ece::encrypt(&remote_pub, &auth, data_string).unwrap();
537        let body = URL_SAFE_NO_PAD.encode(ciphertext);
538
539        let decryp_ctx = MockCryptography::decrypt_context();
540        let body_clone = body.clone();
541        decryp_ctx
542            .expect()
543            .withf(move |key, push_payload| {
544                *key == Key {
545                    p256key: EcKeyComponents::new(
546                        URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
547                        URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
548                    ),
549                    auth: URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap(),
550                } && push_payload.body == body_clone
551                    && push_payload.encoding == "aes128gcm"
552                    && push_payload.dh.is_empty()
553                    && push_payload.salt.is_empty()
554            })
555            .returning(|_, _| Ok(data_string.to_vec()));
556
557        let payload = HashMap::from_iter(vec![
558            ("chid".to_string(), resp.channel_id),
559            ("body".to_string(), body),
560            ("con".to_string(), "aes128gcm".to_string()),
561            ("enc".to_string(), "".to_string()),
562            ("cryptokey".to_string(), "".to_string()),
563        ]);
564        pm.decrypt(payload).unwrap();
565        Ok(())
566    }
567
568    #[test]
569    fn test_aesgcm_decryption() -> Result<()> {
570        ensure_initialized();
571        rc_crypto::ensure_initialized();
572
573        let _m = get_lock(&MTX);
574
575        let ctx = MockConnection::connect_context();
576        ctx.expect().returning(|_| Default::default());
577
578        let mut pm = get_test_manager()?;
579
580        pm.connection
581            .expect_register()
582            .with(eq("native-id"), eq(None))
583            .times(1)
584            .returning(|_, _| {
585                Ok(RegisterResponse {
586                    uaid: TEST_UAID.to_string(),
587                    channel_id: TEST_CHANNEL_ID.to_string(),
588                    secret: TEST_AUTH.to_string(),
589                    endpoint: "https://example.com/dummy-endpoint".to_string(),
590                    sender_id: Some("test".to_string()),
591                })
592            });
593        let crypto_ctx = MockCryptography::generate_key_context();
594        crypto_ctx.expect().returning(|| {
595            let components = EcKeyComponents::new(
596                URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
597                URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
598            );
599            let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
600            Ok(Key {
601                p256key: components,
602                auth,
603            })
604        });
605        let resp = pm.subscribe("test-scope", None)?;
606        let key_info = resp.subscription_info.keys;
607        let remote_pub = URL_SAFE_NO_PAD.decode(&key_info.p256dh).unwrap();
608        let auth = URL_SAFE_NO_PAD.decode(&key_info.auth).unwrap();
609        // Act like a subscription provider, so create a "local" key to encrypt the data
610        let ciphertext = ece::encrypt(&remote_pub, &auth, DATA).unwrap();
611        let body = URL_SAFE_NO_PAD.encode(ciphertext);
612
613        let decryp_ctx = MockCryptography::decrypt_context();
614        let body_clone = body.clone();
615        decryp_ctx
616            .expect()
617            .withf(move |key, push_payload| {
618                *key == Key {
619                    p256key: EcKeyComponents::new(
620                        URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
621                        URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
622                    ),
623                    auth: URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap(),
624                } && push_payload.body == body_clone
625                    && push_payload.encoding == "aesgcm"
626                    && push_payload.dh.is_empty()
627                    && push_payload.salt.is_empty()
628            })
629            .returning(|_, _| Ok(DATA.to_vec()));
630
631        let payload = HashMap::from_iter(vec![
632            ("chid".to_string(), resp.channel_id),
633            ("body".to_string(), body),
634            ("con".to_string(), "aesgcm".to_string()),
635            ("enc".to_string(), "".to_string()),
636            ("cryptokey".to_string(), "".to_string()),
637        ]);
638        pm.decrypt(payload).unwrap();
639        Ok(())
640    }
641
642    #[test]
643    fn test_duplicate_subscription_requests() -> Result<()> {
644        ensure_initialized();
645        rc_crypto::ensure_initialized();
646
647        let _m = get_lock(&MTX);
648
649        let ctx = MockConnection::connect_context();
650        ctx.expect().returning(|_| Default::default());
651
652        let mut pm = get_test_manager()?;
653
654        pm.connection
655            .expect_register()
656            .with(eq("native-id"), eq(None))
657            .times(1) // only once, second time we'll hit cache!
658            .returning(|_, _| {
659                Ok(RegisterResponse {
660                    uaid: TEST_UAID.to_string(),
661                    channel_id: TEST_CHANNEL_ID.to_string(),
662                    secret: TEST_AUTH.to_string(),
663                    endpoint: "https://example.com/dummy-endpoint".to_string(),
664                    sender_id: Some("test".to_string()),
665                })
666            });
667        let crypto_ctx = MockCryptography::generate_key_context();
668        crypto_ctx.expect().returning(|| {
669            let components = EcKeyComponents::new(
670                URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
671                URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
672            );
673            let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
674            Ok(Key {
675                p256key: components,
676                auth,
677            })
678        });
679        let sub_1 = pm.subscribe("test-scope", None)?;
680        let sub_2 = pm.subscribe("test-scope", None)?;
681        assert_eq!(sub_1, sub_2);
682        Ok(())
683    }
684    #[test]
685    fn test_verify_wipe_uaid_if_mismatch() -> Result<()> {
686        let _m = get_lock(&MTX);
687        let ctx = MockConnection::connect_context();
688        ctx.expect().returning(|_| Default::default());
689
690        let mut pm = get_test_manager()?;
691        pm.connection
692            .expect_register()
693            .with(eq("native-id"), eq(None))
694            .times(2)
695            .returning(|_, _| {
696                Ok(RegisterResponse {
697                    uaid: TEST_UAID.to_string(),
698                    channel_id: TEST_CHANNEL_ID.to_string(),
699                    secret: TEST_AUTH.to_string(),
700                    endpoint: "https://example.com/dummy-endpoint".to_string(),
701                    sender_id: Some("test".to_string()),
702                })
703            });
704
705        let crypto_ctx = MockCryptography::generate_key_context();
706        crypto_ctx.expect().returning(|| {
707            let components = EcKeyComponents::new(
708                URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
709                URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
710            );
711            let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
712            Ok(Key {
713                p256key: components,
714                auth,
715            })
716        });
717        pm.connection
718            .expect_channel_list()
719            .with(eq(TEST_UAID), eq(TEST_AUTH))
720            .times(1)
721            .returning(|_, _| Ok(vec![TEST_CHANNEL_ID2.to_string()]));
722
723        pm.connection
724            .expect_unsubscribe_all()
725            .with(eq(TEST_UAID), eq(TEST_AUTH))
726            .times(1)
727            .returning(|_, _| Ok(()));
728        let _ = pm.subscribe("test-scope", None)?;
729        // verify that a uaid got added to our store and
730        // that there is a record associated with the channel ID provided
731        assert_eq!(pm.store.get_uaid()?.unwrap(), TEST_UAID);
732        assert_eq!(
733            pm.store.get_record(TEST_CHANNEL_ID)?.unwrap().channel_id,
734            TEST_CHANNEL_ID
735        );
736        let unsubscribed_channels = pm.verify_connection(false)?;
737        assert_eq!(unsubscribed_channels.len(), 1);
738        assert_eq!(unsubscribed_channels[0].channel_id, TEST_CHANNEL_ID);
739        // since verify_connection failed,
740        // we wipe the uaid and all associated records from our store
741        assert!(pm.store.get_uaid()?.is_none());
742        assert!(pm.store.get_record(TEST_CHANNEL_ID)?.is_none());
743
744        // we now check that a new subscription will cause us to
745        // re-generate a uaid and store it in our store
746        let _ = pm.subscribe("test-scope", None)?;
747        // verify that the uaid got added to our store and
748        // that there is a record associated with the channel ID provided
749        assert_eq!(pm.store.get_uaid()?.unwrap(), TEST_UAID);
750        assert_eq!(
751            pm.store.get_record(TEST_CHANNEL_ID)?.unwrap().channel_id,
752            TEST_CHANNEL_ID
753        );
754        Ok(())
755    }
756
757    #[test]
758    fn test_verify_server_lost_uaid_not_error() -> Result<()> {
759        let _m = get_lock(&MTX);
760        let ctx = MockConnection::connect_context();
761        ctx.expect().returning(|_| Default::default());
762
763        let mut pm = get_test_manager()?;
764        pm.connection
765            .expect_register()
766            .with(eq("native-id"), eq(None))
767            .times(1)
768            .returning(|_, _| {
769                Ok(RegisterResponse {
770                    uaid: TEST_UAID.to_string(),
771                    channel_id: TEST_CHANNEL_ID.to_string(),
772                    secret: TEST_AUTH.to_string(),
773                    endpoint: "https://example.com/dummy-endpoint".to_string(),
774                    sender_id: Some("test".to_string()),
775                })
776            });
777
778        let crypto_ctx = MockCryptography::generate_key_context();
779        crypto_ctx.expect().returning(|| {
780            let components = EcKeyComponents::new(
781                URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
782                URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
783            );
784            let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
785            Ok(Key {
786                p256key: components,
787                auth,
788            })
789        });
790        pm.connection
791            .expect_channel_list()
792            .with(eq(TEST_UAID), eq(TEST_AUTH))
793            .times(1)
794            .returning(|_, _| {
795                Err(PushError::UAIDNotRecognizedError(
796                    "Couldn't find uaid".to_string(),
797                ))
798            });
799
800        let _ = pm.subscribe("test-scope", None)?;
801        // verify that a uaid got added to our store and
802        // that there is a record associated with the channel ID provided
803        assert_eq!(pm.store.get_uaid()?.unwrap(), TEST_UAID);
804        assert_eq!(
805            pm.store.get_record(TEST_CHANNEL_ID)?.unwrap().channel_id,
806            TEST_CHANNEL_ID
807        );
808        let unsubscribed_channels = pm.verify_connection(false)?;
809        assert_eq!(unsubscribed_channels.len(), 1);
810        assert_eq!(unsubscribed_channels[0].channel_id, TEST_CHANNEL_ID);
811        // since verify_connection failed,
812        // we wipe the uaid and all associated records from our store
813        assert!(pm.store.get_uaid()?.is_none());
814        assert!(pm.store.get_record(TEST_CHANNEL_ID)?.is_none());
815        Ok(())
816    }
817
818    #[test]
819    fn test_verify_server_hard_error() -> Result<()> {
820        let _m = get_lock(&MTX);
821        let ctx = MockConnection::connect_context();
822        ctx.expect().returning(|_| Default::default());
823
824        let mut pm = get_test_manager()?;
825        pm.connection
826            .expect_register()
827            .with(eq("native-id"), eq(None))
828            .times(1)
829            .returning(|_, _| {
830                Ok(RegisterResponse {
831                    uaid: TEST_UAID.to_string(),
832                    channel_id: TEST_CHANNEL_ID.to_string(),
833                    secret: TEST_AUTH.to_string(),
834                    endpoint: "https://example.com/dummy-endpoint".to_string(),
835                    sender_id: Some("test".to_string()),
836                })
837            });
838
839        let crypto_ctx = MockCryptography::generate_key_context();
840        crypto_ctx.expect().returning(|| {
841            let components = EcKeyComponents::new(
842                URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
843                URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
844            );
845            let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
846            Ok(Key {
847                p256key: components,
848                auth,
849            })
850        });
851        pm.connection
852            .expect_channel_list()
853            .with(eq(TEST_UAID), eq(TEST_AUTH))
854            .times(1)
855            .returning(|_, _| {
856                Err(PushError::CommunicationError(
857                    "Unrecoverable error".to_string(),
858                ))
859            });
860
861        let _ = pm.subscribe("test-scope", None)?;
862        // verify that a uaid got added to our store and
863        // that there is a record associated with the channel ID provided
864        assert_eq!(pm.store.get_uaid()?.unwrap(), TEST_UAID);
865        assert_eq!(
866            pm.store.get_record(TEST_CHANNEL_ID)?.unwrap().channel_id,
867            TEST_CHANNEL_ID
868        );
869        let err = pm.verify_connection(false).unwrap_err();
870
871        // the same error got propagated
872        assert!(matches!(err, PushError::CommunicationError(_)));
873        Ok(())
874    }
875
876    #[test]
877    fn test_verify_no_local_uaid_ok() -> Result<()> {
878        let _m = get_lock(&MTX);
879        let ctx = MockConnection::connect_context();
880        ctx.expect().returning(|_| Default::default());
881
882        let mut pm = get_test_manager()?;
883        let channel_list = pm
884            .verify_connection(true)
885            .expect("There are no subscriptions, so verify connection should not fail");
886        assert!(channel_list.is_empty());
887        Ok(())
888    }
889
890    #[test]
891    fn test_second_subscribe_hits_subscribe_endpoint() -> Result<()> {
892        let _m = get_lock(&MTX);
893        let ctx = MockConnection::connect_context();
894        ctx.expect().returning(|_| Default::default());
895
896        let mut pm = get_test_manager()?;
897        pm.connection
898            .expect_register()
899            .with(eq("native-id"), eq(None))
900            .times(1)
901            .returning(|_, _| {
902                Ok(RegisterResponse {
903                    uaid: TEST_UAID.to_string(),
904                    channel_id: TEST_CHANNEL_ID.to_string(),
905                    secret: TEST_AUTH.to_string(),
906                    endpoint: "https://example.com/dummy-endpoint".to_string(),
907                    sender_id: Some("test".to_string()),
908                })
909            });
910
911        pm.connection
912            .expect_subscribe()
913            .with(eq(TEST_UAID), eq(TEST_AUTH), eq("native-id"), eq(None))
914            .times(1)
915            .returning(|_, _, _, _| {
916                Ok(SubscribeResponse {
917                    channel_id: TEST_CHANNEL_ID2.to_string(),
918                    endpoint: "https://example.com/different-dummy-endpoint".to_string(),
919                    sender_id: Some("test".to_string()),
920                })
921            });
922
923        let crypto_ctx = MockCryptography::generate_key_context();
924        crypto_ctx.expect().returning(|| {
925            let components = EcKeyComponents::new(
926                URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
927                URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
928            );
929            let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
930            Ok(Key {
931                p256key: components,
932                auth,
933            })
934        });
935
936        let resp_1 = pm.subscribe("test-scope", None)?;
937        let resp_2 = pm.subscribe("another-scope", None)?;
938        assert_eq!(
939            resp_1.subscription_info.endpoint,
940            "https://example.com/dummy-endpoint"
941        );
942        assert_eq!(
943            resp_2.subscription_info.endpoint,
944            "https://example.com/different-dummy-endpoint"
945        );
946        Ok(())
947    }
948
949    #[test]
950    fn test_verify_connection_rate_limiter() -> Result<()> {
951        let _m = get_lock(&MTX);
952        let ctx = MockConnection::connect_context();
953        ctx.expect().returning(|_| Default::default());
954
955        let mut pm = get_test_manager()?;
956        pm.connection
957            .expect_register()
958            .with(eq("native-id"), eq(None))
959            .times(1)
960            .returning(|_, _| {
961                Ok(RegisterResponse {
962                    uaid: TEST_UAID.to_string(),
963                    channel_id: TEST_CHANNEL_ID.to_string(),
964                    secret: TEST_AUTH.to_string(),
965                    endpoint: "https://example.com/dummy-endpoint".to_string(),
966                    sender_id: Some("test".to_string()),
967                })
968            });
969        let crypto_ctx = MockCryptography::generate_key_context();
970        crypto_ctx.expect().returning(|| {
971            let components = EcKeyComponents::new(
972                URL_SAFE_NO_PAD.decode(PRIV_KEY_D).unwrap(),
973                URL_SAFE_NO_PAD.decode(PUB_KEY_RAW).unwrap(),
974            );
975            let auth = URL_SAFE_NO_PAD.decode(TEST_AUTH).unwrap();
976            Ok(Key {
977                p256key: components,
978                auth,
979            })
980        });
981        let _ = pm.subscribe("test-scope", None)?;
982        pm.connection
983            .expect_channel_list()
984            .with(eq(TEST_UAID), eq(TEST_AUTH))
985            .times(3)
986            .returning(|_, _| Ok(vec![TEST_CHANNEL_ID.to_string()]));
987        let _ = pm.verify_connection(false)?;
988        let (_, count) = pm.verify_connection_rate_limiter.get_counters(&pm.store);
989        assert_eq!(count, 1);
990        let _ = pm.verify_connection(false)?;
991        let (timestamp, count) = pm.verify_connection_rate_limiter.get_counters(&pm.store);
992
993        assert_eq!(count, 2);
994
995        pm.verify_connection_rate_limiter.persist_counters(
996            &pm.store,
997            timestamp - ONE_DAY_AND_ONE_SECOND,
998            count,
999        );
1000
1001        let _ = pm.verify_connection(false)?;
1002        let (_, count) = pm.verify_connection_rate_limiter.get_counters(&pm.store);
1003        assert_eq!(count, 1);
1004
1005        // Even though a day hasn't passed, we passed `true` to force verify
1006        // so the counter is now reset
1007        let _ = pm.verify_connection(true)?;
1008        let (_, count) = pm.verify_connection_rate_limiter.get_counters(&pm.store);
1009        assert_eq!(count, 1);
1010
1011        Ok(())
1012    }
1013}