sync15/client/
token.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5use crate::error::{self, debug, trace, warn, Error as ErrorKind, Result};
6use crate::ServerTimestamp;
7use rc_crypto::hawk;
8use serde_derive::*;
9use std::borrow::{Borrow, Cow};
10use std::cell::RefCell;
11use std::fmt;
12use std::time::{Duration, SystemTime};
13use url::Url;
14use viaduct::{header_names, Request};
15
16const RETRY_AFTER_DEFAULT_MS: u64 = 10000;
17
18// The TokenserverToken is the token as received directly from the token server
19// and deserialized from JSON.
20#[derive(Deserialize, Clone, PartialEq, Eq)]
21struct TokenserverToken {
22    id: String,
23    key: String,
24    api_endpoint: String,
25    uid: u64,
26    duration: u64,
27    hashed_fxa_uid: String,
28}
29
30impl std::fmt::Debug for TokenserverToken {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("TokenserverToken")
33            .field("api_endpoint", &self.api_endpoint)
34            .field("uid", &self.uid)
35            .field("duration", &self.duration)
36            .field("hashed_fxa_uid", &self.hashed_fxa_uid)
37            .finish()
38    }
39}
40
41// The struct returned by the TokenFetcher - the token itself and the
42// server timestamp.
43struct TokenFetchResult {
44    token: TokenserverToken,
45    server_timestamp: ServerTimestamp,
46}
47
48// The trait for fetching tokens - we'll provide a "real" implementation but
49// tests will re-implement it.
50trait TokenFetcher {
51    fn fetch_token(&self) -> crate::Result<TokenFetchResult>;
52    // We allow the trait to tell us what the time is so tests can get funky.
53    fn now(&self) -> SystemTime;
54}
55
56// Our "real" token fetcher, implementing the TokenFetcher trait, which hits
57// the token server
58#[derive(Debug)]
59struct TokenServerFetcher {
60    // The stuff needed to fetch a token.
61    server_url: Url,
62    access_token: String,
63    key_id: String,
64}
65
66fn fixup_server_url(mut url: Url) -> url::Url {
67    // The given `url` is the end-point as returned by .well-known/fxa-client-configuration,
68    // or as directly specified by self-hosters. As a result, it may or may not have
69    // the sync 1.5 suffix of "/1.0/sync/1.5", so add it on here if it does not.
70    if url.as_str().ends_with("1.0/sync/1.5") {
71        // ok!
72    } else if url.as_str().ends_with("1.0/sync/1.5/") {
73        // Shouldn't ever be Err() here, but the result is `Result<PathSegmentsMut, ()>`
74        // and I don't want to unwrap or add a new error type just for PathSegmentsMut failing.
75        if let Ok(mut path) = url.path_segments_mut() {
76            path.pop();
77        }
78    } else {
79        // We deliberately don't use `.join()` here in order to preserve all path components.
80        // For example, "http://example.com/token" should produce "http://example.com/token/1.0/sync/1.5"
81        // but using `.join()` would produce "http://example.com/1.0/sync/1.5".
82        if let Ok(mut path) = url.path_segments_mut() {
83            path.pop_if_empty();
84            path.extend(&["1.0", "sync", "1.5"]);
85        }
86    };
87    url
88}
89
90impl TokenServerFetcher {
91    fn new(base_url: Url, access_token: String, key_id: String) -> TokenServerFetcher {
92        TokenServerFetcher {
93            server_url: fixup_server_url(base_url),
94            access_token,
95            key_id,
96        }
97    }
98}
99
100impl TokenFetcher for TokenServerFetcher {
101    fn fetch_token(&self) -> Result<TokenFetchResult> {
102        debug!("Fetching token from {}", self.server_url);
103        let resp = Request::get(self.server_url.clone())
104            .header(
105                header_names::AUTHORIZATION,
106                format!("Bearer {}", self.access_token),
107            )?
108            .header(header_names::X_KEYID, self.key_id.clone())?
109            .send()?;
110
111        if !resp.is_success() {
112            warn!("Non-success status when fetching token: {}", resp.status);
113            // TODO: the body should be JSON and contain a status parameter we might need?
114            trace!("  Response body {}", resp.text());
115            // XXX - shouldn't we "chain" these errors - ie, a BackoffError could
116            // have a TokenserverHttpError as its cause?
117            if let Some(res) = resp.headers.get_as::<f64, _>(header_names::RETRY_AFTER) {
118                let ms = res
119                    .ok()
120                    .map_or(RETRY_AFTER_DEFAULT_MS, |f| (f * 1000f64) as u64);
121                let when = self.now() + Duration::from_millis(ms);
122                return Err(ErrorKind::BackoffError(when));
123            }
124            let status = resp.status;
125            return Err(ErrorKind::TokenserverHttpError(status));
126        }
127
128        let token: TokenserverToken = resp.json()?;
129        let server_timestamp = resp
130            .headers
131            .try_get::<ServerTimestamp, _>(header_names::X_TIMESTAMP)
132            .ok_or(ErrorKind::MissingServerTimestamp)?;
133        Ok(TokenFetchResult {
134            token,
135            server_timestamp,
136        })
137    }
138
139    fn now(&self) -> SystemTime {
140        SystemTime::now()
141    }
142}
143
144// The context stored by our TokenProvider when it has a TokenState::Token
145// state.
146struct TokenContext {
147    token: TokenserverToken,
148    credentials: hawk::Credentials,
149    server_timestamp: ServerTimestamp,
150    valid_until: SystemTime,
151}
152
153// hawk::Credentials doesn't implement debug -_-
154impl fmt::Debug for TokenContext {
155    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> ::std::result::Result<(), fmt::Error> {
156        f.debug_struct("TokenContext")
157            .field("token", &self.token)
158            .field("credentials", &"(omitted)")
159            .field("server_timestamp", &self.server_timestamp)
160            .field("valid_until", &self.valid_until)
161            .finish()
162    }
163}
164
165impl TokenContext {
166    fn new(
167        token: TokenserverToken,
168        credentials: hawk::Credentials,
169        server_timestamp: ServerTimestamp,
170        valid_until: SystemTime,
171    ) -> Self {
172        Self {
173            token,
174            credentials,
175            server_timestamp,
176            valid_until,
177        }
178    }
179
180    fn is_valid(&self, now: SystemTime) -> bool {
181        // We could consider making the duration a little shorter - if it
182        // only has 1 second validity there seems a reasonable chance it will
183        // have expired by the time it gets presented to the remote that wants
184        // it.
185        // Either way though, we will eventually need to handle a token being
186        // rejected as a non-fatal error and recover, so maybe we don't care?
187        now < self.valid_until
188    }
189
190    fn authorization(&self, req: &Request) -> Result<String> {
191        let url = &req.url;
192
193        let path_and_query = match url.query() {
194            None => Cow::from(url.path()),
195            Some(qs) => Cow::from(format!("{}?{}", url.path(), qs)),
196        };
197
198        let host = url
199            .host_str()
200            .ok_or_else(|| ErrorKind::UnacceptableUrl("Storage URL has no host".into()))?;
201
202        // Known defaults exist for https? (among others), so this should be impossible
203        let port = url.port_or_known_default().ok_or_else(|| {
204            ErrorKind::UnacceptableUrl(
205                "Storage URL has no port and no default port is known for the protocol".into(),
206            )
207        })?;
208
209        let header =
210            hawk::RequestBuilder::new(req.method.as_str(), host, port, path_and_query.borrow())
211                .request()
212                .make_header(&self.credentials)?;
213
214        Ok(format!("Hawk {}", header))
215    }
216}
217
218// The state our TokenProvider holds to reflect the state of the token.
219#[derive(Debug)]
220enum TokenState {
221    // We've never fetched a token.
222    NoToken,
223    // Have a token and last we checked it remained valid.
224    Token(TokenContext),
225    // We failed to fetch a token. First elt is the error, second elt is
226    // the api_endpoint we had before we failed to fetch a new token (or
227    // None if the very first attempt at fetching a token failed)
228    Failed(Option<error::Error>, Option<String>),
229    // Previously failed and told to back-off for SystemTime duration. Second
230    // elt is the api_endpoint we had before we hit the backoff error.
231    // XXX - should we roll Backoff and Failed together?
232    Backoff(SystemTime, Option<String>),
233    // api_endpoint changed - we are never going to get a token nor move out
234    // of this state.
235    NodeReassigned,
236}
237
238/// The generic TokenProvider implementation - long lived and fetches tokens
239/// on demand (eg, when first needed, or when an existing one expires.)
240#[derive(Debug)]
241struct TokenProviderImpl<TF: TokenFetcher> {
242    fetcher: TF,
243    // Our token state (ie, whether we have a token, and if not, why not)
244    current_state: RefCell<TokenState>,
245}
246
247impl<TF: TokenFetcher> TokenProviderImpl<TF> {
248    fn new(fetcher: TF) -> Self {
249        // We check this at the real entrypoint of the application, but tests
250        // can/do bypass that, so check this here too.
251        rc_crypto::ensure_initialized();
252        TokenProviderImpl {
253            fetcher,
254            current_state: RefCell::new(TokenState::NoToken),
255        }
256    }
257
258    // Uses our fetcher to grab a new token and if successful, derives other
259    // info from that token into a usable TokenContext.
260    fn fetch_context(&self) -> Result<TokenContext> {
261        let result = self.fetcher.fetch_token()?;
262        let token = result.token;
263        let valid_until = SystemTime::now() + Duration::from_secs(token.duration);
264
265        let credentials = hawk::Credentials {
266            id: token.id.clone(),
267            key: hawk::Key::new(token.key.as_bytes(), hawk::SHA256)?,
268        };
269
270        Ok(TokenContext::new(
271            token,
272            credentials,
273            result.server_timestamp,
274            valid_until,
275        ))
276    }
277
278    // Attempt to fetch a new token and return a new state reflecting that
279    // operation. If it worked a TokenState will be returned, but errors may
280    // cause other states.
281    fn fetch_token(&self, previous_endpoint: Option<&str>) -> TokenState {
282        match self.fetch_context() {
283            Ok(tc) => {
284                // We got a new token - check that the endpoint is the same
285                // as a previous endpoint we saw (if any)
286                match previous_endpoint {
287                    Some(prev) => {
288                        if prev == tc.token.api_endpoint {
289                            TokenState::Token(tc)
290                        } else {
291                            warn!(
292                                "api_endpoint changed from {} to {}",
293                                prev, tc.token.api_endpoint
294                            );
295                            TokenState::NodeReassigned
296                        }
297                    }
298                    None => {
299                        // Never had an api_endpoint in the past, so this is OK.
300                        TokenState::Token(tc)
301                    }
302                }
303            }
304            Err(e) => {
305                // Early to avoid nll issues...
306                if let ErrorKind::BackoffError(be) = e {
307                    return TokenState::Backoff(be, previous_endpoint.map(ToString::to_string));
308                }
309                TokenState::Failed(Some(e), previous_endpoint.map(ToString::to_string))
310            }
311        }
312    }
313
314    // Given the state we are currently in, return a new current state.
315    // Returns None if the current state should be used (eg, if we are
316    // holding a token that remains valid) or Some() if the state has changed
317    // (which may have changed to a state with a token or an error state)
318    fn advance_state(&self, state: &TokenState) -> Option<TokenState> {
319        match state {
320            TokenState::NoToken => Some(self.fetch_token(None)),
321            TokenState::Failed(_, existing_endpoint) => {
322                Some(self.fetch_token(existing_endpoint.as_ref().map(String::as_str)))
323            }
324            TokenState::Token(existing_context) => {
325                if existing_context.is_valid(self.fetcher.now()) {
326                    None
327                } else {
328                    Some(self.fetch_token(Some(existing_context.token.api_endpoint.as_str())))
329                }
330            }
331            TokenState::Backoff(ref until, ref existing_endpoint) => {
332                if let Ok(remaining) = until.duration_since(self.fetcher.now()) {
333                    debug!("enforcing existing backoff - {:?} remains", remaining);
334                    None
335                } else {
336                    // backoff period is over
337                    Some(self.fetch_token(existing_endpoint.as_ref().map(String::as_str)))
338                }
339            }
340            TokenState::NodeReassigned => {
341                // We never leave this state.
342                None
343            }
344        }
345    }
346
347    fn with_token<T, F>(&self, func: F) -> Result<T>
348    where
349        F: FnOnce(&TokenContext) -> Result<T>,
350    {
351        // first get a mutable ref to our existing state, advance to the
352        // state we will use, then re-stash that state for next time.
353        let state: &mut TokenState = &mut self.current_state.borrow_mut();
354        if let Some(new_state) = self.advance_state(state) {
355            *state = new_state;
356        }
357
358        // Now re-fetch the state we should use for this call - if it's
359        // anything other than TokenState::Token we will fail.
360        match state {
361            TokenState::NoToken => {
362                // it should be impossible to get here.
363                panic!("Can't be in NoToken state after advancing");
364            }
365            TokenState::Token(ref token_context) => {
366                // make the call.
367                func(token_context)
368            }
369            TokenState::Failed(e, _) => {
370                // We swap the error out of the state enum and return it.
371                Err(e.take().unwrap())
372            }
373            TokenState::NodeReassigned => {
374                // this is unrecoverable.
375                Err(ErrorKind::StorageResetError)
376            }
377            TokenState::Backoff(ref remaining, _) => Err(ErrorKind::BackoffError(*remaining)),
378        }
379    }
380
381    fn hashed_uid(&self) -> Result<String> {
382        self.with_token(|ctx| Ok(ctx.token.hashed_fxa_uid.clone()))
383    }
384
385    fn authorization(&self, req: &Request) -> Result<String> {
386        self.with_token(|ctx| ctx.authorization(req))
387    }
388
389    fn api_endpoint(&self) -> Result<String> {
390        self.with_token(|ctx| Ok(ctx.token.api_endpoint.clone()))
391    }
392    // TODO: we probably want a "drop_token/context" type method so that when
393    // using a token with some validity fails the caller can force a new one
394    // (in which case the new token request will probably fail with a 401)
395}
396
397// The public concrete object exposed by this module
398#[derive(Debug)]
399pub struct TokenProvider {
400    imp: TokenProviderImpl<TokenServerFetcher>,
401}
402
403impl TokenProvider {
404    pub fn new(url: Url, access_token: String, key_id: String) -> Self {
405        let fetcher = TokenServerFetcher::new(url, access_token, key_id);
406        Self {
407            imp: TokenProviderImpl::new(fetcher),
408        }
409    }
410
411    pub fn hashed_uid(&self) -> Result<String> {
412        self.imp.hashed_uid()
413    }
414
415    pub fn authorization(&self, req: &Request) -> Result<String> {
416        self.imp.authorization(req)
417    }
418
419    pub fn api_endpoint(&self) -> Result<String> {
420        self.imp.api_endpoint()
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use std::cell::Cell;
428
429    struct TestFetcher<FF, FN>
430    where
431        FF: Fn() -> Result<TokenFetchResult>,
432        FN: Fn() -> SystemTime,
433    {
434        fetch: FF,
435        now: FN,
436    }
437    impl<FF, FN> TokenFetcher for TestFetcher<FF, FN>
438    where
439        FF: Fn() -> Result<TokenFetchResult>,
440        FN: Fn() -> SystemTime,
441    {
442        fn fetch_token(&self) -> Result<TokenFetchResult> {
443            (self.fetch)()
444        }
445        fn now(&self) -> SystemTime {
446            (self.now)()
447        }
448    }
449
450    fn make_tsc<FF, FN>(fetch: FF, now: FN) -> Result<TokenProviderImpl<TestFetcher<FF, FN>>>
451    where
452        FF: Fn() -> Result<TokenFetchResult>,
453        FN: Fn() -> SystemTime,
454    {
455        let fetcher: TestFetcher<FF, FN> = TestFetcher { fetch, now };
456        Ok(TokenProviderImpl::new(fetcher))
457    }
458
459    #[test]
460    fn test_endpoint() {
461        nss::ensure_initialized();
462        // Use a cell to avoid the closure having a mutable ref to this scope.
463        let counter: Cell<u32> = Cell::new(0);
464        let fetch = || {
465            counter.set(counter.get() + 1);
466            Ok(TokenFetchResult {
467                token: TokenserverToken {
468                    id: "id".to_string(),
469                    key: "key".to_string(),
470                    api_endpoint: "api_endpoint".to_string(),
471                    uid: 1,
472                    duration: 1000,
473                    hashed_fxa_uid: "hash".to_string(),
474                },
475                server_timestamp: ServerTimestamp(0i64),
476            })
477        };
478
479        let tsc = make_tsc(fetch, SystemTime::now).unwrap();
480
481        let e = tsc.api_endpoint().expect("should work");
482        assert_eq!(e, "api_endpoint".to_string());
483        assert_eq!(counter.get(), 1);
484
485        let e2 = tsc.api_endpoint().expect("should work");
486        assert_eq!(e2, "api_endpoint".to_string());
487        // should not have re-fetched.
488        assert_eq!(counter.get(), 1);
489    }
490
491    #[test]
492    fn test_backoff() {
493        nss::ensure_initialized();
494        let counter: Cell<u32> = Cell::new(0);
495        let fetch = || {
496            counter.set(counter.get() + 1);
497            let when = SystemTime::now() + Duration::from_millis(10000);
498            Err(ErrorKind::BackoffError(when))
499        };
500        let now: Cell<SystemTime> = Cell::new(SystemTime::now());
501        let tsc = make_tsc(fetch, || now.get()).unwrap();
502
503        tsc.api_endpoint().expect_err("should bail");
504        // XXX - check error type.
505        assert_eq!(counter.get(), 1);
506        // try and get another token - should not re-fetch as backoff is still
507        // in progress.
508        tsc.api_endpoint().expect_err("should bail");
509        assert_eq!(counter.get(), 1);
510
511        // Advance the clock.
512        now.set(now.get() + Duration::new(20, 0));
513
514        // Our token fetch mock is still returning a backoff error, so we
515        // still fail, but should have re-hit the fetch function.
516        tsc.api_endpoint().expect_err("should bail");
517        assert_eq!(counter.get(), 2);
518    }
519
520    #[test]
521    fn test_validity() {
522        nss::ensure_initialized();
523        let counter: Cell<u32> = Cell::new(0);
524        let fetch = || {
525            counter.set(counter.get() + 1);
526            Ok(TokenFetchResult {
527                token: TokenserverToken {
528                    id: "id".to_string(),
529                    key: "key".to_string(),
530                    api_endpoint: "api_endpoint".to_string(),
531                    uid: 1,
532                    duration: 10,
533                    hashed_fxa_uid: "hash".to_string(),
534                },
535                server_timestamp: ServerTimestamp(0i64),
536            })
537        };
538        let now: Cell<SystemTime> = Cell::new(SystemTime::now());
539        let tsc = make_tsc(fetch, || now.get()).unwrap();
540
541        tsc.api_endpoint().expect("should get a valid token");
542        assert_eq!(counter.get(), 1);
543
544        // try and get another token - should not re-fetch as the old one
545        // remains valid.
546        tsc.api_endpoint().expect("should reuse existing token");
547        assert_eq!(counter.get(), 1);
548
549        // Advance the clock.
550        now.set(now.get() + Duration::new(20, 0));
551
552        // We should discard our token and fetch a new one.
553        tsc.api_endpoint().expect("should re-fetch");
554        assert_eq!(counter.get(), 2);
555    }
556
557    #[test]
558    fn test_server_url() {
559        assert_eq!(
560            fixup_server_url(
561                Url::parse("https://token.services.mozilla.com/1.0/sync/1.5").unwrap()
562            )
563            .as_str(),
564            "https://token.services.mozilla.com/1.0/sync/1.5"
565        );
566        assert_eq!(
567            fixup_server_url(
568                Url::parse("https://token.services.mozilla.com/1.0/sync/1.5/").unwrap()
569            )
570            .as_str(),
571            "https://token.services.mozilla.com/1.0/sync/1.5"
572        );
573        assert_eq!(
574            fixup_server_url(Url::parse("https://token.services.mozilla.com").unwrap()).as_str(),
575            "https://token.services.mozilla.com/1.0/sync/1.5"
576        );
577        assert_eq!(
578            fixup_server_url(Url::parse("https://token.services.mozilla.com/").unwrap()).as_str(),
579            "https://token.services.mozilla.com/1.0/sync/1.5"
580        );
581        assert_eq!(
582            fixup_server_url(
583                Url::parse("https://selfhosted.example.com/token/1.0/sync/1.5").unwrap()
584            )
585            .as_str(),
586            "https://selfhosted.example.com/token/1.0/sync/1.5"
587        );
588        assert_eq!(
589            fixup_server_url(
590                Url::parse("https://selfhosted.example.com/token/1.0/sync/1.5/").unwrap()
591            )
592            .as_str(),
593            "https://selfhosted.example.com/token/1.0/sync/1.5"
594        );
595        assert_eq!(
596            fixup_server_url(Url::parse("https://selfhosted.example.com/token/").unwrap()).as_str(),
597            "https://selfhosted.example.com/token/1.0/sync/1.5"
598        );
599        assert_eq!(
600            fixup_server_url(Url::parse("https://selfhosted.example.com/token").unwrap()).as_str(),
601            "https://selfhosted.example.com/token/1.0/sync/1.5"
602        );
603    }
604}