viaduct/
headers.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4pub use name::{HeaderName, InvalidHeaderName};
5use std::collections::HashMap;
6use std::iter::FromIterator;
7use std::str::FromStr;
8mod name;
9
10/// A single header. Headers have a name (case insensitive) and a value. The
11/// character set for header and values are both restrictive.
12/// - Names must only contain a-zA-Z0-9 and and ('!' | '#' | '$' | '%' | '&' |
13///   '\'' | '*' | '+' | '-' | '.' | '^' | '_' | '`' | '|' | '~') characters
14///   (the field-name token production defined at
15///   https://tools.ietf.org/html/rfc7230#section-3.2).
16///   For request headers, we expect these to all be specified statically,
17///   and so we panic if you provide an invalid one. (For response headers, we
18///   ignore headers with invalid names, but emit a warning).
19///
20///   Header names are case insensitive, and we have several pre-defined ones in
21///   the [`header_names`] module.
22///
23/// - Values may only contain printable ascii characters, and may not contain
24///   \r or \n. Strictly speaking, HTTP is more flexible for header values,
25///   however we don't need to support binary header values, and so we do not.
26///
27/// Note that typically you should not interact with this directly, and instead
28/// use the methods on [`Request`] or [`Headers`] to manipulate these.
29#[derive(Clone, Debug, PartialEq, PartialOrd, Hash, Eq, Ord)]
30pub struct Header {
31    pub name: HeaderName,
32    pub value: String,
33}
34
35// Trim `s` without copying if it can be avoided.
36fn trim_string<S: AsRef<str> + Into<String>>(s: S) -> String {
37    let sr = s.as_ref();
38    let trimmed = sr.trim();
39    if sr.len() != trimmed.len() {
40        trimmed.into()
41    } else {
42        s.into()
43    }
44}
45
46fn is_valid_header_value(value: &str) -> bool {
47    value.bytes().all(|b| (32..127).contains(&b) || b == b'\t')
48}
49
50impl Header {
51    pub fn new<Name, Value>(name: Name, value: Value) -> Result<Self, crate::ViaductError>
52    where
53        Name: Into<HeaderName>,
54        Value: AsRef<str> + Into<String>,
55    {
56        let name = name.into();
57        let value = trim_string(value);
58        if !is_valid_header_value(&value) {
59            return Err(crate::ViaductError::RequestHeaderError(name.to_string()));
60        }
61        Ok(Self { name, value })
62    }
63
64    pub fn new_unchecked<Value>(name: HeaderName, value: Value) -> Self
65    where
66        Value: AsRef<str> + Into<String>,
67    {
68        Self {
69            name,
70            value: value.into(),
71        }
72    }
73
74    #[inline]
75    pub fn name(&self) -> &HeaderName {
76        &self.name
77    }
78
79    #[inline]
80    pub fn value(&self) -> &str {
81        &self.value
82    }
83
84    #[inline]
85    fn set_value<V: AsRef<str>>(&mut self, s: V) -> Result<(), crate::ViaductError> {
86        let value = s.as_ref();
87        if !is_valid_header_value(value) {
88            Err(crate::ViaductError::RequestHeaderError(
89                self.name.to_string(),
90            ))
91        } else {
92            self.value.clear();
93            self.value.push_str(s.as_ref().trim());
94            Ok(())
95        }
96    }
97}
98
99impl std::fmt::Display for Header {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        write!(f, "{}: {}", self.name, self.value)
102    }
103}
104
105/// A list of headers.
106#[derive(Clone, Debug, PartialEq, Eq, Default)]
107pub struct Headers {
108    headers: Vec<Header>,
109}
110
111impl Headers {
112    /// Initialize an empty list of headers.
113    #[inline]
114    pub fn new() -> Self {
115        Default::default()
116    }
117
118    /// Create headers from a HashMap of name-value pairs
119    ///
120    /// # Errors
121    /// Returns an error if any header name or value is invalid
122    pub fn try_from_hashmap(map: HashMap<String, String>) -> Result<Self, crate::ViaductError> {
123        let mut headers = Headers::new();
124        for (name, value) in map {
125            headers.insert(name, value)?;
126        }
127        Ok(headers)
128    }
129
130    /// Initialize an empty list of headers backed by a vector with the provided
131    /// capacity.
132    pub fn with_capacity(c: usize) -> Self {
133        Self {
134            headers: Vec::with_capacity(c),
135        }
136    }
137
138    /// Convert this list of headers to a Vec<Header>
139    #[inline]
140    pub fn into_vec(self) -> Vec<Header> {
141        self.headers
142    }
143
144    /// Returns the number of headers.
145    #[inline]
146    pub fn len(&self) -> usize {
147        self.headers.len()
148    }
149
150    /// Returns true if `len()` is zero.
151    #[inline]
152    pub fn is_empty(&self) -> bool {
153        self.headers.is_empty()
154    }
155    /// Clear this set of headers.
156    #[inline]
157    pub fn clear(&mut self) {
158        self.headers.clear();
159    }
160
161    /// Insert or update a new header.
162    ///
163    /// This returns an error if you attempt to specify a header with an
164    /// invalid value (values must be printable ASCII and may not contain
165    /// \r or \n)
166    ///
167    /// ## Example
168    /// ```
169    /// # use viaduct::Headers;
170    /// # fn main() -> Result<(), viaduct::ViaductError> {
171    /// let mut h = Headers::new();
172    /// h.insert("My-Cool-Header", "example")?;
173    /// assert_eq!(h.get("My-Cool-Header"), Some("example"));
174    ///
175    /// // Note: names are sensitive
176    /// assert_eq!(h.get("my-cool-header"), Some("example"));
177    ///
178    /// // Also note, constants for headers are in `viaduct::header_names`, and
179    /// // you can chain the result of this function.
180    /// h.insert(viaduct::header_names::CONTENT_TYPE, "something...")?
181    ///  .insert("Something-Else", "etc")?;
182    /// # Ok(())
183    /// # }
184    /// ```
185    pub fn insert<N, V>(&mut self, name: N, value: V) -> Result<&mut Self, crate::ViaductError>
186    where
187        N: Into<HeaderName> + PartialEq<HeaderName>,
188        V: Into<String> + AsRef<str>,
189    {
190        if let Some(entry) = self.headers.iter_mut().find(|h| name == h.name) {
191            entry.set_value(value)?;
192        } else {
193            self.headers.push(Header::new(name, value)?);
194        }
195        Ok(self)
196    }
197
198    /// Insert the provided header unless a header is already specified.
199    /// Mostly used internally, e.g. to set "Content-Type: application/json"
200    /// in `Request::json()` unless it has been set specifically.
201    pub fn insert_if_missing<N, V>(
202        &mut self,
203        name: N,
204        value: V,
205    ) -> Result<&mut Self, crate::ViaductError>
206    where
207        N: Into<HeaderName> + PartialEq<HeaderName>,
208        V: Into<String> + AsRef<str>,
209    {
210        if !self.headers.iter_mut().any(|h| name == h.name) {
211            self.headers.push(Header::new(name, value)?);
212        }
213        Ok(self)
214    }
215
216    /// Insert or update a header directly. Typically you will want to use
217    /// `insert` over this, as it performs less work if the header needs
218    /// updating instead of insertion.
219    pub fn insert_header(&mut self, new: Header) -> &mut Self {
220        if let Some(entry) = self.headers.iter_mut().find(|h| h.name == new.name) {
221            entry.value = new.value;
222        } else {
223            self.headers.push(new);
224        }
225        self
226    }
227
228    /// Add all the headers in the provided iterator to this list of headers.
229    pub fn extend<I>(&mut self, iter: I) -> &mut Self
230    where
231        I: IntoIterator<Item = Header>,
232    {
233        let it = iter.into_iter();
234        self.headers.reserve(it.size_hint().0);
235        for h in it {
236            self.insert_header(h);
237        }
238        self
239    }
240
241    /// Add all the headers in the provided iterator, unless any of them are Err.
242    pub fn try_extend<I, E>(&mut self, iter: I) -> Result<&mut Self, E>
243    where
244        I: IntoIterator<Item = Result<Header, E>>,
245    {
246        // Not the most efficient but avoids leaving us in an unspecified state
247        // if one returns Err.
248        self.extend(iter.into_iter().collect::<Result<Vec<_>, E>>()?);
249        Ok(self)
250    }
251
252    /// Get the header object with the requested name. Usually, you will
253    /// want to use `get()` or `get_as::<T>()` instead.
254    pub fn get_header<S>(&self, name: S) -> Option<&Header>
255    where
256        S: PartialEq<HeaderName>,
257    {
258        self.headers.iter().find(|h| name == h.name)
259    }
260
261    /// Get the value of the header with the provided name.
262    ///
263    /// See also `get_as`.
264    ///
265    /// ## Example
266    /// ```
267    /// # use viaduct::{Headers, header_names::CONTENT_TYPE};
268    /// # fn main() -> Result<(), viaduct::ViaductError> {
269    /// let mut h = Headers::new();
270    /// h.insert(CONTENT_TYPE, "application/json")?;
271    /// assert_eq!(h.get(CONTENT_TYPE), Some("application/json"));
272    /// assert_eq!(h.get("Something-Else"), None);
273    /// # Ok(())
274    /// # }
275    /// ```
276    pub fn get<S>(&self, name: S) -> Option<&str>
277    where
278        S: PartialEq<HeaderName>,
279    {
280        self.get_header(name).map(|h| h.value.as_str())
281    }
282
283    /// Get the value of the header with the provided name, and
284    /// attempt to parse it using [`std::str::FromStr`].
285    ///
286    /// - If the header is missing, it returns None.
287    /// - If the header is present but parsing failed, returns
288    ///   `Some(Err(<error returned by parsing>))`.
289    /// - Otherwise, returns `Some(Ok(result))`.
290    ///
291    /// Note that if `Option<Result<T, E>>` is inconvenient for you,
292    /// and you wish this returned `Result<Option<T>, E>`, you may use
293    /// the built-in `transpose()` method to convert between them.
294    ///
295    /// ```
296    /// # use viaduct::Headers;
297    /// # fn main() -> Result<(), viaduct::ViaductError> {
298    /// let mut h = Headers::new();
299    /// h.insert("Example", "1234")?.insert("Illegal", "abcd")?;
300    /// let v: Option<Result<i64, _>> = h.get_as("Example");
301    /// assert_eq!(v, Some(Ok(1234)));
302    /// assert_eq!(h.get_as::<i64, _>("Example"), Some(Ok(1234)));
303    /// assert_eq!(h.get_as::<i64, _>("Illegal"), Some("abcd".parse::<i64>()));
304    /// assert_eq!(h.get_as::<i64, _>("Something-Else"), None);
305    /// # Ok(())
306    /// # }
307    /// ```
308    pub fn get_as<T, S>(&self, name: S) -> Option<Result<T, <T as FromStr>::Err>>
309    where
310        T: FromStr,
311        S: PartialEq<HeaderName>,
312    {
313        self.get(name).map(str::parse)
314    }
315    /// Get the value of the header with the provided name, and
316    /// attempt to parse it using [`std::str::FromStr`].
317    ///
318    /// This is a variant of `get_as` that returns None on error,
319    /// intended to be used for cases where missing and invalid
320    /// headers should be treated the same. (With `get_as` this
321    /// requires `h.get_as(...).and_then(|r| r.ok())`, which is
322    /// somewhat opaque.
323    pub fn try_get<T, S>(&self, name: S) -> Option<T>
324    where
325        T: FromStr,
326        S: PartialEq<HeaderName>,
327    {
328        self.get(name).and_then(|val| val.parse::<T>().ok())
329    }
330
331    /// Get an iterator over the headers in no particular order.
332    ///
333    /// Note that we also implement IntoIterator.
334    pub fn iter(&self) -> <&Headers as IntoIterator>::IntoIter {
335        self.into_iter()
336    }
337}
338
339impl std::iter::IntoIterator for Headers {
340    type IntoIter = <Vec<Header> as IntoIterator>::IntoIter;
341    type Item = Header;
342    fn into_iter(self) -> Self::IntoIter {
343        self.headers.into_iter()
344    }
345}
346
347impl<'a> std::iter::IntoIterator for &'a Headers {
348    type IntoIter = <&'a [Header] as IntoIterator>::IntoIter;
349    type Item = &'a Header;
350    fn into_iter(self) -> Self::IntoIter {
351        self.headers[..].iter()
352    }
353}
354
355impl FromIterator<Header> for Headers {
356    fn from_iter<T>(iter: T) -> Self
357    where
358        T: IntoIterator<Item = Header>,
359    {
360        let mut v = iter.into_iter().collect::<Vec<Header>>();
361        v.sort_by(|a, b| a.name.cmp(&b.name));
362        v.reverse();
363        v.dedup_by(|a, b| a.name == b.name);
364        v.into()
365    }
366}
367
368impl From<Vec<Header>> for Headers {
369    fn from(headers: Vec<Header>) -> Self {
370        Self { headers }
371    }
372}
373
374#[allow(clippy::implicit_hasher)] // https://github.com/rust-lang/rust-clippy/issues/3899
375impl From<Headers> for HashMap<String, String> {
376    fn from(headers: Headers) -> HashMap<String, String> {
377        headers
378            .into_iter()
379            .map(|h| (String::from(h.name), h.value))
380            .collect()
381    }
382}
383
384pub mod consts {
385    use super::name::HeaderName;
386    macro_rules! def_header_consts {
387        ($(($NAME:ident, $string:literal)),* $(,)?) => {
388            $(pub const $NAME: HeaderName = HeaderName(std::borrow::Cow::Borrowed($string));)*
389        };
390    }
391
392    macro_rules! headers {
393        ($(($NAME:ident, $string:literal)),* $(,)?) => {
394            def_header_consts!($(($NAME, $string)),*);
395            // Unused except for tests.
396            const _ALL: &[&str] = &[$($string),*];
397        };
398    }
399
400    // Predefined header names, for convenience.
401    // Feel free to add to these.
402    headers!(
403        (ACCEPT_ENCODING, "accept-encoding"),
404        (ACCEPT, "accept"),
405        (AUTHORIZATION, "authorization"),
406        (CACHE_CONTROL, "cache-control"),
407        (CONTENT_TYPE, "content-type"),
408        (ETAG, "etag"),
409        (IF_NONE_MATCH, "if-none-match"),
410        (USER_AGENT, "user-agent"),
411        // non-standard, but it's convenient to have these.
412        (RETRY_AFTER, "retry-after"),
413        (X_IF_UNMODIFIED_SINCE, "x-if-unmodified-since"),
414        (X_KEYID, "x-keyid"),
415        (X_LAST_MODIFIED, "x-last-modified"),
416        (X_TIMESTAMP, "x-timestamp"),
417        (X_WEAVE_NEXT_OFFSET, "x-weave-next-offset"),
418        (X_WEAVE_RECORDS, "x-weave-records"),
419        (X_WEAVE_TIMESTAMP, "x-weave-timestamp"),
420        (X_WEAVE_BACKOFF, "x-weave-backoff"),
421    );
422
423    #[test]
424    fn test_predefined() {
425        for &name in _ALL {
426            assert!(
427                HeaderName::new(name).is_ok(),
428                "Invalid header name in predefined header constants: {}",
429                name
430            );
431            assert_eq!(
432                name.to_ascii_lowercase(),
433                name,
434                "Non-lowercase name in predefined header constants: {}",
435                name
436            );
437        }
438    }
439}