1use crate::error::{self, debug, trace, warn, Error as ErrorKind, Result};
6use crate::ServerTimestamp;
7use rc_crypto::hawk;
8use serde_derive::*;
9use std::borrow::{Borrow, Cow};
10use std::cell::RefCell;
11use std::fmt;
12use std::time::{Duration, SystemTime};
13use url::Url;
14use viaduct::{header_names, Request};
15
16const RETRY_AFTER_DEFAULT_MS: u64 = 10000;
17
18#[derive(Deserialize, Clone, PartialEq, Eq)]
21struct TokenserverToken {
22 id: String,
23 key: String,
24 api_endpoint: String,
25 uid: u64,
26 duration: u64,
27 hashed_fxa_uid: String,
28}
29
30impl std::fmt::Debug for TokenserverToken {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("TokenserverToken")
33 .field("api_endpoint", &self.api_endpoint)
34 .field("uid", &self.uid)
35 .field("duration", &self.duration)
36 .field("hashed_fxa_uid", &self.hashed_fxa_uid)
37 .finish()
38 }
39}
40
41struct TokenFetchResult {
44 token: TokenserverToken,
45 server_timestamp: ServerTimestamp,
46}
47
48trait TokenFetcher {
51 fn fetch_token(&self) -> crate::Result<TokenFetchResult>;
52 fn now(&self) -> SystemTime;
54}
55
56#[derive(Debug)]
59struct TokenServerFetcher {
60 server_url: Url,
62 access_token: String,
63 key_id: String,
64}
65
66fn fixup_server_url(mut url: Url) -> url::Url {
67 if url.as_str().ends_with("1.0/sync/1.5") {
71 } else if url.as_str().ends_with("1.0/sync/1.5/") {
73 if let Ok(mut path) = url.path_segments_mut() {
76 path.pop();
77 }
78 } else {
79 if let Ok(mut path) = url.path_segments_mut() {
83 path.pop_if_empty();
84 path.extend(&["1.0", "sync", "1.5"]);
85 }
86 };
87 url
88}
89
90impl TokenServerFetcher {
91 fn new(base_url: Url, access_token: String, key_id: String) -> TokenServerFetcher {
92 TokenServerFetcher {
93 server_url: fixup_server_url(base_url),
94 access_token,
95 key_id,
96 }
97 }
98}
99
100impl TokenFetcher for TokenServerFetcher {
101 fn fetch_token(&self) -> Result<TokenFetchResult> {
102 debug!("Fetching token from {}", self.server_url);
103 let resp = Request::get(self.server_url.clone())
104 .header(
105 header_names::AUTHORIZATION,
106 format!("Bearer {}", self.access_token),
107 )?
108 .header(header_names::X_KEYID, self.key_id.clone())?
109 .send()?;
110
111 if !resp.is_success() {
112 warn!("Non-success status when fetching token: {}", resp.status);
113 trace!(" Response body {}", resp.text());
115 if let Some(res) = resp.headers.get_as::<f64, _>(header_names::RETRY_AFTER) {
118 let ms = res
119 .ok()
120 .map_or(RETRY_AFTER_DEFAULT_MS, |f| (f * 1000f64) as u64);
121 let when = self.now() + Duration::from_millis(ms);
122 return Err(ErrorKind::BackoffError(when));
123 }
124 let status = resp.status;
125 return Err(ErrorKind::TokenserverHttpError(status));
126 }
127
128 let token: TokenserverToken = resp.json()?;
129 let server_timestamp = resp
130 .headers
131 .try_get::<ServerTimestamp, _>(header_names::X_TIMESTAMP)
132 .ok_or(ErrorKind::MissingServerTimestamp)?;
133 Ok(TokenFetchResult {
134 token,
135 server_timestamp,
136 })
137 }
138
139 fn now(&self) -> SystemTime {
140 SystemTime::now()
141 }
142}
143
144struct TokenContext {
147 token: TokenserverToken,
148 credentials: hawk::Credentials,
149 server_timestamp: ServerTimestamp,
150 valid_until: SystemTime,
151}
152
153impl fmt::Debug for TokenContext {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> ::std::result::Result<(), fmt::Error> {
156 f.debug_struct("TokenContext")
157 .field("token", &self.token)
158 .field("credentials", &"(omitted)")
159 .field("server_timestamp", &self.server_timestamp)
160 .field("valid_until", &self.valid_until)
161 .finish()
162 }
163}
164
165impl TokenContext {
166 fn new(
167 token: TokenserverToken,
168 credentials: hawk::Credentials,
169 server_timestamp: ServerTimestamp,
170 valid_until: SystemTime,
171 ) -> Self {
172 Self {
173 token,
174 credentials,
175 server_timestamp,
176 valid_until,
177 }
178 }
179
180 fn is_valid(&self, now: SystemTime) -> bool {
181 now < self.valid_until
188 }
189
190 fn authorization(&self, req: &Request) -> Result<String> {
191 let url = &req.url;
192
193 let path_and_query = match url.query() {
194 None => Cow::from(url.path()),
195 Some(qs) => Cow::from(format!("{}?{}", url.path(), qs)),
196 };
197
198 let host = url
199 .host_str()
200 .ok_or_else(|| ErrorKind::UnacceptableUrl("Storage URL has no host".into()))?;
201
202 let port = url.port_or_known_default().ok_or_else(|| {
204 ErrorKind::UnacceptableUrl(
205 "Storage URL has no port and no default port is known for the protocol".into(),
206 )
207 })?;
208
209 let header =
210 hawk::RequestBuilder::new(req.method.as_str(), host, port, path_and_query.borrow())
211 .request()
212 .make_header(&self.credentials)?;
213
214 Ok(format!("Hawk {}", header))
215 }
216}
217
218#[derive(Debug)]
220enum TokenState {
221 NoToken,
223 Token(TokenContext),
225 Failed(Option<error::Error>, Option<String>),
229 Backoff(SystemTime, Option<String>),
233 NodeReassigned,
236}
237
238#[derive(Debug)]
241struct TokenProviderImpl<TF: TokenFetcher> {
242 fetcher: TF,
243 current_state: RefCell<TokenState>,
245}
246
247impl<TF: TokenFetcher> TokenProviderImpl<TF> {
248 fn new(fetcher: TF) -> Self {
249 rc_crypto::ensure_initialized();
252 TokenProviderImpl {
253 fetcher,
254 current_state: RefCell::new(TokenState::NoToken),
255 }
256 }
257
258 fn fetch_context(&self) -> Result<TokenContext> {
261 let result = self.fetcher.fetch_token()?;
262 let token = result.token;
263 let valid_until = SystemTime::now() + Duration::from_secs(token.duration);
264
265 let credentials = hawk::Credentials {
266 id: token.id.clone(),
267 key: hawk::Key::new(token.key.as_bytes(), hawk::SHA256)?,
268 };
269
270 Ok(TokenContext::new(
271 token,
272 credentials,
273 result.server_timestamp,
274 valid_until,
275 ))
276 }
277
278 fn fetch_token(&self, previous_endpoint: Option<&str>) -> TokenState {
282 match self.fetch_context() {
283 Ok(tc) => {
284 match previous_endpoint {
287 Some(prev) => {
288 if prev == tc.token.api_endpoint {
289 TokenState::Token(tc)
290 } else {
291 warn!(
292 "api_endpoint changed from {} to {}",
293 prev, tc.token.api_endpoint
294 );
295 TokenState::NodeReassigned
296 }
297 }
298 None => {
299 TokenState::Token(tc)
301 }
302 }
303 }
304 Err(e) => {
305 if let ErrorKind::BackoffError(be) = e {
307 return TokenState::Backoff(be, previous_endpoint.map(ToString::to_string));
308 }
309 TokenState::Failed(Some(e), previous_endpoint.map(ToString::to_string))
310 }
311 }
312 }
313
314 fn advance_state(&self, state: &TokenState) -> Option<TokenState> {
319 match state {
320 TokenState::NoToken => Some(self.fetch_token(None)),
321 TokenState::Failed(_, existing_endpoint) => {
322 Some(self.fetch_token(existing_endpoint.as_ref().map(String::as_str)))
323 }
324 TokenState::Token(existing_context) => {
325 if existing_context.is_valid(self.fetcher.now()) {
326 None
327 } else {
328 Some(self.fetch_token(Some(existing_context.token.api_endpoint.as_str())))
329 }
330 }
331 TokenState::Backoff(ref until, ref existing_endpoint) => {
332 if let Ok(remaining) = until.duration_since(self.fetcher.now()) {
333 debug!("enforcing existing backoff - {:?} remains", remaining);
334 None
335 } else {
336 Some(self.fetch_token(existing_endpoint.as_ref().map(String::as_str)))
338 }
339 }
340 TokenState::NodeReassigned => {
341 None
343 }
344 }
345 }
346
347 fn with_token<T, F>(&self, func: F) -> Result<T>
348 where
349 F: FnOnce(&TokenContext) -> Result<T>,
350 {
351 let state: &mut TokenState = &mut self.current_state.borrow_mut();
354 if let Some(new_state) = self.advance_state(state) {
355 *state = new_state;
356 }
357
358 match state {
361 TokenState::NoToken => {
362 panic!("Can't be in NoToken state after advancing");
364 }
365 TokenState::Token(ref token_context) => {
366 func(token_context)
368 }
369 TokenState::Failed(e, _) => {
370 Err(e.take().unwrap())
372 }
373 TokenState::NodeReassigned => {
374 Err(ErrorKind::StorageResetError)
376 }
377 TokenState::Backoff(ref remaining, _) => Err(ErrorKind::BackoffError(*remaining)),
378 }
379 }
380
381 fn hashed_uid(&self) -> Result<String> {
382 self.with_token(|ctx| Ok(ctx.token.hashed_fxa_uid.clone()))
383 }
384
385 fn authorization(&self, req: &Request) -> Result<String> {
386 self.with_token(|ctx| ctx.authorization(req))
387 }
388
389 fn api_endpoint(&self) -> Result<String> {
390 self.with_token(|ctx| Ok(ctx.token.api_endpoint.clone()))
391 }
392 }
396
397#[derive(Debug)]
399pub struct TokenProvider {
400 imp: TokenProviderImpl<TokenServerFetcher>,
401}
402
403impl TokenProvider {
404 pub fn new(url: Url, access_token: String, key_id: String) -> Self {
405 let fetcher = TokenServerFetcher::new(url, access_token, key_id);
406 Self {
407 imp: TokenProviderImpl::new(fetcher),
408 }
409 }
410
411 pub fn hashed_uid(&self) -> Result<String> {
412 self.imp.hashed_uid()
413 }
414
415 pub fn authorization(&self, req: &Request) -> Result<String> {
416 self.imp.authorization(req)
417 }
418
419 pub fn api_endpoint(&self) -> Result<String> {
420 self.imp.api_endpoint()
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use std::cell::Cell;
428
429 struct TestFetcher<FF, FN>
430 where
431 FF: Fn() -> Result<TokenFetchResult>,
432 FN: Fn() -> SystemTime,
433 {
434 fetch: FF,
435 now: FN,
436 }
437 impl<FF, FN> TokenFetcher for TestFetcher<FF, FN>
438 where
439 FF: Fn() -> Result<TokenFetchResult>,
440 FN: Fn() -> SystemTime,
441 {
442 fn fetch_token(&self) -> Result<TokenFetchResult> {
443 (self.fetch)()
444 }
445 fn now(&self) -> SystemTime {
446 (self.now)()
447 }
448 }
449
450 fn make_tsc<FF, FN>(fetch: FF, now: FN) -> Result<TokenProviderImpl<TestFetcher<FF, FN>>>
451 where
452 FF: Fn() -> Result<TokenFetchResult>,
453 FN: Fn() -> SystemTime,
454 {
455 let fetcher: TestFetcher<FF, FN> = TestFetcher { fetch, now };
456 Ok(TokenProviderImpl::new(fetcher))
457 }
458
459 #[test]
460 fn test_endpoint() {
461 nss::ensure_initialized();
462 let counter: Cell<u32> = Cell::new(0);
464 let fetch = || {
465 counter.set(counter.get() + 1);
466 Ok(TokenFetchResult {
467 token: TokenserverToken {
468 id: "id".to_string(),
469 key: "key".to_string(),
470 api_endpoint: "api_endpoint".to_string(),
471 uid: 1,
472 duration: 1000,
473 hashed_fxa_uid: "hash".to_string(),
474 },
475 server_timestamp: ServerTimestamp(0i64),
476 })
477 };
478
479 let tsc = make_tsc(fetch, SystemTime::now).unwrap();
480
481 let e = tsc.api_endpoint().expect("should work");
482 assert_eq!(e, "api_endpoint".to_string());
483 assert_eq!(counter.get(), 1);
484
485 let e2 = tsc.api_endpoint().expect("should work");
486 assert_eq!(e2, "api_endpoint".to_string());
487 assert_eq!(counter.get(), 1);
489 }
490
491 #[test]
492 fn test_backoff() {
493 nss::ensure_initialized();
494 let counter: Cell<u32> = Cell::new(0);
495 let fetch = || {
496 counter.set(counter.get() + 1);
497 let when = SystemTime::now() + Duration::from_millis(10000);
498 Err(ErrorKind::BackoffError(when))
499 };
500 let now: Cell<SystemTime> = Cell::new(SystemTime::now());
501 let tsc = make_tsc(fetch, || now.get()).unwrap();
502
503 tsc.api_endpoint().expect_err("should bail");
504 assert_eq!(counter.get(), 1);
506 tsc.api_endpoint().expect_err("should bail");
509 assert_eq!(counter.get(), 1);
510
511 now.set(now.get() + Duration::new(20, 0));
513
514 tsc.api_endpoint().expect_err("should bail");
517 assert_eq!(counter.get(), 2);
518 }
519
520 #[test]
521 fn test_validity() {
522 nss::ensure_initialized();
523 let counter: Cell<u32> = Cell::new(0);
524 let fetch = || {
525 counter.set(counter.get() + 1);
526 Ok(TokenFetchResult {
527 token: TokenserverToken {
528 id: "id".to_string(),
529 key: "key".to_string(),
530 api_endpoint: "api_endpoint".to_string(),
531 uid: 1,
532 duration: 10,
533 hashed_fxa_uid: "hash".to_string(),
534 },
535 server_timestamp: ServerTimestamp(0i64),
536 })
537 };
538 let now: Cell<SystemTime> = Cell::new(SystemTime::now());
539 let tsc = make_tsc(fetch, || now.get()).unwrap();
540
541 tsc.api_endpoint().expect("should get a valid token");
542 assert_eq!(counter.get(), 1);
543
544 tsc.api_endpoint().expect("should reuse existing token");
547 assert_eq!(counter.get(), 1);
548
549 now.set(now.get() + Duration::new(20, 0));
551
552 tsc.api_endpoint().expect("should re-fetch");
554 assert_eq!(counter.get(), 2);
555 }
556
557 #[test]
558 fn test_server_url() {
559 assert_eq!(
560 fixup_server_url(
561 Url::parse("https://token.services.mozilla.com/1.0/sync/1.5").unwrap()
562 )
563 .as_str(),
564 "https://token.services.mozilla.com/1.0/sync/1.5"
565 );
566 assert_eq!(
567 fixup_server_url(
568 Url::parse("https://token.services.mozilla.com/1.0/sync/1.5/").unwrap()
569 )
570 .as_str(),
571 "https://token.services.mozilla.com/1.0/sync/1.5"
572 );
573 assert_eq!(
574 fixup_server_url(Url::parse("https://token.services.mozilla.com").unwrap()).as_str(),
575 "https://token.services.mozilla.com/1.0/sync/1.5"
576 );
577 assert_eq!(
578 fixup_server_url(Url::parse("https://token.services.mozilla.com/").unwrap()).as_str(),
579 "https://token.services.mozilla.com/1.0/sync/1.5"
580 );
581 assert_eq!(
582 fixup_server_url(
583 Url::parse("https://selfhosted.example.com/token/1.0/sync/1.5").unwrap()
584 )
585 .as_str(),
586 "https://selfhosted.example.com/token/1.0/sync/1.5"
587 );
588 assert_eq!(
589 fixup_server_url(
590 Url::parse("https://selfhosted.example.com/token/1.0/sync/1.5/").unwrap()
591 )
592 .as_str(),
593 "https://selfhosted.example.com/token/1.0/sync/1.5"
594 );
595 assert_eq!(
596 fixup_server_url(Url::parse("https://selfhosted.example.com/token/").unwrap()).as_str(),
597 "https://selfhosted.example.com/token/1.0/sync/1.5"
598 );
599 assert_eq!(
600 fixup_server_url(Url::parse("https://selfhosted.example.com/token").unwrap()).as_str(),
601 "https://selfhosted.example.com/token/1.0/sync/1.5"
602 );
603 }
604}