From f692d5567182a5f2536c7ec056713fbf24c11613 Mon Sep 17 00:00:00 2001 From: RapidPoseidon Date: Wed, 10 Jun 2026 13:23:44 +0000 Subject: [PATCH] refactor(assets): centralize asset upload and compare truth translation Co-Authored-By: Claude Fable 5 Co-Authored-By: lino --- .../audience/audience_example_handler.py | 60 +++++----- .../benchmark/participant/participant.py | 8 +- .../datapoints/_asset_uploader.py | 38 ++++++- .../datapoints/_datapoint_uploader.py | 9 +- .../datapoints/_truth_translator.py | 61 ++++++++++ .../job/rapidata_job_manager.py | 2 - .../order/rapidata_order_manager.py | 2 - .../rapids/_validation_rapid_uploader.py | 104 +++--------------- .../validation/validation_set_manager.py | 20 ++++ 9 files changed, 165 insertions(+), 139 deletions(-) create mode 100644 src/rapidata/rapidata_client/datapoints/_truth_translator.py diff --git a/src/rapidata/rapidata_client/audience/audience_example_handler.py b/src/rapidata/rapidata_client/audience/audience_example_handler.py index ded438c66..5d40ace14 100644 --- a/src/rapidata/rapidata_client/audience/audience_example_handler.py +++ b/src/rapidata/rapidata_client/audience/audience_example_handler.py @@ -23,7 +23,9 @@ from rapidata.api_client.models.i_example_payload import IExamplePayload from rapidata.api_client.models.i_example_truth import IExampleTruth from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader -from rapidata.rapidata_client.datapoints._asset_mapper import AssetMapper +from rapidata.rapidata_client.datapoints._truth_translator import ( + translate_compare_truth, +) class AudienceExampleHandler: @@ -35,7 +37,6 @@ def __init__(self, openapi_service: OpenAPIService, audience_id: str): self._openapi_service = openapi_service self._audience_id = audience_id self._asset_uploader = AssetUploader(openapi_service) - self._asset_mapper = AssetMapper() def add_classification_example( self, @@ -72,10 +73,7 @@ def add_classification_example( if not all(truth in answer_options for truth in truth): raise ValueError("Truth must be part of the answer options") - if data_type == "media": - asset_input = self._asset_uploader.upload_and_map_asset(datapoint) - else: - asset_input = self._asset_mapper.create_text_input(datapoint) + asset_input = self._asset_uploader.build_asset_input(datapoint, data_type) payload = IExamplePayload( actual_instance=IExamplePayloadClassifyExamplePayload( @@ -138,36 +136,29 @@ def add_compare_example( AddExampleToAudienceEndpointInput, ) + if truth not in datapoint: + raise ValueError("Truth must be one of the datapoints") + + if len(datapoint) != 2: + raise ValueError("Compare rapid requires exactly two media paths") + payload = IExamplePayload( actual_instance=IExamplePayloadCompareExamplePayload( _t="CompareExamplePayload", criteria=instruction ) ) - uploaded_names: list[str] = [] - if data_type == "media": - uploaded_names = [self._asset_uploader.upload_asset(dp) for dp in datapoint] - asset_input = self._asset_mapper.create_existing_asset_input(uploaded_names) - else: - asset_input = self._asset_mapper.create_text_input(datapoint) - - if truth not in datapoint: - raise ValueError("Truth must be one of the datapoints") + asset_input, asset_to_uploaded = ( + self._asset_uploader.build_asset_input_with_names(datapoint, data_type) + ) - truth_index = datapoint.index(truth) - if data_type == "media": - winner_id = uploaded_names[truth_index] - else: - winner_id = truth + winner_id = asset_to_uploaded[truth] if data_type == "media" else truth model_truth = IExampleTruth( actual_instance=IExampleTruthCompareExampleTruth( _t="CompareExampleTruth", winnerId=winner_id ) ) - if len(datapoint) != 2: - raise ValueError("Compare rapid requires exactly two media paths") - self._openapi_service.audience.examples_api.audience_audience_id_example_post( audience_id=self._audience_id, add_example_to_audience_endpoint_input=AddExampleToAudienceEndpointInput( @@ -196,23 +187,30 @@ def _add_rapid_example(self, rapid: Rapid) -> None: AddExampleToAudienceEndpointInput, ) - # Handle asset uploading based on data type - if rapid.data_type == "media": - asset_input = self._asset_uploader.upload_and_map_asset(rapid.asset) - else: - asset_input = self._asset_mapper.create_text_input(rapid.asset) + asset_input, asset_to_uploaded = ( + self._asset_uploader.build_asset_input_with_names( + rapid.asset, rapid.data_type + ) + ) - # Handle media context if present context_asset = None if rapid.media_context: context_asset = self._asset_uploader.upload_and_map_asset(rapid.media_context) + # Compare truths reference original asset paths — rewrite them to the + # uploaded names before the wire conversion. + truth = ( + translate_compare_truth(rapid.truth, asset_to_uploaded) + if rapid.data_type == "media" + else rapid.truth + ) + # Convert IValidationTruthModel to IExampleTruth # Both types are structurally identical (same JSON schema), differing only in class names # The dict-based conversion is safe and preserves all data model_truth: IExampleTruth | None = None - if rapid.truth: - truth_dict = cast(dict[str, Any], rapid.truth.to_dict()) + if truth: + truth_dict = cast(dict[str, Any], truth.to_dict()) model_truth = IExampleTruth.from_dict(truth_dict) # Convert IRapidPayload to IExamplePayload diff --git a/src/rapidata/rapidata_client/benchmark/participant/participant.py b/src/rapidata/rapidata_client/benchmark/participant/participant.py index e7d78a0e4..6a1d53d85 100644 --- a/src/rapidata/rapidata_client/benchmark/participant/participant.py +++ b/src/rapidata/rapidata_client/benchmark/participant/participant.py @@ -13,7 +13,6 @@ from opentelemetry import context as otel_context from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader -from rapidata.rapidata_client.datapoints._asset_mapper import AssetMapper from rapidata.service.openapi_service import OpenAPIService @@ -92,12 +91,7 @@ def _process_single_sample_upload( last_exception = None for attempt in range(rapidata_config.upload.maxRetries): try: - if data_type == "text": - asset_input = AssetMapper.create_text_input(asset) - else: - asset_input = AssetMapper.create_existing_asset_input( - self._asset_uploader.upload_asset(asset) - ) + asset_input = self._asset_uploader.build_asset_input(asset, data_type) with suppress_rapidata_error_logging(): self._openapi_service.leaderboard.participant_api.participant_participant_id_sample_post( diff --git a/src/rapidata/rapidata_client/datapoints/_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_asset_uploader.py index d93ba937b..1700c71e8 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_uploader.py @@ -3,7 +3,7 @@ import re import os import threading -from typing import Any +from typing import Any, Literal from rapidata.api_client.models.i_asset_input import IAssetInput from rapidata.service.openapi_service import OpenAPIService @@ -211,6 +211,42 @@ def upload_and_map_asset(self, asset: str | list[str]) -> IAssetInput: return AssetMapper.create_existing_asset_input(self.upload_asset(asset)) + def build_asset_input( + self, asset: str | list[str], data_type: Literal["media", "text"] + ) -> IAssetInput: + """Build the ``IAssetInput`` for an asset: upload media, wrap text as-is.""" + asset_input, _ = self.build_asset_input_with_names(asset, data_type) + return asset_input + + def build_asset_input_with_names( + self, asset: str | list[str], data_type: Literal["media", "text"] + ) -> tuple[IAssetInput, dict[str, str]]: + """Build the ``IAssetInput`` plus an original-asset → uploaded-name map. + + The map lets compare-truth translation rewrite truth references + (winner ids, correct combinations) from caller-supplied paths/URLs to + the uploaded names the API expects. It is empty for text assets, + which are sent verbatim and never uploaded. + """ + if data_type == "text": + return AssetMapper.create_text_input(asset), {} + + if isinstance(asset, list): + asset_to_uploaded = {a: self.upload_asset(a) for a in asset} + # Re-index through the original list so duplicate entries survive. + return ( + AssetMapper.create_existing_asset_input( + [asset_to_uploaded[a] for a in asset] + ), + asset_to_uploaded, + ) + + uploaded_name = self.upload_asset(asset) + return ( + AssetMapper.create_existing_asset_input(uploaded_name), + {asset: uploaded_name}, + ) + def clear_cache(self) -> None: """Clear both URL and file caches.""" self._get_file_cache().clear() diff --git a/src/rapidata/rapidata_client/datapoints/_datapoint_uploader.py b/src/rapidata/rapidata_client/datapoints/_datapoint_uploader.py index 1a3dee738..466d0b116 100644 --- a/src/rapidata/rapidata_client/datapoints/_datapoint_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_datapoint_uploader.py @@ -10,7 +10,6 @@ from rapidata.rapidata_client.datapoints._datapoint import Datapoint from rapidata.service.openapi_service import OpenAPIService from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader -from rapidata.rapidata_client.datapoints._asset_mapper import AssetMapper if TYPE_CHECKING: from rapidata.api_client.models.create_datapoint_endpoint_input import ( @@ -25,7 +24,6 @@ class DatapointUploader: def __init__(self, openapi_service: OpenAPIService): self.openapi_service = openapi_service self.asset_uploader = AssetUploader(openapi_service) - self.asset_mapper = AssetMapper() def upload_datapoint( self, datapoint: Datapoint, dataset_id: str, index: int @@ -34,10 +32,9 @@ def upload_datapoint( CreateDatapointEndpointInput, ) - if datapoint.data_type == "media": - uploaded_asset = self.asset_uploader.upload_and_map_asset(datapoint.asset) - else: - uploaded_asset = self.asset_mapper.create_text_input(datapoint.asset) + uploaded_asset = self.asset_uploader.build_asset_input( + datapoint.asset, datapoint.data_type + ) # If the datapoint belongs to a group, context is handled at group level has_group = datapoint.group is not None diff --git a/src/rapidata/rapidata_client/datapoints/_truth_translator.py b/src/rapidata/rapidata_client/datapoints/_truth_translator.py new file mode 100644 index 000000000..9aafa60db --- /dev/null +++ b/src/rapidata/rapidata_client/datapoints/_truth_translator.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from rapidata.api_client.models.i_validation_truth_model import IValidationTruthModel +from rapidata.api_client.models.i_validation_truth_model_compare_truth_model import ( + IValidationTruthModelCompareTruthModel, +) +from rapidata.api_client.models.i_validation_truth_model_multi_compare_truth_model import ( + IValidationTruthModelMultiCompareTruthModel, +) + + +def translate_compare_truth( + truth: IValidationTruthModel | None, + asset_to_uploaded: dict[str, str], +) -> IValidationTruthModel | None: + """Rewrite compare-truth asset references from original paths/URLs to uploaded names. + + Compare truths are built before upload, so ``winnerId`` / + ``correctCombinations`` reference assets by the caller-supplied path or + URL, while the API expects uploaded asset names. Only call this for media + rapids — text truths reference the text itself and must pass through + unchanged. Raises ``ValueError`` if a referenced asset is not in the map, + since sending the untranslated reference would silently break validation. + """ + if truth is None or truth.actual_instance is None: + return truth + + instance = truth.actual_instance + + if isinstance(instance, IValidationTruthModelCompareTruthModel): + winner_id = instance.winner_id + if winner_id not in asset_to_uploaded: + raise ValueError( + f"Compare truth winner '{winner_id}' is not one of the rapid's " + f"assets: {list(asset_to_uploaded)}" + ) + return IValidationTruthModel( + actual_instance=IValidationTruthModelCompareTruthModel( + _t="CompareTruth", winnerId=asset_to_uploaded[winner_id] + ) + ) + + if isinstance(instance, IValidationTruthModelMultiCompareTruthModel): + translated_combinations: list[list[str]] = [] + for combination in instance.correct_combinations: + translated_combination: list[str] = [] + for asset_id in combination: + if asset_id not in asset_to_uploaded: + raise ValueError( + f"Multi-compare truth asset '{asset_id}' is not one of " + f"the rapid's assets: {list(asset_to_uploaded)}" + ) + translated_combination.append(asset_to_uploaded[asset_id]) + translated_combinations.append(translated_combination) + return IValidationTruthModel( + actual_instance=IValidationTruthModelMultiCompareTruthModel( + _t="MultiCompareTruth", correctCombinations=translated_combinations + ) + ) + + return truth diff --git a/src/rapidata/rapidata_client/job/rapidata_job_manager.py b/src/rapidata/rapidata_client/job/rapidata_job_manager.py index 6e7981f71..712c839b8 100644 --- a/src/rapidata/rapidata_client/job/rapidata_job_manager.py +++ b/src/rapidata/rapidata_client/job/rapidata_job_manager.py @@ -11,7 +11,6 @@ from rapidata.rapidata_client.settings import RapidataSetting from rapidata.rapidata_client.job.rapidata_job_definition import RapidataJobDefinition from typing import Sequence, Literal, TYPE_CHECKING -from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader from rapidata.rapidata_client.dataset._rapidata_dataset import RapidataDataset from rapidata.rapidata_client.exceptions.failed_upload_exception import ( FailedUploadException, @@ -34,7 +33,6 @@ def __init__(self, openapi_service: OpenAPIService): self._openapi_service = openapi_service self.__priority: int | None = None - self._asset_uploader = AssetUploader(openapi_service) logger.debug("JobManager initialized") def _create_general_job_definition( diff --git a/src/rapidata/rapidata_client/order/rapidata_order_manager.py b/src/rapidata/rapidata_client/order/rapidata_order_manager.py index bde867666..c3f821a66 100644 --- a/src/rapidata/rapidata_client/order/rapidata_order_manager.py +++ b/src/rapidata/rapidata_client/order/rapidata_order_manager.py @@ -6,7 +6,6 @@ from rapidata.rapidata_client.config import logger from rapidata.rapidata_client.config.tracer import tracer -from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader from rapidata.rapidata_client.datapoints._datapoint import Datapoint from rapidata.rapidata_client.datapoints._datapoints_validator import ( DatapointsValidator, @@ -43,7 +42,6 @@ def __init__(self, openapi_service: OpenAPIService): self.__priority: int | None = None self.__sticky_config: StickyConfig | None = None - self._asset_uploader = AssetUploader(openapi_service) logger.debug("RapidataOrderManager initialized") def _create_general_order( diff --git a/src/rapidata/rapidata_client/validation/rapids/_validation_rapid_uploader.py b/src/rapidata/rapidata_client/validation/rapids/_validation_rapid_uploader.py index 47d541ffd..ac1844072 100644 --- a/src/rapidata/rapidata_client/validation/rapids/_validation_rapid_uploader.py +++ b/src/rapidata/rapidata_client/validation/rapids/_validation_rapid_uploader.py @@ -3,32 +3,30 @@ from rapidata.api_client.models.add_validation_rapid_endpoint_input import ( AddValidationRapidEndpointInput, ) -from rapidata.api_client.models.i_asset_input import IAssetInput from rapidata.api_client.models.i_rapid_payload_model import IRapidPayloadModel -from rapidata.api_client.models.i_validation_truth_model import IValidationTruthModel -from rapidata.api_client.models.i_validation_truth_model_compare_truth_model import ( - IValidationTruthModelCompareTruthModel, -) -from rapidata.api_client.models.i_validation_truth_model_multi_compare_truth_model import ( - IValidationTruthModelMultiCompareTruthModel, -) from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader -from rapidata.rapidata_client.datapoints._asset_mapper import AssetMapper +from rapidata.rapidata_client.datapoints._truth_translator import ( + translate_compare_truth, +) class ValidationRapidUploader: def __init__(self, openapi_service: OpenAPIService): self.openapi_service = openapi_service self.asset_uploader = AssetUploader(openapi_service) - self.asset_mapper = AssetMapper() def upload_rapid(self, rapid: Rapid, validation_set_id: str) -> None: - asset_to_uploaded: dict[str, str] = {} + uploaded_asset, asset_to_uploaded = ( + self.asset_uploader.build_asset_input_with_names( + rapid.asset, rapid.data_type + ) + ) - if rapid.data_type == "media": - uploaded_asset, asset_to_uploaded = self._upload_rapid_asset(rapid.asset) - else: - uploaded_asset = self.asset_mapper.create_text_input(rapid.asset) + truth = ( + translate_compare_truth(rapid.truth, asset_to_uploaded) + if rapid.data_type == "media" + else rapid.truth + ) context_asset = ( self.asset_uploader.upload_and_map_asset(rapid.media_context) @@ -43,7 +41,7 @@ def upload_rapid(self, rapid: Rapid, validation_set_id: str) -> None: payload=self._get_payload(rapid), context=rapid.context, contextAsset=context_asset, - truth=self._translate_compare_truth(rapid, asset_to_uploaded), + truth=truth, randomCorrectProbability=rapid.random_correct_probability, explanation=rapid.explanation, featureFlags=( @@ -54,80 +52,6 @@ def upload_rapid(self, rapid: Rapid, validation_set_id: str) -> None: ), ) - def _upload_rapid_asset( - self, asset: str | list[str] - ) -> tuple[IAssetInput, dict[str, str]]: - """Upload the rapid asset(s) and return the wrapped input plus a path→name map. - - The map lets ``_translate_compare_truth`` rewrite truth references - (winner ids, correct combinations) from caller-supplied paths to the - uploaded names the API expects. Truth only ever points at the rapid's - own assets, so this mapping isn't needed for ``media_context``. - """ - if isinstance(asset, list): - asset_to_uploaded = {a: self.asset_uploader.upload_asset(a) for a in asset} - return ( - self.asset_mapper.create_existing_asset_input( - list(asset_to_uploaded.values()) - ), - asset_to_uploaded, - ) - - uploaded_name = self.asset_uploader.upload_asset(asset) - return ( - self.asset_mapper.create_existing_asset_input(uploaded_name), - {asset: uploaded_name}, - ) - - def _translate_compare_truth( - self, rapid: Rapid, asset_to_uploaded: dict[str, str] - ) -> IValidationTruthModel | None: - """Translate compare rapid truth from original asset paths to uploaded names. - - For compare rapids with media assets, ``winnerId`` / - ``correctCombinations`` reference assets by their original path, but - the API expects uploaded names. - """ - if not rapid.truth or rapid.data_type != "media": - return rapid.truth - - if not rapid.truth.actual_instance: - return rapid.truth - - if isinstance( - rapid.truth.actual_instance, IValidationTruthModelCompareTruthModel - ): - compare_truth = rapid.truth.actual_instance - original_winner_id = compare_truth.winner_id - - if original_winner_id in asset_to_uploaded: - return IValidationTruthModel( - actual_instance=IValidationTruthModelCompareTruthModel( - _t="CompareTruth", - winnerId=asset_to_uploaded[original_winner_id], - ) - ) - - elif isinstance( - rapid.truth.actual_instance, IValidationTruthModelMultiCompareTruthModel - ): - multi_compare_truth = rapid.truth.actual_instance - translated_combinations = [ - [ - asset_to_uploaded.get(asset_id, asset_id) - for asset_id in combination - ] - for combination in multi_compare_truth.correct_combinations - ] - - return IValidationTruthModel( - actual_instance=IValidationTruthModelMultiCompareTruthModel( - _t="MultiCompareTruth", correctCombinations=translated_combinations - ) - ) - - return rapid.truth - def _get_payload(self, rapid: Rapid) -> IRapidPayloadModel: if isinstance(rapid.payload, dict): return IRapidPayloadModel.from_dict(rapid.payload) diff --git a/src/rapidata/rapidata_client/validation/validation_set_manager.py b/src/rapidata/rapidata_client/validation/validation_set_manager.py index e7e1c0f66..1e95937ba 100644 --- a/src/rapidata/rapidata_client/validation/validation_set_manager.py +++ b/src/rapidata/rapidata_client/validation/validation_set_manager.py @@ -15,6 +15,9 @@ AudienceAudienceIdJobsGetJobIdParameter, ) from rapidata.service.openapi_service import OpenAPIService +from rapidata.rapidata_client.datapoints._asset_upload_orchestrator import ( + AssetUploadOrchestrator, +) from rapidata.rapidata_client.validation.rapids.rapids_manager import RapidsManager from rapidata.rapidata_client.validation.rapids.box import Box @@ -584,6 +587,23 @@ def _submit( dimensions=dimensions, openapi_service=self._openapi_service, ) + # Pre-upload all media assets (batched URLs, parallel files) so the per-rapid + # uploads become cache hits; failures surface via failed_rapids below. + assets_to_upload: set[str] = set() + for rapid in rapids: + if rapid.data_type == "media": + if isinstance(rapid.asset, list): + assets_to_upload.update(rapid.asset) + else: + assets_to_upload.add(rapid.asset) + if rapid.media_context: + assets_to_upload.update(rapid.media_context) + + if assets_to_upload: + AssetUploadOrchestrator(self._openapi_service).upload_all_assets( + assets_to_upload + ) + with tracer.start_as_current_span("Adding rapids to validation set"): logger.debug("Adding rapids to validation set") failed_rapids = []