1use crate::{
7 error::{
8 check_http_status_for_error, CallbackRequestError, FetchAdsError, RecordClickError,
9 RecordImpressionError, ReportAdError,
10 },
11 models::{AdRequest, AdResponse},
12};
13use url::Url;
14use uuid::Uuid;
15use viaduct::Request;
16
17const DEFAULT_MARS_API_ENDPOINT: &str = "https://ads.mozilla.org/v1";
18
19#[cfg_attr(test, mockall::automock)]
20pub trait MARSClient: Sync + Send {
21 fn fetch_ads(&self, request: &AdRequest) -> Result<AdResponse, FetchAdsError>;
22 fn record_impression(
23 &self,
24 url_callback_string: Option<String>,
25 ) -> Result<(), RecordImpressionError>;
26 fn record_click(&self, url_callback_string: Option<String>) -> Result<(), RecordClickError>;
27 fn report_ad(&self, url_callback_string: Option<String>) -> Result<(), ReportAdError>;
28 fn get_context_id(&self) -> &str;
29 fn cycle_context_id(&mut self) -> String;
30 fn get_mars_endpoint(&self) -> &str;
31}
32
33pub struct DefaultMARSClient {
34 context_id: String,
35 endpoint: String,
36}
37
38impl DefaultMARSClient {
39 pub fn new(context_id: String) -> Self {
40 Self {
41 context_id,
42 endpoint: DEFAULT_MARS_API_ENDPOINT.to_string(),
43 }
44 }
45
46 #[cfg(test)]
47 pub fn new_with_endpoint(context_id: String, endpoint: String) -> Self {
48 Self {
49 context_id,
50 endpoint,
51 }
52 }
53
54 fn make_callback_request(&self, url_callback_string: &str) -> Result<(), CallbackRequestError> {
55 let url = Url::parse(url_callback_string)?;
56 let request = Request::get(url);
57 let response = request.send()?;
58 check_http_status_for_error(&response).map_err(Into::into)
59 }
60}
61
62impl MARSClient for DefaultMARSClient {
63 fn get_context_id(&self) -> &str {
64 &self.context_id
65 }
66
67 fn get_mars_endpoint(&self) -> &str {
68 &self.endpoint
69 }
70
71 fn cycle_context_id(&mut self) -> String {
74 let old_context_id = self.context_id.clone();
75 self.context_id = Uuid::new_v4().to_string();
76 old_context_id
77 }
78
79 fn fetch_ads(&self, ad_request: &AdRequest) -> Result<AdResponse, FetchAdsError> {
80 let endpoint = self.get_mars_endpoint();
81 let url = Url::parse(&format!("{endpoint}/ads"))?;
82 let request = Request::post(url).json(ad_request);
83 let response = request.send()?;
84
85 check_http_status_for_error(&response)?;
86
87 let response_json: AdResponse = response.json()?;
88
89 Ok(response_json)
90 }
91
92 fn record_impression(
93 &self,
94 url_callback_string: Option<String>,
95 ) -> Result<(), RecordImpressionError> {
96 match url_callback_string {
97 Some(callback) => self.make_callback_request(&callback).map_err(Into::into),
98 None => Err(CallbackRequestError::MissingCallback {
99 message: "Impression callback url empty.".to_string(),
100 }
101 .into()),
102 }
103 }
104
105 fn record_click(&self, url_callback_string: Option<String>) -> Result<(), RecordClickError> {
106 match url_callback_string {
107 Some(callback) => self.make_callback_request(&callback).map_err(Into::into),
108 None => Err(CallbackRequestError::MissingCallback {
109 message: "Click callback url empty.".to_string(),
110 }
111 .into()),
112 }
113 }
114
115 fn report_ad(&self, url_callback_string: Option<String>) -> Result<(), ReportAdError> {
116 match url_callback_string {
117 Some(callback) => self.make_callback_request(&callback).map_err(Into::into),
118 None => Err(CallbackRequestError::MissingCallback {
119 message: "Report callback url empty.".to_string(),
120 }
121 .into()),
122 }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128
129 use super::*;
130 use crate::{
131 models::AdPlacementRequest,
132 test_utils::{create_test_client, get_example_happy_ad_response, TEST_CONTEXT_ID},
133 };
134 use mockito::mock;
135
136 #[test]
137 fn test_get_context_id() {
138 let client = create_test_client(mockito::server_url());
139 assert_eq!(client.get_context_id(), TEST_CONTEXT_ID.to_string());
140 }
141
142 #[test]
143 fn test_cycle_context_id() {
144 let mut client = create_test_client(mockito::server_url());
145 let old_id = client.cycle_context_id();
146 assert_eq!(old_id, TEST_CONTEXT_ID);
147 assert_ne!(client.get_context_id(), TEST_CONTEXT_ID);
148 }
149
150 #[test]
151 fn test_record_impression_with_empty_callback_should_fail() {
152 let client = create_test_client(mockito::server_url());
153 let result = client.record_impression(None);
154 assert!(result.is_err());
155 }
156
157 #[test]
158 fn test_record_click_with_empty_callback_should_fail() {
159 let client = create_test_client(mockito::server_url());
160 let result = client.record_click(None);
161 assert!(result.is_err());
162 }
163
164 #[test]
165 fn test_record_report_with_empty_callback_should_fail() {
166 let client = create_test_client(mockito::server_url());
167 let result = client.report_ad(None);
168 assert!(result.is_err());
169 }
170
171 #[test]
172 fn test_record_impression_with_valid_url_should_succeed() {
173 viaduct_reqwest::use_reqwest_backend();
174 let _m = mock("GET", "/impression_callback_url")
175 .with_status(200)
176 .create();
177 let client = create_test_client(mockito::server_url());
178 let url = format!("{}/impression_callback_url", &mockito::server_url());
179 let result = client.record_impression(Some(url));
180 assert!(result.is_ok());
181 }
182
183 #[test]
184 fn test_record_click_with_valid_url_should_succeed() {
185 viaduct_reqwest::use_reqwest_backend();
186 let _m = mock("GET", "/click_callback_url").with_status(200).create();
187
188 let client = create_test_client(mockito::server_url());
189 let url = format!("{}/click_callback_url", &mockito::server_url());
190 let result = client.record_click(Some(url));
191 assert!(result.is_ok());
192 }
193
194 #[test]
195 fn test_report_ad_with_valid_url_should_succeed() {
196 viaduct_reqwest::use_reqwest_backend();
197 let _m = mock("GET", "/report_ad_callback_url")
198 .with_status(200)
199 .create();
200
201 let client = create_test_client(mockito::server_url());
202 let url = format!("{}/report_ad_callback_url", &mockito::server_url());
203 let result = client.report_ad(Some(url));
204 assert!(result.is_ok());
205 }
206
207 #[test]
208 fn test_fetch_ads_success() {
209 viaduct_reqwest::use_reqwest_backend();
210 let expected_response = get_example_happy_ad_response();
211
212 let _m = mock("POST", "/ads")
213 .match_header("content-type", "application/json")
214 .with_status(200)
215 .with_header("content-type", "application/json")
216 .with_body(serde_json::to_string(&expected_response).unwrap())
217 .create();
218
219 let client = create_test_client(mockito::server_url());
220
221 let ad_request = AdRequest {
222 context_id: client.get_context_id().to_string(),
223 placements: vec![
224 AdPlacementRequest {
225 placement: "example_placement_1".to_string(),
226 count: 1,
227 content: None,
228 },
229 AdPlacementRequest {
230 placement: "example_placement_2".to_string(),
231 count: 1,
232 content: None,
233 },
234 ],
235 };
236
237 let result = client.fetch_ads(&ad_request);
238 assert!(result.is_ok());
239 assert_eq!(expected_response, result.unwrap());
240 }
241}