1use serde::{Deserialize, Serialize};
20use url::Url;
21use viaduct::{header_names, status_codes, Headers, Request};
22
23use crate::error::{
24 self, info,
25 PushError::{
26 AlreadyRegisteredError, CommunicationError, CommunicationServerError,
27 UAIDNotRecognizedError,
28 },
29};
30use crate::internal::config::PushConfiguration;
31use crate::internal::storage::Store;
32
33mod rate_limiter;
34pub use rate_limiter::PersistedRateLimiter;
35
36const UAID_NOT_FOUND_ERRNO: u32 = 103;
37#[derive(Deserialize, Debug)]
38pub struct RegisterResponse {
40 pub uaid: String,
42
43 #[serde(rename = "channelID")]
48 pub channel_id: String,
49
50 pub secret: String,
52
53 pub endpoint: String,
55
56 #[allow(dead_code)]
58 #[serde(rename = "senderid")]
59 pub sender_id: Option<String>,
60}
61
62#[derive(Deserialize, Debug)]
63pub struct SubscribeResponse {
65 #[serde(rename = "channelID")]
71 pub channel_id: String,
72
73 pub endpoint: String,
75
76 #[allow(dead_code)]
78 #[serde(rename = "senderid")]
79 pub sender_id: Option<String>,
80}
81
82#[derive(Serialize)]
83struct RegisterRequest<'a> {
85 token: &'a str,
87
88 key: Option<&'a str>,
90}
91
92#[derive(Serialize)]
93struct UpdateRequest<'a> {
94 token: &'a str,
95}
96
97#[cfg_attr(test, mockall::automock)]
99pub trait Connection: Sized {
100 fn connect(options: PushConfiguration) -> Self;
102
103 fn register(
113 &self,
114 registration_id: &str,
115 app_server_key: &Option<String>,
116 ) -> error::Result<RegisterResponse>;
117
118 fn subscribe(
129 &self,
130 uaid: &str,
131 auth: &str,
132 registration_id: &str,
133 app_server_key: &Option<String>,
134 ) -> error::Result<SubscribeResponse>;
135
136 fn unsubscribe(&self, channel_id: &str, uaid: &str, auth: &str) -> error::Result<()>;
142
143 fn unsubscribe_all(&self, uaid: &str, auth: &str) -> error::Result<()>;
149
150 fn update(&self, new_token: &str, uaid: &str, auth: &str) -> error::Result<()>;
156
157 fn channel_list(&self, uaid: &str, auth: &str) -> error::Result<Vec<String>>;
165}
166
167pub struct ConnectHttp {
169 options: PushConfiguration,
170}
171
172impl ConnectHttp {
173 fn auth_headers(&self, auth: &str) -> error::Result<Headers> {
174 let mut headers = Headers::new();
175 headers
176 .insert(header_names::AUTHORIZATION, &*format!("webpush {}", auth))
177 .map_err(|e| error::PushError::CommunicationError(format!("Header error: {:?}", e)))?;
178
179 Ok(headers)
180 }
181
182 fn check_response_error(&self, response: &viaduct::Response) -> error::Result<()> {
183 #[derive(Deserialize)]
186 struct ResponseError {
187 pub errno: Option<u32>,
188 pub message: String,
189 }
190 if response.is_server_error() {
191 let response_error = response.json::<ResponseError>()?;
192 return Err(CommunicationServerError(format!(
193 "General Server Error: {}",
194 response_error.message
195 )));
196 }
197 if response.is_client_error() {
198 let response_error = response.json::<ResponseError>()?;
199 if response.status == status_codes::CONFLICT {
200 return Err(AlreadyRegisteredError);
201 }
202 if response.status == status_codes::GONE
203 && matches!(response_error.errno, Some(UAID_NOT_FOUND_ERRNO))
204 {
205 return Err(UAIDNotRecognizedError(response_error.message));
206 }
207 return Err(CommunicationError(format!(
208 "Unhandled client error {:?}",
209 response
210 )));
211 }
212 Ok(())
213 }
214
215 fn format_unsubscribe_url(&self, uaid: &str) -> error::Result<String> {
216 Ok(format!(
217 "{}://{}/v1/{}/{}/registration/{}",
218 &self.options.http_protocol,
219 &self.options.server_host,
220 &self.options.bridge_type,
221 &self.options.sender_id,
222 &uaid,
223 ))
224 }
225
226 fn send_subscription_request<T>(
227 &self,
228 url: Url,
229 headers: Headers,
230 registration_id: &str,
231 app_server_key: &Option<String>,
232 ) -> error::Result<T>
233 where
234 T: for<'a> Deserialize<'a>,
235 {
236 let body = RegisterRequest {
237 token: registration_id,
238 key: app_server_key.as_ref().map(|s| s.as_str()),
239 };
240
241 let response = Request::post(url).headers(headers).json(&body).send()?;
242 self.check_response_error(&response)?;
243 Ok(response.json()?)
244 }
245}
246
247impl Connection for ConnectHttp {
248 fn connect(options: PushConfiguration) -> ConnectHttp {
249 ConnectHttp { options }
250 }
251
252 fn register(
253 &self,
254 registration_id: &str,
255 app_server_key: &Option<String>,
256 ) -> error::Result<RegisterResponse> {
257 let url = format!(
258 "{}://{}/v1/{}/{}/registration",
259 &self.options.http_protocol,
260 &self.options.server_host,
261 &self.options.bridge_type,
262 &self.options.sender_id
263 );
264
265 let headers = Headers::new();
266
267 self.send_subscription_request(Url::parse(&url)?, headers, registration_id, app_server_key)
268 }
269
270 fn subscribe(
271 &self,
272 uaid: &str,
273 auth: &str,
274 registration_id: &str,
275 app_server_key: &Option<String>,
276 ) -> error::Result<SubscribeResponse> {
277 let url = format!(
278 "{}://{}/v1/{}/{}/registration/{}/subscription",
279 &self.options.http_protocol,
280 &self.options.server_host,
281 &self.options.bridge_type,
282 &self.options.sender_id,
283 uaid,
284 );
285
286 let headers = self.auth_headers(auth)?;
287
288 self.send_subscription_request(Url::parse(&url)?, headers, registration_id, app_server_key)
289 }
290
291 fn unsubscribe(&self, channel_id: &str, uaid: &str, auth: &str) -> error::Result<()> {
292 let url = format!(
293 "{}/subscription/{}",
294 self.format_unsubscribe_url(uaid)?,
295 channel_id
296 );
297 let response = Request::delete(Url::parse(&url)?)
298 .headers(self.auth_headers(auth)?)
299 .send()?;
300 info!("unsubscribed from {}: {}", url, response.status);
301 self.check_response_error(&response)?;
302 Ok(())
303 }
304
305 fn unsubscribe_all(&self, uaid: &str, auth: &str) -> error::Result<()> {
306 let url = self.format_unsubscribe_url(uaid)?;
307 let response = Request::delete(Url::parse(&url)?)
308 .headers(self.auth_headers(auth)?)
309 .send()?;
310 info!("unsubscribed from all via {}: {}", url, response.status);
311 self.check_response_error(&response)?;
312 Ok(())
313 }
314
315 fn update(&self, new_token: &str, uaid: &str, auth: &str) -> error::Result<()> {
316 let options = self.options.clone();
317 let url = format!(
318 "{}://{}/v1/{}/{}/registration/{}",
319 &options.http_protocol,
320 &options.server_host,
321 &options.bridge_type,
322 &options.sender_id,
323 uaid
324 );
325 let body = UpdateRequest { token: new_token };
326 let response = Request::put(Url::parse(&url)?)
327 .json(&body)
328 .headers(self.auth_headers(auth)?)
329 .send()?;
330 info!("update via {}: {}", url, response.status);
331 self.check_response_error(&response)?;
332 Ok(())
333 }
334
335 fn channel_list(&self, uaid: &str, auth: &str) -> error::Result<Vec<String>> {
336 #[derive(Deserialize, Debug)]
337 struct Payload {
338 uaid: String,
339 #[serde(rename = "channelIDs")]
340 channel_ids: Vec<String>,
341 }
342
343 let options = self.options.clone();
344
345 let url = format!(
346 "{}://{}/v1/{}/{}/registration/{}",
347 &options.http_protocol,
348 &options.server_host,
349 &options.bridge_type,
350 &options.sender_id,
351 &uaid,
352 );
353 let response = match Request::get(Url::parse(&url)?)
354 .headers(self.auth_headers(auth)?)
355 .send()
356 {
357 Ok(v) => v,
358 Err(e) => {
359 return Err(CommunicationServerError(format!(
360 "Could not fetch channel list: {}",
361 e
362 )));
363 }
364 };
365 self.check_response_error(&response)?;
366 let payload: Payload = response.json()?;
367 if payload.uaid != uaid {
368 return Err(CommunicationServerError(
369 "Invalid Response from server".to_string(),
370 ));
371 }
372 Ok(payload
373 .channel_ids
374 .iter()
375 .map(|s| Store::normalize_uuid(s))
376 .collect())
377 }
378}
379
380#[cfg(test)]
381mod test {
382 use crate::internal::config::Protocol;
383
384 use super::*;
385
386 use super::Connection;
387
388 use mockito::{mock, server_address};
389 use serde_json::json;
390
391 const DUMMY_CHID: &str = "deadbeef00000000decafbad00000000";
392 const DUMMY_CHID2: &str = "decafbad00000000deadbeef00000000";
393
394 const DUMMY_UAID: &str = "abad1dea00000000aabbccdd00000000";
395
396 const SENDER_ID: &str = "FakeSenderID";
398 const SECRET: &str = "SuP3rS1kRet";
399
400 #[test]
401 fn test_communications() {
402 viaduct_reqwest::use_reqwest_backend();
403 let config = PushConfiguration {
405 http_protocol: Protocol::Http,
406 server_host: server_address().to_string(),
407 sender_id: SENDER_ID.to_owned(),
408 ..Default::default()
409 };
410 {
412 let body = json!({
413 "uaid": DUMMY_UAID,
414 "channelID": DUMMY_CHID,
415 "endpoint": "https://example.com/update",
416 "senderid": SENDER_ID,
417 "secret": SECRET,
418 })
419 .to_string();
420 let ap_mock = mock("POST", &*format!("/v1/fcm/{}/registration", SENDER_ID))
421 .with_status(200)
422 .with_header("content-type", "application/json")
423 .with_body(body)
424 .create();
425 let conn = ConnectHttp::connect(config.clone());
426 let response = conn.register(SENDER_ID, &None).unwrap();
427 ap_mock.assert();
428 assert_eq!(response.uaid, DUMMY_UAID);
429 }
430 {
432 let body = json!({
433 "uaid": DUMMY_UAID,
434 "channelID": DUMMY_CHID,
435 "endpoint": "https://example.com/update",
436 "senderid": SENDER_ID,
437 "secret": SECRET,
438 })
439 .to_string();
440 let ap_mock = mock("POST", &*format!("/v1/fcm/{}/registration", SENDER_ID))
441 .with_status(200)
442 .with_header("content-type", "application/json")
443 .with_body(body)
444 .create();
445 let conn = ConnectHttp::connect(config.clone());
446 let response = conn.register(SENDER_ID, &None).unwrap();
447 ap_mock.assert();
448 assert_eq!(response.uaid, DUMMY_UAID);
449 assert_eq!(response.channel_id, DUMMY_CHID);
450 assert_eq!(response.endpoint, "https://example.com/update");
451
452 let body_2 = json!({
453 "uaid": DUMMY_UAID,
454 "channelID": DUMMY_CHID2,
455 "endpoint": "https://example.com/otherendpoint",
456 "senderid": SENDER_ID,
457 "secret": SECRET,
458 })
459 .to_string();
460 let ap_mock_2 = mock(
461 "POST",
462 &*format!(
463 "/v1/fcm/{}/registration/{}/subscription",
464 SENDER_ID, DUMMY_UAID
465 ),
466 )
467 .with_status(200)
468 .with_header("content-type", "application/json")
469 .with_body(body_2)
470 .create();
471
472 let response = conn
473 .subscribe(DUMMY_UAID, SECRET, SENDER_ID, &None)
474 .unwrap();
475 ap_mock_2.assert();
476 assert_eq!(response.endpoint, "https://example.com/otherendpoint");
477 }
478 {
480 let ap_mock = mock(
481 "DELETE",
482 &*format!(
483 "/v1/fcm/{}/registration/{}/subscription/{}",
484 SENDER_ID, DUMMY_UAID, DUMMY_CHID
485 ),
486 )
487 .match_header("authorization", format!("webpush {}", SECRET).as_str())
488 .with_status(200)
489 .with_header("content-type", "application/json")
490 .with_body("{}")
491 .create();
492 let conn = ConnectHttp::connect(config.clone());
493 conn.unsubscribe(DUMMY_CHID, DUMMY_UAID, SECRET).unwrap();
494 ap_mock.assert();
495 }
496 {
498 let ap_mock = mock(
499 "DELETE",
500 &*format!("/v1/fcm/{}/registration/{}", SENDER_ID, DUMMY_UAID),
501 )
502 .match_header("authorization", format!("webpush {}", SECRET).as_str())
503 .with_status(200)
504 .with_header("content-type", "application/json")
505 .with_body("{}")
506 .create();
507 let conn = ConnectHttp::connect(config.clone());
508 conn.unsubscribe_all(DUMMY_UAID, SECRET).unwrap();
509 ap_mock.assert();
510 }
511 {
513 let ap_mock = mock(
514 "PUT",
515 &*format!("/v1/fcm/{}/registration/{}", SENDER_ID, DUMMY_UAID),
516 )
517 .match_header("authorization", format!("webpush {}", SECRET).as_str())
518 .with_status(200)
519 .with_header("content-type", "application/json")
520 .with_body("{}")
521 .create();
522 let conn = ConnectHttp::connect(config.clone());
523
524 conn.update("NewTokenValue", DUMMY_UAID, SECRET).unwrap();
525 ap_mock.assert();
526 }
527 {
529 let body_cl_success = json!({
530 "uaid": DUMMY_UAID,
531 "channelIDs": [DUMMY_CHID],
532 })
533 .to_string();
534 let ap_mock = mock(
535 "GET",
536 &*format!("/v1/fcm/{}/registration/{}", SENDER_ID, DUMMY_UAID),
537 )
538 .match_header("authorization", format!("webpush {}", SECRET).as_str())
539 .with_status(200)
540 .with_header("content-type", "application/json")
541 .with_body(body_cl_success)
542 .create();
543 let conn = ConnectHttp::connect(config);
544 let response = conn.channel_list(DUMMY_UAID, SECRET).unwrap();
545 ap_mock.assert();
546 assert!(response == [DUMMY_CHID.to_owned()]);
547 }
548 {
551 let config = PushConfiguration {
552 http_protocol: Protocol::Http,
553 server_host: server_address().to_string(),
554 sender_id: SENDER_ID.to_owned(),
555 ..Default::default()
556 };
557 let body = json!({
560 "code": status_codes::CONFLICT,
561 "errno": 999u32,
562 "error": "",
563 "message": "Already registered"
564
565 })
566 .to_string();
567 let ap_mock = mock("POST", &*format!("/v1/fcm/{}/registration", SENDER_ID))
568 .with_status(status_codes::CONFLICT as usize)
569 .with_header("content-type", "application/json")
570 .with_body(body)
571 .create();
572 let conn = ConnectHttp::connect(config);
573 let err = conn.register(SENDER_ID, &None).unwrap_err();
574 ap_mock.assert();
575 assert!(matches!(err, error::PushError::AlreadyRegisteredError));
576 }
577 }
578}