viaduct/
lib.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5#![allow(unknown_lints)]
6#![warn(rust_2018_idioms)]
7
8use url::Url;
9#[macro_use]
10mod headers;
11
12mod backend;
13mod client;
14pub mod error;
15mod new_backend;
16#[cfg(feature = "ohttp")]
17pub mod ohttp;
18#[cfg(feature = "ohttp")]
19mod ohttp_client;
20pub mod settings;
21pub use error::*;
22// reexport logging helpers.
23pub use error_support::{debug, error, info, trace, warn};
24
25pub use backend::{note_backend, set_backend, Backend as OldBackend};
26pub use client::{Client, ClientSettings};
27pub use headers::{consts as header_names, Header, HeaderName, Headers, InvalidHeaderName};
28pub use new_backend::{init_backend, Backend};
29#[cfg(feature = "ohttp")]
30pub use ohttp::{clear_ohttp_channels, configure_ohttp_channel, list_ohttp_channels, OhttpConfig};
31pub use settings::{allow_android_emulator_loopback, GLOBAL_SETTINGS};
32
33#[allow(clippy::derive_partial_eq_without_eq)]
34pub(crate) mod msg_types {
35    include!("mozilla.appservices.httpconfig.protobuf.rs");
36}
37
38/// HTTP Methods.
39///
40/// The supported methods are the limited to what's supported by android-components.
41#[derive(Clone, Debug, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, uniffi::Enum)]
42#[repr(u8)]
43pub enum Method {
44    Get,
45    Head,
46    Post,
47    Put,
48    Delete,
49    Connect,
50    Options,
51    Trace,
52    Patch,
53}
54
55impl Method {
56    pub fn as_str(self) -> &'static str {
57        match self {
58            Method::Get => "GET",
59            Method::Head => "HEAD",
60            Method::Post => "POST",
61            Method::Put => "PUT",
62            Method::Delete => "DELETE",
63            Method::Connect => "CONNECT",
64            Method::Options => "OPTIONS",
65            Method::Trace => "TRACE",
66            Method::Patch => "PATCH",
67        }
68    }
69}
70
71impl std::fmt::Display for Method {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.write_str(self.as_str())
74    }
75}
76
77#[must_use = "`Request`'s \"builder\" functions take by move, not by `&mut self`"]
78#[derive(Clone, uniffi::Record)]
79pub struct Request {
80    pub method: Method,
81    pub url: Url,
82    pub headers: Headers,
83    pub body: Option<Vec<u8>>,
84}
85
86impl Request {
87    /// Construct a new request to the given `url` using the given `method`.
88    /// Note that the request is not made until `send()` is called.
89    pub fn new(method: Method, url: Url) -> Self {
90        Self {
91            method,
92            url,
93            headers: Headers::new(),
94            body: None,
95        }
96    }
97
98    pub fn send(self) -> Result<Response, ViaductError> {
99        crate::backend::send(self)
100    }
101
102    /// Alias for `Request::new(Method::Get, url)`, for convenience.
103    pub fn get(url: Url) -> Self {
104        Self::new(Method::Get, url)
105    }
106
107    /// Alias for `Request::new(Method::Patch, url)`, for convenience.
108    pub fn patch(url: Url) -> Self {
109        Self::new(Method::Patch, url)
110    }
111
112    /// Alias for `Request::new(Method::Post, url)`, for convenience.
113    pub fn post(url: Url) -> Self {
114        Self::new(Method::Post, url)
115    }
116
117    /// Alias for `Request::new(Method::Put, url)`, for convenience.
118    pub fn put(url: Url) -> Self {
119        Self::new(Method::Put, url)
120    }
121
122    /// Alias for `Request::new(Method::Delete, url)`, for convenience.
123    pub fn delete(url: Url) -> Self {
124        Self::new(Method::Delete, url)
125    }
126
127    /// Append the provided query parameters to the URL
128    ///
129    /// ## Example
130    /// ```
131    /// # use viaduct::{Request, header_names};
132    /// # use url::Url;
133    /// let some_url = url::Url::parse("https://www.example.com/xyz").unwrap();
134    ///
135    /// let req = Request::post(some_url).query(&[("a", "1234"), ("b", "qwerty")]);
136    /// assert_eq!(req.url.as_str(), "https://www.example.com/xyz?a=1234&b=qwerty");
137    ///
138    /// // This appends to the query query instead of replacing `a`.
139    /// let req = req.query(&[("a", "5678")]);
140    /// assert_eq!(req.url.as_str(), "https://www.example.com/xyz?a=1234&b=qwerty&a=5678");
141    /// ```
142    pub fn query(mut self, pairs: &[(&str, &str)]) -> Self {
143        let mut append_to = self.url.query_pairs_mut();
144        for (k, v) in pairs {
145            append_to.append_pair(k, v);
146        }
147        drop(append_to);
148        self
149    }
150
151    /// Set the query string of the URL. Note that `req.set_query(None)` will
152    /// clear the query.
153    ///
154    /// See also `Request::query` which appends a slice of query pairs, which is
155    /// typically more ergonomic when usable.
156    ///
157    /// ## Example
158    /// ```
159    /// # use viaduct::{Request, header_names};
160    /// # use url::Url;
161    /// let some_url = url::Url::parse("https://www.example.com/xyz").unwrap();
162    ///
163    /// let req = Request::post(some_url).set_query("a=b&c=d");
164    /// assert_eq!(req.url.as_str(), "https://www.example.com/xyz?a=b&c=d");
165    ///
166    /// let req = req.set_query(None);
167    /// assert_eq!(req.url.as_str(), "https://www.example.com/xyz");
168    /// ```
169    pub fn set_query<'a, Q: Into<Option<&'a str>>>(mut self, query: Q) -> Self {
170        self.url.set_query(query.into());
171        self
172    }
173
174    /// Add all the provided headers to the list of headers to send with this
175    /// request.
176    pub fn headers<I>(mut self, to_add: I) -> Self
177    where
178        I: IntoIterator<Item = Header>,
179    {
180        self.headers.extend(to_add);
181        self
182    }
183
184    /// Add the provided header to the list of headers to send with this request.
185    ///
186    /// This returns `Err` if `val` contains characters that may not appear in
187    /// the body of a header.
188    ///
189    /// ## Example
190    /// ```
191    /// # use viaduct::{Request, header_names};
192    /// # use url::Url;
193    /// # fn main() -> Result<(), viaduct::ViaductError> {
194    /// # let some_url = url::Url::parse("https://www.example.com").unwrap();
195    /// Request::post(some_url)
196    ///     .header(header_names::CONTENT_TYPE, "application/json")?
197    ///     .header("My-Header", "Some special value")?;
198    /// // ...
199    /// # Ok(())
200    /// # }
201    /// ```
202    pub fn header<Name, Val>(mut self, name: Name, val: Val) -> Result<Self, crate::ViaductError>
203    where
204        Name: Into<HeaderName> + PartialEq<HeaderName>,
205        Val: Into<String> + AsRef<str>,
206    {
207        self.headers.insert(name, val)?;
208        Ok(self)
209    }
210
211    /// Set this request's body.
212    pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
213        self.body = Some(body.into());
214        self
215    }
216
217    /// Set body to the result of serializing `val`, and, unless it has already
218    /// been set, set the Content-Type header to "application/json".
219    ///
220    /// Note: This panics if serde_json::to_vec fails. This can only happen
221    /// in a couple cases:
222    ///
223    /// 1. Trying to serialize a map with non-string keys.
224    /// 2. We wrote a custom serializer that fails.
225    ///
226    /// Neither of these are things we do. If they happen, it seems better for
227    /// this to fail hard with an easy to track down panic, than for e.g. `sync`
228    /// to fail with a JSON parse error (which we'd probably attribute to
229    /// corrupt data on the server, or something).
230    pub fn json<T: ?Sized + serde::Serialize>(mut self, val: &T) -> Self {
231        self.body =
232            Some(serde_json::to_vec(val).expect("Rust component bug: serde_json::to_vec failure"));
233        self.headers
234            .insert_if_missing(header_names::CONTENT_TYPE, "application/json")
235            .unwrap(); // We know this has to be valid.
236        self
237    }
238}
239
240// Hand-written `Debug` impl for nicer logging
241impl std::fmt::Debug for Request {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        f.debug_struct("Request")
244            .field("method", &self.method)
245            .field("url", &self.url.to_string())
246            .field("headers", &self.headers)
247            .field(
248                "body",
249                &self.body.as_ref().map(|body| String::from_utf8_lossy(body)),
250            )
251            .finish()
252    }
253}
254
255/// A response from the server.
256#[derive(Clone, uniffi::Record)]
257pub struct Response {
258    /// The method used to request this response.
259    pub request_method: Method,
260    /// The URL of this response.
261    pub url: Url,
262    /// The HTTP Status code of this response.
263    pub status: u16,
264    /// The headers returned with this response.
265    pub headers: Headers,
266    /// The body of the response.
267    pub body: Vec<u8>,
268}
269
270impl Response {
271    /// Parse the body as JSON.
272    pub fn json<'a, T>(&'a self) -> Result<T, serde_json::Error>
273    where
274        T: serde::Deserialize<'a>,
275    {
276        serde_json::from_slice(&self.body)
277    }
278
279    /// Get the body as a string. Assumes UTF-8 encoding. Any non-utf8 bytes
280    /// are replaced with the replacement character.
281    pub fn text(&self) -> std::borrow::Cow<'_, str> {
282        String::from_utf8_lossy(&self.body)
283    }
284
285    /// Returns true if the status code is in the interval `[200, 300)`.
286    #[inline]
287    pub fn is_success(&self) -> bool {
288        status_codes::is_success_code(self.status)
289    }
290
291    /// Returns true if the status code is in the interval `[500, 600)`.
292    #[inline]
293    pub fn is_server_error(&self) -> bool {
294        status_codes::is_server_error_code(self.status)
295    }
296
297    /// Returns true if the status code is in the interval `[400, 500)`.
298    #[inline]
299    pub fn is_client_error(&self) -> bool {
300        status_codes::is_client_error_code(self.status)
301    }
302
303    /// Returns an [`UnexpectedStatus`] error if `self.is_success()` is false,
304    /// otherwise returns `Ok(self)`.
305    #[inline]
306    pub fn require_success(self) -> Result<Self, UnexpectedStatus> {
307        if self.is_success() {
308            Ok(self)
309        } else {
310            Err(UnexpectedStatus {
311                method: self.request_method,
312                // XXX We probably should try and sanitize this. Replace the user id
313                // if it's a sync token server URL, for example.
314                url: self.url,
315                status: self.status,
316            })
317        }
318    }
319}
320
321// Hand-written `Debug` impl for nicer logging
322impl std::fmt::Debug for Response {
323    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324        f.debug_struct("Response")
325            .field("request_method", &self.request_method)
326            .field("url", &self.url.to_string())
327            .field("status", &self.status)
328            .field("headers", &self.headers)
329            .field("body", &String::from_utf8_lossy(&self.body))
330            .finish()
331    }
332}
333
334/// A module containing constants for all HTTP status codes.
335pub mod status_codes {
336
337    /// Is it a 2xx status?
338    #[inline]
339    pub fn is_success_code(c: u16) -> bool {
340        (200..300).contains(&c)
341    }
342
343    /// Is it a 4xx error?
344    #[inline]
345    pub fn is_client_error_code(c: u16) -> bool {
346        (400..500).contains(&c)
347    }
348
349    /// Is it a 5xx error?
350    #[inline]
351    pub fn is_server_error_code(c: u16) -> bool {
352        (500..600).contains(&c)
353    }
354
355    macro_rules! define_status_codes {
356        ($(($val:expr, $NAME:ident)),* $(,)?) => {
357            $(pub const $NAME: u16 = $val;)*
358        };
359    }
360    // From https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html
361    define_status_codes![
362        (100, CONTINUE),
363        (101, SWITCHING_PROTOCOLS),
364        // 2xx
365        (200, OK),
366        (201, CREATED),
367        (202, ACCEPTED),
368        (203, NONAUTHORITATIVE_INFORMATION),
369        (204, NO_CONTENT),
370        (205, RESET_CONTENT),
371        (206, PARTIAL_CONTENT),
372        // 3xx
373        (300, MULTIPLE_CHOICES),
374        (301, MOVED_PERMANENTLY),
375        (302, FOUND),
376        (303, SEE_OTHER),
377        (304, NOT_MODIFIED),
378        (305, USE_PROXY),
379        // no 306
380        (307, TEMPORARY_REDIRECT),
381        // 4xx
382        (400, BAD_REQUEST),
383        (401, UNAUTHORIZED),
384        (402, PAYMENT_REQUIRED),
385        (403, FORBIDDEN),
386        (404, NOT_FOUND),
387        (405, METHOD_NOT_ALLOWED),
388        (406, NOT_ACCEPTABLE),
389        (407, PROXY_AUTHENTICATION_REQUIRED),
390        (408, REQUEST_TIMEOUT),
391        (409, CONFLICT),
392        (410, GONE),
393        (411, LENGTH_REQUIRED),
394        (412, PRECONDITION_FAILED),
395        (413, REQUEST_ENTITY_TOO_LARGE),
396        (414, REQUEST_URI_TOO_LONG),
397        (415, UNSUPPORTED_MEDIA_TYPE),
398        (416, REQUESTED_RANGE_NOT_SATISFIABLE),
399        (417, EXPECTATION_FAILED),
400        (429, TOO_MANY_REQUESTS),
401        // 5xx
402        (500, INTERNAL_SERVER_ERROR),
403        (501, NOT_IMPLEMENTED),
404        (502, BAD_GATEWAY),
405        (503, SERVICE_UNAVAILABLE),
406        (504, GATEWAY_TIMEOUT),
407        (505, HTTP_VERSION_NOT_SUPPORTED),
408    ];
409}
410
411pub fn parse_url(url: &str) -> Result<Url, ViaductError> {
412    Ok(Url::parse(url)?)
413}
414
415// Rename `Url` to `ViaductUrl` to avoid name conflicts on Swift
416pub type ViaductUrl = Url;
417
418uniffi::custom_type!(ViaductUrl, String, {
419    remote,
420    try_lift: |val| Ok(ViaductUrl::parse(&val)?),
421    lower: |obj| obj.into(),
422});
423
424uniffi::custom_type!(Headers, std::collections::HashMap<String, String>, {
425    remote,
426    try_lift: |map| {
427        Ok(map.into_iter()
428            .map(|(name, value)| Header::new(name, value))
429            .collect::<Result<Vec<Header>>>()?
430            .into()
431        )
432    },
433    lower: |headers| headers.into(),
434});
435
436uniffi::setup_scaffolding!("viaduct");
437
438/// Send a request through an OHTTP channel.
439///
440/// This encrypts the request and routes it through the configured OHTTP
441/// relay/gateway for the specified channel.
442///
443/// # Arguments
444/// * `request` - The request to send
445/// * `channel` - The name of the OHTTP channel to use (e.g., "merino")
446///
447/// # Example (Kotlin)
448/// ```kotlin
449/// val response = sendOhttpRequest(
450///     Request(
451///         method = Method.GET,
452///         url = "https://example.com/api",
453///         headers = mapOf("Accept" to "application/json"),
454///         body = null
455///     ),
456///     "merino"
457/// )
458/// ```
459#[cfg(feature = "ohttp")]
460#[uniffi::export]
461pub async fn send_ohttp_request(request: Request, channel: String) -> Result<Response> {
462    let settings = crate::ClientSettings::default();
463    crate::ohttp::process_ohttp_request(request, &channel, settings).await
464}