diff --git a/components/ads-client/src/client.rs b/components/ads-client/src/client.rs index e9a2ab0c72..5c818e2a4c 100644 --- a/components/ads-client/src/client.rs +++ b/components/ads-client/src/client.rs @@ -130,6 +130,29 @@ where // if let Some(request_hash) = pop_request_hash_from_url(&mut impression_url) { // let _ = self.client.invalidate_cache_by_hash(&request_hash); // } + + // TODO: Add count call with _cap_key for impression capping logic + let impression_url = if let Some((_, _cap_key)) = impression_url + .query_pairs() + .find(|(key, _)| key == "cap_key") + { + let mut new_url = impression_url.clone(); + new_url + .query_pairs_mut() + .clear() + .extend_pairs( + impression_url + .query_pairs() + .collect::>() + .iter() + .filter(|(key, _)| key != "cap_key"), + ) + .finish(); + new_url + } else { + impression_url + }; + self.client .record_impression(impression_url, ohttp) .inspect_err(|e| { @@ -392,6 +415,33 @@ mod tests { m.assert(); } + #[test] + fn test_record_impression_removes_cap_key() { + viaduct_dev::init_backend_dev(); + let mars_client = MARSClient::new(Environment::Test, None, MozAdsTelemetryWrapper::noop()); + let ads_client = new_with_mars_client(mars_client); + + let base_url = mockito::server_url(); + let path_and_query = "/impression?kept=example"; + let callback_url = Url::parse(&format!("{}{}", base_url, path_and_query)).unwrap(); + + let mock = mockito::mock("GET", path_and_query) + .with_status(200) + .create(); + + ads_client.record_impression(callback_url, false).unwrap(); + + mock.assert(); + + let callback_url_with_cap_key = + Url::parse(&format!("{}{}&cap_key=test", base_url, path_and_query)).unwrap(); + ads_client + .record_impression(callback_url_with_cap_key, false) + .unwrap(); + + mock.expect(2).assert(); + } + #[test] #[ignore = "Cache invalidation temporarily disabled - will be re-enabled behind Nimbus experiment"] fn test_record_click_invalidates_cache() { diff --git a/components/ads-client/src/mars/ad_response.rs b/components/ads-client/src/mars/ad_response.rs index b9b4b20cd4..5e057b7606 100644 --- a/components/ads-client/src/mars/ad_response.rs +++ b/components/ads-client/src/mars/ad_response.rs @@ -47,15 +47,17 @@ impl AdResponse { let hash_str = request_hash.to_string(); for (placement_id, ads) in self.data.iter_mut() { for (position, ad) in ads.iter_mut().enumerate() { + let cap_key = ad.cap_key(); let callbacks = ad.callbacks_mut(); callbacks .click .query_pairs_mut() .append_pair("request_hash", &hash_str); - callbacks - .impression - .query_pairs_mut() - .append_pair("request_hash", &hash_str); + let mut impression_callback_query = callbacks.impression.query_pairs_mut(); + impression_callback_query.append_pair("request_hash", &hash_str); + if let Some(cap_key) = cap_key { + impression_callback_query.append_pair("cap_key", &cap_key); + } if let Some(report_url) = callbacks.report.as_mut() { report_url .query_pairs_mut() @@ -161,6 +163,9 @@ pub struct AdCallbacks { pub trait AdResponseValue: DeserializeOwned { fn callbacks_mut(&mut self) -> &mut AdCallbacks; + fn cap_key(&self) -> Option { + None + } } impl AdResponseValue for AdImage { @@ -173,6 +178,10 @@ impl AdResponseValue for AdSpoc { fn callbacks_mut(&mut self) -> &mut AdCallbacks { &mut self.callbacks } + + fn cap_key(&self) -> Option { + Some(self.caps.cap_key.clone()) + } } impl AdResponseValue for AdTile { @@ -514,6 +523,86 @@ mod tests { assert!(report_1.contains("position=1")); } + #[test] + fn test_enrich_callbacks_spoc_impressions_have_cap_key() { + let mut response = AdResponse { + data: HashMap::from([( + "tile1".into(), + vec![ + AdSpoc { + block_key: "block_key1".into(), + callbacks: AdCallbacks { + click: Url::parse("https://example.com/click1").unwrap(), + impression: Url::parse("https://example.com/impression1").unwrap(), + report: Some(Url::parse("https://example.com/report1").unwrap()), + }, + caps: SpocFrequencyCaps { + cap_key: "cap_key1".into(), + day: 100, + }, + domain: "1.example.com".into(), + excerpt: "excerpt1".into(), + format: "format1".into(), + image_url: "https://example.com/image1.png".into(), + url: "https://example.com/ad1".into(), + ranking: SpocRanking { + priority: 1, + personalization_models: None, + item_score: 1.0, + }, + sponsor: "sponsor1".into(), + sponsored_by_override: None, + title: "title1".into(), + }, + AdSpoc { + block_key: "block_key2".into(), + callbacks: AdCallbacks { + click: Url::parse("https://example.com/click2").unwrap(), + impression: Url::parse("https://example.com/impression2").unwrap(), + report: Some(Url::parse("https://example.com/report2").unwrap()), + }, + caps: SpocFrequencyCaps { + cap_key: "cap_key2".into(), + day: 200, + }, + domain: "2.example.com".into(), + excerpt: "excerpt2".into(), + format: "format2".into(), + image_url: "https://example.com/image2.png".into(), + url: "https://example.com/ad2".into(), + ranking: SpocRanking { + priority: 2, + personalization_models: None, + item_score: 2.0, + }, + sponsor: "sponsor2".into(), + sponsored_by_override: None, + title: "title2".into(), + }, + ], + )]), + }; + + let request_hash = RequestHash::from("abc123def456"); + response.enrich_callbacks(&request_hash); + + let ads = &response.data["tile1"]; + + assert!(ads[0] + .callbacks + .impression + .query() + .unwrap_or("") + .contains("cap_key=cap_key1")); + + assert!(ads[1] + .callbacks + .impression + .query() + .unwrap_or("") + .contains("cap_key=cap_key2")); + } + #[test] fn test_enrich_callbacks_skips_ads_without_report_url() { let mut response = AdResponse {