ads_client/
mars.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2* License, v. 2.0. If a copy of the MPL was not distributed with this
3* file, You can obtain one at http://mozilla.org/MPL/2.0/.
4*/
5
6use crate::{
7    error::{
8        check_http_status_for_error, CallbackRequestError, FetchAdsError, RecordClickError,
9        RecordImpressionError, ReportAdError,
10    },
11    models::{AdRequest, AdResponse},
12};
13use url::Url;
14use uuid::Uuid;
15use viaduct::Request;
16
17const DEFAULT_MARS_API_ENDPOINT: &str = "https://ads.mozilla.org/v1";
18
19#[cfg_attr(test, mockall::automock)]
20pub trait MARSClient: Sync + Send {
21    fn fetch_ads(&self, request: &AdRequest) -> Result<AdResponse, FetchAdsError>;
22    fn record_impression(
23        &self,
24        url_callback_string: Option<String>,
25    ) -> Result<(), RecordImpressionError>;
26    fn record_click(&self, url_callback_string: Option<String>) -> Result<(), RecordClickError>;
27    fn report_ad(&self, url_callback_string: Option<String>) -> Result<(), ReportAdError>;
28    fn get_context_id(&self) -> &str;
29    fn cycle_context_id(&mut self) -> String;
30    fn get_mars_endpoint(&self) -> &str;
31}
32
33pub struct DefaultMARSClient {
34    context_id: String,
35    endpoint: String,
36}
37
38impl DefaultMARSClient {
39    pub fn new(context_id: String) -> Self {
40        Self {
41            context_id,
42            endpoint: DEFAULT_MARS_API_ENDPOINT.to_string(),
43        }
44    }
45
46    #[cfg(test)]
47    pub fn new_with_endpoint(context_id: String, endpoint: String) -> Self {
48        Self {
49            context_id,
50            endpoint,
51        }
52    }
53
54    fn make_callback_request(&self, url_callback_string: &str) -> Result<(), CallbackRequestError> {
55        let url = Url::parse(url_callback_string)?;
56        let request = Request::get(url);
57        let response = request.send()?;
58        check_http_status_for_error(&response).map_err(Into::into)
59    }
60}
61
62impl MARSClient for DefaultMARSClient {
63    fn get_context_id(&self) -> &str {
64        &self.context_id
65    }
66
67    fn get_mars_endpoint(&self) -> &str {
68        &self.endpoint
69    }
70
71    /// Updates the client's context_id to the passed value and returns the previous context_id
72    /// TODO: Context_id functions should swap over to use the proper context_id component
73    fn cycle_context_id(&mut self) -> String {
74        let old_context_id = self.context_id.clone();
75        self.context_id = Uuid::new_v4().to_string();
76        old_context_id
77    }
78
79    fn fetch_ads(&self, ad_request: &AdRequest) -> Result<AdResponse, FetchAdsError> {
80        let endpoint = self.get_mars_endpoint();
81        let url = Url::parse(&format!("{endpoint}/ads"))?;
82        let request = Request::post(url).json(ad_request);
83        let response = request.send()?;
84
85        check_http_status_for_error(&response)?;
86
87        let response_json: AdResponse = response.json()?;
88
89        Ok(response_json)
90    }
91
92    fn record_impression(
93        &self,
94        url_callback_string: Option<String>,
95    ) -> Result<(), RecordImpressionError> {
96        match url_callback_string {
97            Some(callback) => self.make_callback_request(&callback).map_err(Into::into),
98            None => Err(CallbackRequestError::MissingCallback {
99                message: "Impression callback url empty.".to_string(),
100            }
101            .into()),
102        }
103    }
104
105    fn record_click(&self, url_callback_string: Option<String>) -> Result<(), RecordClickError> {
106        match url_callback_string {
107            Some(callback) => self.make_callback_request(&callback).map_err(Into::into),
108            None => Err(CallbackRequestError::MissingCallback {
109                message: "Click callback url empty.".to_string(),
110            }
111            .into()),
112        }
113    }
114
115    fn report_ad(&self, url_callback_string: Option<String>) -> Result<(), ReportAdError> {
116        match url_callback_string {
117            Some(callback) => self.make_callback_request(&callback).map_err(Into::into),
118            None => Err(CallbackRequestError::MissingCallback {
119                message: "Report callback url empty.".to_string(),
120            }
121            .into()),
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128
129    use super::*;
130    use crate::{
131        models::AdPlacementRequest,
132        test_utils::{create_test_client, get_example_happy_ad_response, TEST_CONTEXT_ID},
133    };
134    use mockito::mock;
135
136    #[test]
137    fn test_get_context_id() {
138        let client = create_test_client(mockito::server_url());
139        assert_eq!(client.get_context_id(), TEST_CONTEXT_ID.to_string());
140    }
141
142    #[test]
143    fn test_cycle_context_id() {
144        let mut client = create_test_client(mockito::server_url());
145        let old_id = client.cycle_context_id();
146        assert_eq!(old_id, TEST_CONTEXT_ID);
147        assert_ne!(client.get_context_id(), TEST_CONTEXT_ID);
148    }
149
150    #[test]
151    fn test_record_impression_with_empty_callback_should_fail() {
152        let client = create_test_client(mockito::server_url());
153        let result = client.record_impression(None);
154        assert!(result.is_err());
155    }
156
157    #[test]
158    fn test_record_click_with_empty_callback_should_fail() {
159        let client = create_test_client(mockito::server_url());
160        let result = client.record_click(None);
161        assert!(result.is_err());
162    }
163
164    #[test]
165    fn test_record_report_with_empty_callback_should_fail() {
166        let client = create_test_client(mockito::server_url());
167        let result = client.report_ad(None);
168        assert!(result.is_err());
169    }
170
171    #[test]
172    fn test_record_impression_with_valid_url_should_succeed() {
173        viaduct_reqwest::use_reqwest_backend();
174        let _m = mock("GET", "/impression_callback_url")
175            .with_status(200)
176            .create();
177        let client = create_test_client(mockito::server_url());
178        let url = format!("{}/impression_callback_url", &mockito::server_url());
179        let result = client.record_impression(Some(url));
180        assert!(result.is_ok());
181    }
182
183    #[test]
184    fn test_record_click_with_valid_url_should_succeed() {
185        viaduct_reqwest::use_reqwest_backend();
186        let _m = mock("GET", "/click_callback_url").with_status(200).create();
187
188        let client = create_test_client(mockito::server_url());
189        let url = format!("{}/click_callback_url", &mockito::server_url());
190        let result = client.record_click(Some(url));
191        assert!(result.is_ok());
192    }
193
194    #[test]
195    fn test_report_ad_with_valid_url_should_succeed() {
196        viaduct_reqwest::use_reqwest_backend();
197        let _m = mock("GET", "/report_ad_callback_url")
198            .with_status(200)
199            .create();
200
201        let client = create_test_client(mockito::server_url());
202        let url = format!("{}/report_ad_callback_url", &mockito::server_url());
203        let result = client.report_ad(Some(url));
204        assert!(result.is_ok());
205    }
206
207    #[test]
208    fn test_fetch_ads_success() {
209        viaduct_reqwest::use_reqwest_backend();
210        let expected_response = get_example_happy_ad_response();
211
212        let _m = mock("POST", "/ads")
213            .match_header("content-type", "application/json")
214            .with_status(200)
215            .with_header("content-type", "application/json")
216            .with_body(serde_json::to_string(&expected_response).unwrap())
217            .create();
218
219        let client = create_test_client(mockito::server_url());
220
221        let ad_request = AdRequest {
222            context_id: client.get_context_id().to_string(),
223            placements: vec![
224                AdPlacementRequest {
225                    placement: "example_placement_1".to_string(),
226                    count: 1,
227                    content: None,
228                },
229                AdPlacementRequest {
230                    placement: "example_placement_2".to_string(),
231                    count: 1,
232                    content: None,
233                },
234            ],
235        };
236
237        let result = client.fetch_ads(&ad_request);
238        assert!(result.is_ok());
239        assert_eq!(expected_response, result.unwrap());
240    }
241}