nimbus_cli/output/
server.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
5use anyhow::Result;
6use reqwest::StatusCode;
7use serde::{Deserialize, Serialize};
8use std::sync::RwLock;
9use std::{
10    collections::HashMap,
11    net::{IpAddr, SocketAddr},
12    sync::Arc,
13};
14
15use crate::config;
16
17use anyhow::anyhow;
18use axum::{
19    extract::{Path, State},
20    http,
21    response::{Html, IntoResponse},
22    routing::{get, post, IntoMakeService},
23    Json, Router, Server,
24};
25use hyper::server::conn::AddrIncoming;
26use serde_json::Value;
27use tower::layer::util::Stack;
28use tower_http::set_header::SetResponseHeaderLayer;
29use tower_livereload::{LiveReloadLayer, Reloader};
30
31fn create_server(
32    livereload: LiveReloadLayer,
33    state: Db,
34) -> Result<Server<AddrIncoming, IntoMakeService<Router>>, anyhow::Error> {
35    let app = create_app(livereload, state);
36
37    let addr = get_address()?;
38    eprintln!("Copy the address http://{}/ into your mobile browser", addr);
39
40    let server = Server::try_bind(&addr)?.serve(app.into_make_service());
41
42    Ok(server)
43}
44
45fn create_app(livereload: LiveReloadLayer, state: Db) -> Router {
46    Router::new()
47        .route("/", get(index))
48        .route("/style.css", get(style))
49        .route("/script.js", get(script))
50        .route("/post", post(post_handler))
51        .route("/buckets/:bucket/collections/:collection/records", get(rs))
52        .route(
53            "/v1/buckets/:bucket/collections/:collection/records",
54            get(rs),
55        )
56        .layer(livereload)
57        .layer(no_cache_layer())
58        .with_state(state)
59}
60
61fn create_state(livereload: &LiveReloadLayer) -> Db {
62    let reloader = livereload.reloader();
63    Arc::new(RwLock::new(InMemoryDb::new(reloader)))
64}
65
66#[tokio::main]
67pub(crate) async fn start_server() -> Result<bool> {
68    let livereload = LiveReloadLayer::new();
69    let state = create_state(&livereload);
70    let server = create_server(livereload, state)?;
71    server.await?;
72    Ok(true)
73}
74
75pub(crate) fn post_deeplink(
76    platform: &str,
77    deeplink: &str,
78    experiments: Option<&Value>,
79) -> Result<bool> {
80    let payload = StartAppPostPayload::new(platform, deeplink, experiments);
81    let addr = get_address()?;
82    let _ret = post_payload(&payload, &addr.to_string())?;
83    Ok(true)
84}
85
86type Db = Arc<RwLock<InMemoryDb>>;
87
88pub(crate) fn get_address() -> Result<SocketAddr> {
89    let host = config::server_host();
90    let port = config::server_port();
91
92    let port = port
93        .parse::<u16>()
94        .map_err(|_| anyhow!("NIMBUS_CLI_SERVER_PORT must be numeric"))?;
95    let host = host
96        .parse::<IpAddr>()
97        .map_err(|_| anyhow!("NIMBUS_CLI_SERVER_HOST must be an IP address"))?;
98
99    Ok((host, port).into())
100}
101
102async fn index(State(db): State<Db>) -> Html<String> {
103    let mut html =
104        include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/index.html")).to_string();
105    let li_template = include_str!(concat!(
106        env!("CARGO_MANIFEST_DIR"),
107        "/assets/li-template.html"
108    ));
109
110    let state = db.write().unwrap();
111    for p in ["android", "ios", "web"] {
112        let ppat = format!("{{{p}}}");
113        match state.url(p) {
114            Some(url) => {
115                let li = li_template.replace("{platform}", p).replace("{url}", url);
116                html = html.replace(&ppat, &li);
117            }
118            _ => {
119                html = html.replace(&ppat, "");
120            }
121        }
122    }
123
124    Html(html)
125}
126
127async fn style(State(_): State<Db>) -> &'static str {
128    include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/style.css"))
129}
130
131async fn script(State(_): State<Db>) -> &'static str {
132    include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/script.js"))
133}
134
135async fn rs(
136    State(db): State<Db>,
137    Path((_bucket, _collection)): Path<(String, String)>,
138) -> impl IntoResponse {
139    let state = db.write().unwrap();
140
141    let latest = state.latest();
142    if let Some(latest) = latest {
143        if let Some(e) = &latest.experiments {
144            (StatusCode::OK, Json(e.clone()))
145        } else {
146            // The server's latest content has no experiments; e.g.
147            // nimbus-cli open --pbpaste
148            (StatusCode::NOT_MODIFIED, Json(Value::Null))
149        }
150    } else {
151        // The server is up and running, but the first invocation of a --pbpaste
152        // has not come in yet.
153        (StatusCode::SERVICE_UNAVAILABLE, Json(Value::Null))
154    }
155}
156
157async fn post_handler(
158    State(db): State<Db>,
159    Json(payload): Json<StartAppPostPayload>,
160) -> impl IntoResponse {
161    eprintln!("Updating {platform} URL", platform = payload.platform);
162    let mut state = db.write().unwrap();
163    state.update(payload);
164    // This will be converted into a JSON response
165    // with a status code of `201 Created`
166    (StatusCode::CREATED, Json(()))
167}
168
169#[derive(Deserialize, Serialize)]
170struct StartAppPostPayload {
171    platform: String,
172    url: String,
173    experiments: Option<Value>,
174}
175
176impl StartAppPostPayload {
177    fn new(platform: &str, url: &str, experiments: Option<&Value>) -> Self {
178        Self {
179            platform: platform.to_string(),
180            url: url.to_string(),
181            experiments: experiments.cloned(),
182        }
183    }
184}
185
186fn post_payload<T: Serialize>(payload: &T, addr: &str) -> Result<String> {
187    let url = format!("http://{addr}/post");
188    let body = serde_json::to_string(payload)?;
189    let req = reqwest::blocking::Client::new()
190        .post(url)
191        .header("Content-type", "application/json; charset=UTF-8")
192        .header("accept", "application/json")
193        .body(body);
194    let resp = req.send()?;
195
196    Ok(resp.text()?)
197}
198
199struct InMemoryDb {
200    reloader: Reloader,
201    payloads: HashMap<String, StartAppPostPayload>,
202    latest: Option<String>,
203}
204
205impl InMemoryDb {
206    fn new(reloader: Reloader) -> Self {
207        Self {
208            reloader,
209            payloads: Default::default(),
210            latest: None,
211        }
212    }
213
214    fn url(&self, platform: &str) -> Option<&str> {
215        Some(self.payloads.get(platform)?.url.as_str())
216    }
217
218    fn update(&mut self, payload: StartAppPostPayload) {
219        self.latest = Some(payload.platform.clone());
220        self.payloads.insert(payload.platform.clone(), payload);
221        self.reloader.reload();
222    }
223
224    fn latest(&self) -> Option<&StartAppPostPayload> {
225        let key = self.latest.as_ref()?;
226        self.payloads.get(key)
227    }
228}
229
230type Srhl = SetResponseHeaderLayer<http::HeaderValue>;
231
232fn no_cache_layer() -> Stack<Srhl, Stack<Srhl, Srhl>> {
233    Stack::new(
234        SetResponseHeaderLayer::overriding(
235            http::header::CACHE_CONTROL,
236            http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
237        ),
238        Stack::new(
239            SetResponseHeaderLayer::overriding(
240                http::header::PRAGMA,
241                http::HeaderValue::from_static("no-cache"),
242            ),
243            SetResponseHeaderLayer::overriding(
244                http::header::EXPIRES,
245                http::HeaderValue::from_static("0"),
246            ),
247        ),
248    )
249}
250
251#[cfg(test)]
252mod tests {
253    use hyper::{Body, Method, Request, Response};
254    use serde_json::json;
255    use std::net::TcpListener;
256    use tokio::sync::oneshot::Sender;
257
258    use super::*;
259
260    fn start_test_server(port: u32) -> Result<(Db, Sender<()>)> {
261        let livereload = LiveReloadLayer::new();
262        let state = create_state(&livereload);
263
264        let app = create_app(livereload, state.clone());
265        let addr = format!("127.0.0.1:{port}");
266        let listener = TcpListener::bind(addr)?;
267        let (tx, rx) = tokio::sync::oneshot::channel::<()>();
268        tokio::spawn(async move {
269            Server::from_tcp(listener)
270                .unwrap()
271                .serve(app.into_make_service())
272                .with_graceful_shutdown(async {
273                    rx.await.ok();
274                })
275                .await
276                .unwrap();
277        });
278
279        Ok((state, tx))
280    }
281
282    async fn get(port: u32, endpoint: &str) -> Result<String> {
283        let url = format!("http://127.0.0.1:{port}{endpoint}");
284
285        let client = hyper::Client::new();
286        let response = client
287            .request(Request::builder().uri(url).body(Body::empty()).unwrap())
288            .await
289            .unwrap();
290
291        let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
292        let s = std::str::from_utf8(&body)?;
293
294        Ok(s.to_string())
295    }
296
297    async fn post_payload<T: Serialize>(payload: &T, addr: &str) -> Result<Response<Body>> {
298        let url = format!("http://{addr}/post");
299        let body = serde_json::to_string(payload)?;
300        let request = Request::builder()
301            .method(Method::POST)
302            .uri(url)
303            .header("accept", "application/json")
304            .header("Content-type", "application/json; charset=UTF-8")
305            .body(Body::from(body))
306            .unwrap();
307        let client = hyper::Client::new();
308        Ok(client.request(request).await?)
309    }
310
311    #[tokio::test]
312    async fn test_smoke_test() -> Result<()> {
313        let port = 1234;
314        let (_db, tx) = start_test_server(port)?;
315
316        let s = get(port, "/").await?;
317        assert!(s.contains("<html>"));
318
319        let _ = tx.send(());
320        Ok(())
321    }
322
323    #[tokio::test]
324    async fn test_posting_platform_url() -> Result<()> {
325        let port = 1235;
326        let (db, tx) = start_test_server(port)?;
327
328        let platform = "android";
329        let deeplink = "fenix-dev-test://open-now";
330
331        let payload = StartAppPostPayload::new(platform, deeplink, None);
332        let _ = post_payload(&payload, &format!("127.0.0.1:{port}")).await?;
333
334        // Check the internal state
335        let state = db.write().unwrap();
336        let url = state.url(platform);
337        assert_eq!(url, Some(deeplink));
338
339        let _ = tx.send(());
340        Ok(())
341    }
342
343    #[tokio::test]
344    async fn test_posting_platform_url_from_index_page() -> Result<()> {
345        let port = 1236;
346        let (_, tx) = start_test_server(port)?;
347
348        let platform = "android";
349        let deeplink = "fenix-dev-test://open-now";
350
351        let payload = StartAppPostPayload::new(platform, deeplink, None);
352        let _ = post_payload(&payload, &format!("127.0.0.1:{port}")).await?;
353
354        // Check the index.html page
355        let s = get(port, "/").await?;
356        assert!(s.contains(deeplink));
357
358        let _ = tx.send(());
359        Ok(())
360    }
361
362    #[tokio::test]
363    async fn test_posting_value_to_fake_remote_settings() -> Result<()> {
364        let port = 1237;
365        let (_, tx) = start_test_server(port)?;
366
367        let platform = "android";
368        let deeplink = "fenix-dev-test://open-now";
369        let value = json!({
370            "int": 1,
371            "boolean": true,
372            "object": {},
373            "array": [],
374            "null": null,
375        });
376        let payload = StartAppPostPayload::new(platform, deeplink, Some(&value));
377        let _ = post_payload(&payload, &format!("127.0.0.1:{port}")).await?;
378
379        // Check the fake Remote Settings page
380        let s = get(port, "/v1/buckets/BUCKET/collections/COLLECTION/records").await?;
381        assert_eq!(s, serde_json::to_string(&value)?);
382
383        let s = get(port, "/buckets/BUCKET/collections/COLLECTION/records").await?;
384        assert_eq!(s, serde_json::to_string(&value)?);
385
386        let _ = tx.send(());
387        Ok(())
388    }
389
390    #[tokio::test]
391    async fn test_getting_null_values_from_fake_remote_settings() -> Result<()> {
392        let port = 1238;
393        let (_, tx) = start_test_server(port)?;
394
395        // Part 1: get from remote settings page before anything has been posted yet.
396        let s = get(port, "/v1/buckets/BUCKET/collections/COLLECTION/records").await?;
397        assert_eq!(s, "null".to_string());
398
399        // Part 2: Post a payload, but not with any experiments.
400        let platform = "android";
401        let deeplink = "fenix-dev-test://open-now";
402
403        let payload = StartAppPostPayload::new(platform, deeplink, None);
404        let _ = post_payload(&payload, &format!("127.0.0.1:{port}")).await?;
405
406        // Check the fake Remote Settings page, should be empty, since an experiments payload
407        // wasn't posted
408        let s = get(port, "/v1/buckets/BUCKET/collections/COLLECTION/records").await?;
409        assert_eq!(s, "".to_string());
410
411        let _ = tx.send(());
412        Ok(())
413    }
414}