as_ohttp_client/
lib.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5extern crate rusqlite;
6
7use parking_lot::Mutex;
8use std::collections::HashMap;
9
10#[derive(Debug, thiserror::Error)]
11pub enum OhttpError {
12    #[error("Failed to fetch encryption key")]
13    KeyFetchFailed,
14
15    #[error("OHTTP key config is malformed")]
16    MalformedKeyConfig,
17
18    #[error("Unsupported OHTTP encryption algorithm")]
19    UnsupportedKeyConfig,
20
21    #[error("OhttpSession is in invalid state")]
22    InvalidSession,
23
24    #[error("Network errors communicating with Relay / Gateway")]
25    RelayFailed,
26
27    #[error("Cannot encode message as BHTTP/OHTTP")]
28    CannotEncodeMessage,
29
30    #[error("Cannot decode OHTTP/BHTTP message")]
31    MalformedMessage,
32
33    #[error("Duplicate HTTP response headers")]
34    DuplicateHeaders,
35}
36
37#[derive(Default)]
38enum ExchangeState {
39    #[default]
40    Invalid,
41    Request(ohttp::ClientRequest),
42    Response(ohttp::ClientResponse),
43}
44
45pub struct OhttpSession {
46    state: Mutex<ExchangeState>,
47}
48
49pub struct OhttpResponse {
50    status_code: u16,
51    headers: HashMap<String, String>,
52    payload: Vec<u8>,
53}
54
55/// Transform the headers from a BHTTP message into a HashMap for use from Swift
56/// later. If there are duplicate errors, we currently raise an error.
57fn headers_to_map(message: &bhttp::Message) -> Result<HashMap<String, String>, OhttpError> {
58    let mut headers = HashMap::new();
59
60    for field in message.header().iter() {
61        if headers
62            .insert(
63                std::str::from_utf8(field.name())
64                    .map_err(|_| OhttpError::MalformedMessage)?
65                    .into(),
66                std::str::from_utf8(field.value())
67                    .map_err(|_| OhttpError::MalformedMessage)?
68                    .into(),
69            )
70            .is_some()
71        {
72            return Err(OhttpError::DuplicateHeaders);
73        }
74    }
75
76    Ok(headers)
77}
78
79impl OhttpSession {
80    /// Create a new encryption session for use with specific key configuration
81    pub fn new(config: &[u8]) -> Result<Self, OhttpError> {
82        ohttp::init();
83
84        let request = ohttp::ClientRequest::from_encoded_config(config).map_err(|e| match e {
85            ohttp::Error::Unsupported => OhttpError::UnsupportedKeyConfig,
86            _ => OhttpError::MalformedKeyConfig,
87        })?;
88
89        let state = Mutex::new(ExchangeState::Request(request));
90        Ok(OhttpSession { state })
91    }
92
93    /// Encode an HTTP request in Binary HTTP format and then encrypt it into an
94    /// Oblivious HTTP request message.
95    pub fn encapsulate(
96        &self,
97        method: &str,
98        scheme: &str,
99        server: &str,
100        endpoint: &str,
101        mut headers: HashMap<String, String>,
102        payload: &[u8],
103    ) -> Result<Vec<u8>, OhttpError> {
104        let mut message =
105            bhttp::Message::request(method.into(), scheme.into(), server.into(), endpoint.into());
106
107        for (k, v) in headers.drain() {
108            message.put_header(k, v);
109        }
110
111        message.write_content(payload);
112
113        let mut encoded = vec![];
114        message
115            .write_bhttp(bhttp::Mode::KnownLength, &mut encoded)
116            .map_err(|_| OhttpError::CannotEncodeMessage)?;
117
118        let mut state = self.state.lock();
119        let request = match std::mem::take(&mut *state) {
120            ExchangeState::Request(request) => request,
121            _ => return Err(OhttpError::InvalidSession),
122        };
123        let (capsule, response) = request
124            .encapsulate(&encoded)
125            .map_err(|_| OhttpError::CannotEncodeMessage)?;
126        *state = ExchangeState::Response(response);
127
128        Ok(capsule)
129    }
130
131    /// Decode an OHTTP response returned in response to a request encoded on
132    /// this session.
133    pub fn decapsulate(&self, encoded: &[u8]) -> Result<OhttpResponse, OhttpError> {
134        let mut state = self.state.lock();
135        let decoder = match std::mem::take(&mut *state) {
136            ExchangeState::Response(response) => response,
137            _ => return Err(OhttpError::InvalidSession),
138        };
139        let binary = decoder
140            .decapsulate(encoded)
141            .map_err(|_| OhttpError::MalformedMessage)?;
142
143        let mut cursor = std::io::Cursor::new(binary);
144        let message =
145            bhttp::Message::read_bhttp(&mut cursor).map_err(|_| OhttpError::MalformedMessage)?;
146
147        let headers = headers_to_map(&message)?;
148
149        Ok(OhttpResponse {
150            status_code: match message.control() {
151                bhttp::ControlData::Response(sc) => (*sc).into(),
152                _ => return Err(OhttpError::InvalidSession),
153            },
154            headers,
155            payload: message.content().into(),
156        })
157    }
158}
159
160pub struct OhttpTestServer {
161    server: Mutex<ohttp::Server>,
162    state: Mutex<Option<ohttp::ServerResponse>>,
163    config: Vec<u8>,
164}
165
166pub struct TestServerRequest {
167    method: String,
168    scheme: String,
169    server: String,
170    endpoint: String,
171    headers: HashMap<String, String>,
172    payload: Vec<u8>,
173}
174
175impl OhttpTestServer {
176    /// Create a simple OHTTP server to decrypt and respond to OHTTP messages in
177    /// testing. The key is randomly generated.
178    fn new() -> Self {
179        ohttp::init();
180
181        let key = ohttp::KeyConfig::new(
182            0x01,
183            ohttp::hpke::Kem::X25519Sha256,
184            vec![ohttp::SymmetricSuite::new(
185                ohttp::hpke::Kdf::HkdfSha256,
186                ohttp::hpke::Aead::Aes128Gcm,
187            )],
188        )
189        .unwrap();
190
191        let config = key.encode().unwrap();
192        let server = ohttp::Server::new(key).unwrap();
193
194        OhttpTestServer {
195            server: Mutex::new(server),
196            state: Mutex::new(Option::None),
197            config,
198        }
199    }
200
201    /// Return a copy of the key config for clients to use.
202    fn get_config(&self) -> Vec<u8> {
203        self.config.clone()
204    }
205
206    /// Decode an OHTTP request message and return the cleartext contents. This
207    /// also updates the internal server state so that a response message can be
208    /// generated.
209    fn receive(&self, message: &[u8]) -> Result<TestServerRequest, OhttpError> {
210        let (encoded, response) = self
211            .server
212            .lock()
213            .decapsulate(message)
214            .map_err(|_| OhttpError::MalformedMessage)?;
215        let mut cursor = std::io::Cursor::new(encoded);
216        let message =
217            bhttp::Message::read_bhttp(&mut cursor).map_err(|_| OhttpError::MalformedMessage)?;
218
219        *self.state.lock() = Some(response);
220
221        let headers = headers_to_map(&message)?;
222
223        match message.control() {
224            bhttp::ControlData::Request {
225                method,
226                scheme,
227                authority,
228                path,
229            } => Ok(TestServerRequest {
230                method: String::from_utf8_lossy(method).into(),
231                scheme: String::from_utf8_lossy(scheme).into(),
232                server: String::from_utf8_lossy(authority).into(),
233                endpoint: String::from_utf8_lossy(path).into(),
234                headers,
235                payload: message.content().into(),
236            }),
237            _ => Err(OhttpError::MalformedMessage),
238        }
239    }
240
241    /// Encode an OHTTP response keyed to the last message received.
242    fn respond(&self, response: OhttpResponse) -> Result<Vec<u8>, OhttpError> {
243        let state = self.state.lock().take().unwrap();
244
245        let mut message =
246            bhttp::Message::response(bhttp::StatusCode::try_from(response.status_code).unwrap());
247        message.write_content(&response.payload);
248
249        for (k, v) in response.headers {
250            message.put_header(k, v);
251        }
252
253        let mut encoded = vec![];
254        message
255            .write_bhttp(bhttp::Mode::KnownLength, &mut encoded)
256            .map_err(|_| OhttpError::CannotEncodeMessage)?;
257
258        state
259            .encapsulate(&encoded)
260            .map_err(|_| OhttpError::CannotEncodeMessage)
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_smoke() {
270        let server = OhttpTestServer::new();
271        let config = server.get_config();
272
273        let body: Vec<u8> = vec![0x00, 0x01, 0x02];
274        let header = HashMap::from([
275            ("Content-Type".into(), "application/octet-stream".into()),
276            ("X-Header".into(), "value".into()),
277        ]);
278
279        let session = OhttpSession::new(&config).unwrap();
280        let mut message = session
281            .encapsulate("GET", "https", "example.com", "/api", header.clone(), &body)
282            .unwrap();
283
284        let request = server.receive(&message).unwrap();
285        assert_eq!(request.method, "GET");
286        assert_eq!(request.scheme, "https");
287        assert_eq!(request.server, "example.com");
288        assert_eq!(request.endpoint, "/api");
289        assert_eq!(request.headers, header);
290
291        message = server
292            .respond(OhttpResponse {
293                status_code: 200,
294                headers: header.clone(),
295                payload: body.clone(),
296            })
297            .unwrap();
298
299        let response = session.decapsulate(&message).unwrap();
300        assert_eq!(response.status_code, 200);
301        assert_eq!(response.headers, header);
302        assert_eq!(response.payload, body);
303    }
304}
305
306uniffi::include_scaffolding!("as_ohttp_client");