nimbus_cli/output/
server.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
5use anyhow::Result;
6use axum::http::StatusCode;
7use serde::{Deserialize, Serialize};
8use std::sync::RwLock;
9use std::{
10    collections::HashMap,
11    net::{IpAddr, SocketAddr},
12    sync::Arc,
13};
14
15use crate::config;
16
17use anyhow::anyhow;
18use axum::{
19    extract::{Path, State},
20    http,
21    response::{Html, IntoResponse},
22    routing::{get, post, IntoMakeService},
23    Json, Router, Server,
24};
25use hyper::server::conn::AddrIncoming;
26use serde_json::Value;
27use tower::layer::util::Stack;
28use tower_http::set_header::SetResponseHeaderLayer;
29use tower_livereload::{LiveReloadLayer, Reloader};
30
31fn create_server(
32    livereload: LiveReloadLayer,
33    state: Db,
34) -> Result<Server<AddrIncoming, IntoMakeService<Router>>, anyhow::Error> {
35    let app = create_app(livereload, state);
36
37    let addr = get_address()?;
38    eprintln!("Copy the address http://{}/ into your mobile browser", addr);
39
40    let server = Server::try_bind(&addr)?.serve(app.into_make_service());
41
42    Ok(server)
43}
44
45fn create_app(livereload: LiveReloadLayer, state: Db) -> Router {
46    Router::new()
47        .route("/", get(index))
48        .route("/style.css", get(style))
49        .route("/script.js", get(script))
50        .route("/post", post(post_handler))
51        .route("/buckets/:bucket/collections/:collection/records", get(rs))
52        .route(
53            "/v1/buckets/:bucket/collections/:collection/records",
54            get(rs),
55        )
56        .layer(livereload)
57        .layer(no_cache_layer())
58        .with_state(state)
59}
60
61fn create_state(livereload: &LiveReloadLayer) -> Db {
62    let reloader = livereload.reloader();
63    Arc::new(RwLock::new(InMemoryDb::new(reloader)))
64}
65
66#[tokio::main]
67pub(crate) async fn start_server() -> Result<bool> {
68    let livereload = LiveReloadLayer::new();
69    let state = create_state(&livereload);
70    let server = create_server(livereload, state)?;
71    server.await?;
72    Ok(true)
73}
74
75pub(crate) fn post_deeplink(
76    platform: &str,
77    deeplink: &str,
78    experiments: Option<&Value>,
79) -> Result<bool> {
80    let payload = StartAppPostPayload::new(platform, deeplink, experiments);
81    let addr = get_address()?;
82    let _ret = post_payload(&payload, &addr.to_string())?;
83    Ok(true)
84}
85
86type Db = Arc<RwLock<InMemoryDb>>;
87
88pub(crate) fn get_address() -> Result<SocketAddr> {
89    let host = config::server_host();
90    let port = config::server_port();
91
92    let port = port
93        .parse::<u16>()
94        .map_err(|_| anyhow!("NIMBUS_CLI_SERVER_PORT must be numeric"))?;
95    let host = host
96        .parse::<IpAddr>()
97        .map_err(|_| anyhow!("NIMBUS_CLI_SERVER_HOST must be an IP address"))?;
98
99    Ok((host, port).into())
100}
101
102async fn index(State(db): State<Db>) -> Html<String> {
103    let mut html =
104        include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/index.html")).to_string();
105    let li_template = include_str!(concat!(
106        env!("CARGO_MANIFEST_DIR"),
107        "/assets/li-template.html"
108    ));
109
110    let state = db.write().unwrap();
111    for p in ["android", "ios", "web"] {
112        let ppat = format!("{{{p}}}");
113        match state.url(p) {
114            Some(url) => {
115                let li = li_template.replace("{platform}", p).replace("{url}", url);
116                html = html.replace(&ppat, &li);
117            }
118            _ => {
119                html = html.replace(&ppat, "");
120            }
121        }
122    }
123
124    Html(html)
125}
126
127async fn style(State(_): State<Db>) -> &'static str {
128    include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/style.css"))
129}
130
131async fn script(State(_): State<Db>) -> &'static str {
132    include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/script.js"))
133}
134
135async fn rs(
136    State(db): State<Db>,
137    Path((_bucket, _collection)): Path<(String, String)>,
138) -> impl IntoResponse {
139    let state = db.write().unwrap();
140
141    let latest = state.latest();
142    if let Some(latest) = latest {
143        if let Some(e) = &latest.experiments {
144            (StatusCode::OK, Json(e.clone()))
145        } else {
146            // The server's latest content has no experiments; e.g.
147            // nimbus-cli open --pbpaste
148            (StatusCode::NOT_MODIFIED, Json(Value::Null))
149        }
150    } else {
151        // The server is up and running, but the first invocation of a --pbpaste
152        // has not come in yet.
153        (StatusCode::SERVICE_UNAVAILABLE, Json(Value::Null))
154    }
155}
156
157async fn post_handler(
158    State(db): State<Db>,
159    Json(payload): Json<StartAppPostPayload>,
160) -> impl IntoResponse {
161    eprintln!("Updating {platform} URL", platform = payload.platform);
162    let mut state = db.write().unwrap();
163    state.update(payload);
164    // This will be converted into a JSON response
165    // with a status code of `201 Created`
166    (StatusCode::CREATED, Json(()))
167}
168
169#[derive(Deserialize, Serialize)]
170struct StartAppPostPayload {
171    platform: String,
172    url: String,
173    experiments: Option<Value>,
174}
175
176impl StartAppPostPayload {
177    fn new(platform: &str, url: &str, experiments: Option<&Value>) -> Self {
178        Self {
179            platform: platform.to_string(),
180            url: url.to_string(),
181            experiments: experiments.cloned(),
182        }
183    }
184}
185
186fn post_payload<T: Serialize>(payload: &T, addr: &str) -> Result<String> {
187    let url = format!("http://{addr}/post");
188    let body = serde_json::to_string(payload)?;
189    let req = viaduct::Request::post(viaduct::parse_url(&url)?)
190        .header("Content-type", "application/json; charset=UTF-8")?
191        .header("accept", "application/json")?
192        .body(body);
193    let resp = req.send()?;
194
195    Ok(resp.text().to_string())
196}
197
198struct InMemoryDb {
199    reloader: Reloader,
200    payloads: HashMap<String, StartAppPostPayload>,
201    latest: Option<String>,
202}
203
204impl InMemoryDb {
205    fn new(reloader: Reloader) -> Self {
206        Self {
207            reloader,
208            payloads: Default::default(),
209            latest: None,
210        }
211    }
212
213    fn url(&self, platform: &str) -> Option<&str> {
214        Some(self.payloads.get(platform)?.url.as_str())
215    }
216
217    fn update(&mut self, payload: StartAppPostPayload) {
218        self.latest = Some(payload.platform.clone());
219        self.payloads.insert(payload.platform.clone(), payload);
220        self.reloader.reload();
221    }
222
223    fn latest(&self) -> Option<&StartAppPostPayload> {
224        let key = self.latest.as_ref()?;
225        self.payloads.get(key)
226    }
227}
228
229type Srhl = SetResponseHeaderLayer<http::HeaderValue>;
230
231fn no_cache_layer() -> Stack<Srhl, Stack<Srhl, Srhl>> {
232    Stack::new(
233        SetResponseHeaderLayer::overriding(
234            http::header::CACHE_CONTROL,
235            http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
236        ),
237        Stack::new(
238            SetResponseHeaderLayer::overriding(
239                http::header::PRAGMA,
240                http::HeaderValue::from_static("no-cache"),
241            ),
242            SetResponseHeaderLayer::overriding(
243                http::header::EXPIRES,
244                http::HeaderValue::from_static("0"),
245            ),
246        ),
247    )
248}
249
250#[cfg(test)]
251mod tests {
252    use hyper::{Body, Method, Request, Response};
253    use serde_json::json;
254    use std::net::TcpListener;
255    use tokio::sync::oneshot::Sender;
256
257    use super::*;
258
259    fn start_test_server(port: u32) -> Result<(Db, Sender<()>)> {
260        let livereload = LiveReloadLayer::new();
261        let state = create_state(&livereload);
262
263        let app = create_app(livereload, state.clone());
264        let addr = format!("127.0.0.1:{port}");
265        let listener = TcpListener::bind(addr)?;
266        let (tx, rx) = tokio::sync::oneshot::channel::<()>();
267        tokio::spawn(async move {
268            Server::from_tcp(listener)
269                .unwrap()
270                .serve(app.into_make_service())
271                .with_graceful_shutdown(async {
272                    rx.await.ok();
273                })
274                .await
275                .unwrap();
276        });
277
278        Ok((state, tx))
279    }
280
281    async fn get(port: u32, endpoint: &str) -> Result<String> {
282        let url = format!("http://127.0.0.1:{port}{endpoint}");
283
284        let client = hyper::Client::new();
285        let response = client
286            .request(Request::builder().uri(url).body(Body::empty()).unwrap())
287            .await
288            .unwrap();
289
290        let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
291        let s = std::str::from_utf8(&body)?;
292
293        Ok(s.to_string())
294    }
295
296    async fn post_payload<T: Serialize>(payload: &T, addr: &str) -> Result<Response<Body>> {
297        let url = format!("http://{addr}/post");
298        let body = serde_json::to_string(payload)?;
299        let request = Request::builder()
300            .method(Method::POST)
301            .uri(url)
302            .header("accept", "application/json")
303            .header("Content-type", "application/json; charset=UTF-8")
304            .body(Body::from(body))
305            .unwrap();
306        let client = hyper::Client::new();
307        Ok(client.request(request).await?)
308    }
309
310    #[tokio::test]
311    async fn test_smoke_test() -> Result<()> {
312        let port = 1234;
313        let (_db, tx) = start_test_server(port)?;
314
315        let s = get(port, "/").await?;
316        assert!(s.contains("<html>"));
317
318        let _ = tx.send(());
319        Ok(())
320    }
321
322    #[tokio::test]
323    async fn test_posting_platform_url() -> Result<()> {
324        let port = 1235;
325        let (db, tx) = start_test_server(port)?;
326
327        let platform = "android";
328        let deeplink = "fenix-dev-test://open-now";
329
330        let payload = StartAppPostPayload::new(platform, deeplink, None);
331        let _ = post_payload(&payload, &format!("127.0.0.1:{port}")).await?;
332
333        // Check the internal state
334        let state = db.write().unwrap();
335        let url = state.url(platform);
336        assert_eq!(url, Some(deeplink));
337
338        let _ = tx.send(());
339        Ok(())
340    }
341
342    #[tokio::test]
343    async fn test_posting_platform_url_from_index_page() -> Result<()> {
344        let port = 1236;
345        let (_, tx) = start_test_server(port)?;
346
347        let platform = "android";
348        let deeplink = "fenix-dev-test://open-now";
349
350        let payload = StartAppPostPayload::new(platform, deeplink, None);
351        let _ = post_payload(&payload, &format!("127.0.0.1:{port}")).await?;
352
353        // Check the index.html page
354        let s = get(port, "/").await?;
355        assert!(s.contains(deeplink));
356
357        let _ = tx.send(());
358        Ok(())
359    }
360
361    #[tokio::test]
362    async fn test_posting_value_to_fake_remote_settings() -> Result<()> {
363        let port = 1237;
364        let (_, tx) = start_test_server(port)?;
365
366        let platform = "android";
367        let deeplink = "fenix-dev-test://open-now";
368        let value = json!({
369            "int": 1,
370            "boolean": true,
371            "object": {},
372            "array": [],
373            "null": null,
374        });
375        let payload = StartAppPostPayload::new(platform, deeplink, Some(&value));
376        let _ = post_payload(&payload, &format!("127.0.0.1:{port}")).await?;
377
378        // Check the fake Remote Settings page
379        let s = get(port, "/v1/buckets/BUCKET/collections/COLLECTION/records").await?;
380        assert_eq!(s, serde_json::to_string(&value)?);
381
382        let s = get(port, "/buckets/BUCKET/collections/COLLECTION/records").await?;
383        assert_eq!(s, serde_json::to_string(&value)?);
384
385        let _ = tx.send(());
386        Ok(())
387    }
388
389    #[tokio::test]
390    async fn test_getting_null_values_from_fake_remote_settings() -> Result<()> {
391        let port = 1238;
392        let (_, tx) = start_test_server(port)?;
393
394        // Part 1: get from remote settings page before anything has been posted yet.
395        let s = get(port, "/v1/buckets/BUCKET/collections/COLLECTION/records").await?;
396        assert_eq!(s, "null".to_string());
397
398        // Part 2: Post a payload, but not with any experiments.
399        let platform = "android";
400        let deeplink = "fenix-dev-test://open-now";
401
402        let payload = StartAppPostPayload::new(platform, deeplink, None);
403        let _ = post_payload(&payload, &format!("127.0.0.1:{port}")).await?;
404
405        // Check the fake Remote Settings page, should be empty, since an experiments payload
406        // wasn't posted
407        let s = get(port, "/v1/buckets/BUCKET/collections/COLLECTION/records").await?;
408        assert_eq!(s, "".to_string());
409
410        let _ = tx.send(());
411        Ok(())
412    }
413}