1extern crate rusqlite;
6
7use parking_lot::Mutex;
8use std::collections::HashMap;
9
10#[derive(Debug, thiserror::Error)]
11pub enum OhttpError {
12 #[error("Failed to fetch encryption key")]
13 KeyFetchFailed,
14
15 #[error("OHTTP key config is malformed")]
16 MalformedKeyConfig,
17
18 #[error("Unsupported OHTTP encryption algorithm")]
19 UnsupportedKeyConfig,
20
21 #[error("OhttpSession is in invalid state")]
22 InvalidSession,
23
24 #[error("Network errors communicating with Relay / Gateway")]
25 RelayFailed,
26
27 #[error("Cannot encode message as BHTTP/OHTTP")]
28 CannotEncodeMessage,
29
30 #[error("Cannot decode OHTTP/BHTTP message")]
31 MalformedMessage,
32
33 #[error("Duplicate HTTP response headers")]
34 DuplicateHeaders,
35}
36
37#[derive(Default)]
38enum ExchangeState {
39 #[default]
40 Invalid,
41 Request(ohttp::ClientRequest),
42 Response(ohttp::ClientResponse),
43}
44
45pub struct OhttpSession {
46 state: Mutex<ExchangeState>,
47}
48
49pub struct OhttpResponse {
50 status_code: u16,
51 headers: HashMap<String, String>,
52 payload: Vec<u8>,
53}
54
55fn headers_to_map(message: &bhttp::Message) -> Result<HashMap<String, String>, OhttpError> {
58 let mut headers = HashMap::new();
59
60 for field in message.header().iter() {
61 if headers
62 .insert(
63 std::str::from_utf8(field.name())
64 .map_err(|_| OhttpError::MalformedMessage)?
65 .into(),
66 std::str::from_utf8(field.value())
67 .map_err(|_| OhttpError::MalformedMessage)?
68 .into(),
69 )
70 .is_some()
71 {
72 return Err(OhttpError::DuplicateHeaders);
73 }
74 }
75
76 Ok(headers)
77}
78
79impl OhttpSession {
80 pub fn new(config: &[u8]) -> Result<Self, OhttpError> {
82 ohttp::init();
83
84 let request = ohttp::ClientRequest::from_encoded_config(config).map_err(|e| match e {
85 ohttp::Error::Unsupported => OhttpError::UnsupportedKeyConfig,
86 _ => OhttpError::MalformedKeyConfig,
87 })?;
88
89 let state = Mutex::new(ExchangeState::Request(request));
90 Ok(OhttpSession { state })
91 }
92
93 pub fn encapsulate(
96 &self,
97 method: &str,
98 scheme: &str,
99 server: &str,
100 endpoint: &str,
101 mut headers: HashMap<String, String>,
102 payload: &[u8],
103 ) -> Result<Vec<u8>, OhttpError> {
104 let mut message =
105 bhttp::Message::request(method.into(), scheme.into(), server.into(), endpoint.into());
106
107 for (k, v) in headers.drain() {
108 message.put_header(k, v);
109 }
110
111 message.write_content(payload);
112
113 let mut encoded = vec![];
114 message
115 .write_bhttp(bhttp::Mode::KnownLength, &mut encoded)
116 .map_err(|_| OhttpError::CannotEncodeMessage)?;
117
118 let mut state = self.state.lock();
119 let request = match std::mem::take(&mut *state) {
120 ExchangeState::Request(request) => request,
121 _ => return Err(OhttpError::InvalidSession),
122 };
123 let (capsule, response) = request
124 .encapsulate(&encoded)
125 .map_err(|_| OhttpError::CannotEncodeMessage)?;
126 *state = ExchangeState::Response(response);
127
128 Ok(capsule)
129 }
130
131 pub fn decapsulate(&self, encoded: &[u8]) -> Result<OhttpResponse, OhttpError> {
134 let mut state = self.state.lock();
135 let decoder = match std::mem::take(&mut *state) {
136 ExchangeState::Response(response) => response,
137 _ => return Err(OhttpError::InvalidSession),
138 };
139 let binary = decoder
140 .decapsulate(encoded)
141 .map_err(|_| OhttpError::MalformedMessage)?;
142
143 let mut cursor = std::io::Cursor::new(binary);
144 let message =
145 bhttp::Message::read_bhttp(&mut cursor).map_err(|_| OhttpError::MalformedMessage)?;
146
147 let headers = headers_to_map(&message)?;
148
149 Ok(OhttpResponse {
150 status_code: match message.control() {
151 bhttp::ControlData::Response(sc) => (*sc).into(),
152 _ => return Err(OhttpError::InvalidSession),
153 },
154 headers,
155 payload: message.content().into(),
156 })
157 }
158}
159
160pub struct OhttpTestServer {
161 server: Mutex<ohttp::Server>,
162 state: Mutex<Option<ohttp::ServerResponse>>,
163 config: Vec<u8>,
164}
165
166pub struct TestServerRequest {
167 method: String,
168 scheme: String,
169 server: String,
170 endpoint: String,
171 headers: HashMap<String, String>,
172 payload: Vec<u8>,
173}
174
175impl OhttpTestServer {
176 fn new() -> Self {
179 ohttp::init();
180
181 let key = ohttp::KeyConfig::new(
182 0x01,
183 ohttp::hpke::Kem::X25519Sha256,
184 vec![ohttp::SymmetricSuite::new(
185 ohttp::hpke::Kdf::HkdfSha256,
186 ohttp::hpke::Aead::Aes128Gcm,
187 )],
188 )
189 .unwrap();
190
191 let config = key.encode().unwrap();
192 let server = ohttp::Server::new(key).unwrap();
193
194 OhttpTestServer {
195 server: Mutex::new(server),
196 state: Mutex::new(Option::None),
197 config,
198 }
199 }
200
201 fn get_config(&self) -> Vec<u8> {
203 self.config.clone()
204 }
205
206 fn receive(&self, message: &[u8]) -> Result<TestServerRequest, OhttpError> {
210 let (encoded, response) = self
211 .server
212 .lock()
213 .decapsulate(message)
214 .map_err(|_| OhttpError::MalformedMessage)?;
215 let mut cursor = std::io::Cursor::new(encoded);
216 let message =
217 bhttp::Message::read_bhttp(&mut cursor).map_err(|_| OhttpError::MalformedMessage)?;
218
219 *self.state.lock() = Some(response);
220
221 let headers = headers_to_map(&message)?;
222
223 match message.control() {
224 bhttp::ControlData::Request {
225 method,
226 scheme,
227 authority,
228 path,
229 } => Ok(TestServerRequest {
230 method: String::from_utf8_lossy(method).into(),
231 scheme: String::from_utf8_lossy(scheme).into(),
232 server: String::from_utf8_lossy(authority).into(),
233 endpoint: String::from_utf8_lossy(path).into(),
234 headers,
235 payload: message.content().into(),
236 }),
237 _ => Err(OhttpError::MalformedMessage),
238 }
239 }
240
241 fn respond(&self, response: OhttpResponse) -> Result<Vec<u8>, OhttpError> {
243 let state = self.state.lock().take().unwrap();
244
245 let mut message =
246 bhttp::Message::response(bhttp::StatusCode::try_from(response.status_code).unwrap());
247 message.write_content(&response.payload);
248
249 for (k, v) in response.headers {
250 message.put_header(k, v);
251 }
252
253 let mut encoded = vec![];
254 message
255 .write_bhttp(bhttp::Mode::KnownLength, &mut encoded)
256 .map_err(|_| OhttpError::CannotEncodeMessage)?;
257
258 state
259 .encapsulate(&encoded)
260 .map_err(|_| OhttpError::CannotEncodeMessage)
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_smoke() {
270 let server = OhttpTestServer::new();
271 let config = server.get_config();
272
273 let body: Vec<u8> = vec![0x00, 0x01, 0x02];
274 let header = HashMap::from([
275 ("Content-Type".into(), "application/octet-stream".into()),
276 ("X-Header".into(), "value".into()),
277 ]);
278
279 let session = OhttpSession::new(&config).unwrap();
280 let mut message = session
281 .encapsulate("GET", "https", "example.com", "/api", header.clone(), &body)
282 .unwrap();
283
284 let request = server.receive(&message).unwrap();
285 assert_eq!(request.method, "GET");
286 assert_eq!(request.scheme, "https");
287 assert_eq!(request.server, "example.com");
288 assert_eq!(request.endpoint, "/api");
289 assert_eq!(request.headers, header);
290
291 message = server
292 .respond(OhttpResponse {
293 status_code: 200,
294 headers: header.clone(),
295 payload: body.clone(),
296 })
297 .unwrap();
298
299 let response = session.decapsulate(&message).unwrap();
300 assert_eq!(response.status_code, 200);
301 assert_eq!(response.headers, header);
302 assert_eq!(response.payload, body);
303 }
304}
305
306uniffi::include_scaffolding!("as_ohttp_client");