Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 29 additions & 31 deletions src/rapidata/rapidata_client/audience/audience_example_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from rapidata.api_client.models.i_example_payload import IExamplePayload
from rapidata.api_client.models.i_example_truth import IExampleTruth
from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
from rapidata.rapidata_client.datapoints._asset_mapper import AssetMapper
from rapidata.rapidata_client.datapoints._truth_translator import (
translate_compare_truth,
)


class AudienceExampleHandler:
Expand All @@ -35,7 +37,6 @@ def __init__(self, openapi_service: OpenAPIService, audience_id: str):
self._openapi_service = openapi_service
self._audience_id = audience_id
self._asset_uploader = AssetUploader(openapi_service)
self._asset_mapper = AssetMapper()

def add_classification_example(
self,
Expand Down Expand Up @@ -72,10 +73,7 @@ def add_classification_example(
if not all(truth in answer_options for truth in truth):
raise ValueError("Truth must be part of the answer options")

if data_type == "media":
asset_input = self._asset_uploader.upload_and_map_asset(datapoint)
else:
asset_input = self._asset_mapper.create_text_input(datapoint)
asset_input = self._asset_uploader.build_asset_input(datapoint, data_type)

payload = IExamplePayload(
actual_instance=IExamplePayloadClassifyExamplePayload(
Expand Down Expand Up @@ -138,36 +136,29 @@ def add_compare_example(
AddExampleToAudienceEndpointInput,
)

if truth not in datapoint:
raise ValueError("Truth must be one of the datapoints")

if len(datapoint) != 2:
raise ValueError("Compare rapid requires exactly two media paths")

payload = IExamplePayload(
actual_instance=IExamplePayloadCompareExamplePayload(
_t="CompareExamplePayload", criteria=instruction
)
)

uploaded_names: list[str] = []
if data_type == "media":
uploaded_names = [self._asset_uploader.upload_asset(dp) for dp in datapoint]
asset_input = self._asset_mapper.create_existing_asset_input(uploaded_names)
else:
asset_input = self._asset_mapper.create_text_input(datapoint)

if truth not in datapoint:
raise ValueError("Truth must be one of the datapoints")
asset_input, asset_to_uploaded = (
self._asset_uploader.build_asset_input_with_names(datapoint, data_type)
)

truth_index = datapoint.index(truth)
if data_type == "media":
winner_id = uploaded_names[truth_index]
else:
winner_id = truth
winner_id = asset_to_uploaded[truth] if data_type == "media" else truth
model_truth = IExampleTruth(
actual_instance=IExampleTruthCompareExampleTruth(
_t="CompareExampleTruth", winnerId=winner_id
)
)

if len(datapoint) != 2:
raise ValueError("Compare rapid requires exactly two media paths")

self._openapi_service.audience.examples_api.audience_audience_id_example_post(
audience_id=self._audience_id,
add_example_to_audience_endpoint_input=AddExampleToAudienceEndpointInput(
Expand Down Expand Up @@ -196,23 +187,30 @@ def _add_rapid_example(self, rapid: Rapid) -> None:
AddExampleToAudienceEndpointInput,
)

# Handle asset uploading based on data type
if rapid.data_type == "media":
asset_input = self._asset_uploader.upload_and_map_asset(rapid.asset)
else:
asset_input = self._asset_mapper.create_text_input(rapid.asset)
asset_input, asset_to_uploaded = (
self._asset_uploader.build_asset_input_with_names(
rapid.asset, rapid.data_type
)
)

# Handle media context if present
context_asset = None
if rapid.media_context:
context_asset = self._asset_uploader.upload_and_map_asset(rapid.media_context)

# Compare truths reference original asset paths — rewrite them to the
# uploaded names before the wire conversion.
truth = (
translate_compare_truth(rapid.truth, asset_to_uploaded)
if rapid.data_type == "media"
else rapid.truth
)

# Convert IValidationTruthModel to IExampleTruth
# Both types are structurally identical (same JSON schema), differing only in class names
# The dict-based conversion is safe and preserves all data
model_truth: IExampleTruth | None = None
if rapid.truth:
truth_dict = cast(dict[str, Any], rapid.truth.to_dict())
if truth:
truth_dict = cast(dict[str, Any], truth.to_dict())
model_truth = IExampleTruth.from_dict(truth_dict)

# Convert IRapidPayload to IExamplePayload
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from opentelemetry import context as otel_context
from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
from rapidata.rapidata_client.datapoints._asset_mapper import AssetMapper


from rapidata.service.openapi_service import OpenAPIService
Expand Down Expand Up @@ -92,12 +91,7 @@ def _process_single_sample_upload(
last_exception = None
for attempt in range(rapidata_config.upload.maxRetries):
try:
if data_type == "text":
asset_input = AssetMapper.create_text_input(asset)
else:
asset_input = AssetMapper.create_existing_asset_input(
self._asset_uploader.upload_asset(asset)
)
asset_input = self._asset_uploader.build_asset_input(asset, data_type)

with suppress_rapidata_error_logging():
self._openapi_service.leaderboard.participant_api.participant_participant_id_sample_post(
Expand Down
38 changes: 37 additions & 1 deletion src/rapidata/rapidata_client/datapoints/_asset_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import os
import threading
from typing import Any
from typing import Any, Literal

from rapidata.api_client.models.i_asset_input import IAssetInput
from rapidata.service.openapi_service import OpenAPIService
Expand Down Expand Up @@ -211,6 +211,42 @@ def upload_and_map_asset(self, asset: str | list[str]) -> IAssetInput:

return AssetMapper.create_existing_asset_input(self.upload_asset(asset))

def build_asset_input(
self, asset: str | list[str], data_type: Literal["media", "text"]
) -> IAssetInput:
"""Build the ``IAssetInput`` for an asset: upload media, wrap text as-is."""
asset_input, _ = self.build_asset_input_with_names(asset, data_type)
return asset_input

def build_asset_input_with_names(
self, asset: str | list[str], data_type: Literal["media", "text"]
) -> tuple[IAssetInput, dict[str, str]]:
"""Build the ``IAssetInput`` plus an original-asset → uploaded-name map.

The map lets compare-truth translation rewrite truth references
(winner ids, correct combinations) from caller-supplied paths/URLs to
the uploaded names the API expects. It is empty for text assets,
which are sent verbatim and never uploaded.
"""
if data_type == "text":
return AssetMapper.create_text_input(asset), {}

if isinstance(asset, list):
asset_to_uploaded = {a: self.upload_asset(a) for a in asset}
# Re-index through the original list so duplicate entries survive.
return (
AssetMapper.create_existing_asset_input(
[asset_to_uploaded[a] for a in asset]
),
asset_to_uploaded,
)

uploaded_name = self.upload_asset(asset)
return (
AssetMapper.create_existing_asset_input(uploaded_name),
{asset: uploaded_name},
)

def clear_cache(self) -> None:
"""Clear both URL and file caches."""
self._get_file_cache().clear()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
from rapidata.service.openapi_service import OpenAPIService
from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
from rapidata.rapidata_client.datapoints._asset_mapper import AssetMapper

if TYPE_CHECKING:
from rapidata.api_client.models.create_datapoint_endpoint_input import (
Expand All @@ -25,7 +24,6 @@ class DatapointUploader:
def __init__(self, openapi_service: OpenAPIService):
self.openapi_service = openapi_service
self.asset_uploader = AssetUploader(openapi_service)
self.asset_mapper = AssetMapper()

def upload_datapoint(
self, datapoint: Datapoint, dataset_id: str, index: int
Expand All @@ -34,10 +32,9 @@ def upload_datapoint(
CreateDatapointEndpointInput,
)

if datapoint.data_type == "media":
uploaded_asset = self.asset_uploader.upload_and_map_asset(datapoint.asset)
else:
uploaded_asset = self.asset_mapper.create_text_input(datapoint.asset)
uploaded_asset = self.asset_uploader.build_asset_input(
datapoint.asset, datapoint.data_type
)

# If the datapoint belongs to a group, context is handled at group level
has_group = datapoint.group is not None
Expand Down
61 changes: 61 additions & 0 deletions src/rapidata/rapidata_client/datapoints/_truth_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

from rapidata.api_client.models.i_validation_truth_model import IValidationTruthModel
from rapidata.api_client.models.i_validation_truth_model_compare_truth_model import (
IValidationTruthModelCompareTruthModel,
)
from rapidata.api_client.models.i_validation_truth_model_multi_compare_truth_model import (
IValidationTruthModelMultiCompareTruthModel,
)


def translate_compare_truth(
truth: IValidationTruthModel | None,
asset_to_uploaded: dict[str, str],
) -> IValidationTruthModel | None:
"""Rewrite compare-truth asset references from original paths/URLs to uploaded names.

Compare truths are built before upload, so ``winnerId`` /
``correctCombinations`` reference assets by the caller-supplied path or
URL, while the API expects uploaded asset names. Only call this for media
rapids — text truths reference the text itself and must pass through
unchanged. Raises ``ValueError`` if a referenced asset is not in the map,
since sending the untranslated reference would silently break validation.
"""
if truth is None or truth.actual_instance is None:
return truth

instance = truth.actual_instance

if isinstance(instance, IValidationTruthModelCompareTruthModel):
winner_id = instance.winner_id
if winner_id not in asset_to_uploaded:
raise ValueError(
f"Compare truth winner '{winner_id}' is not one of the rapid's "
f"assets: {list(asset_to_uploaded)}"
)
return IValidationTruthModel(
actual_instance=IValidationTruthModelCompareTruthModel(
_t="CompareTruth", winnerId=asset_to_uploaded[winner_id]
)
)

if isinstance(instance, IValidationTruthModelMultiCompareTruthModel):
translated_combinations: list[list[str]] = []
for combination in instance.correct_combinations:
translated_combination: list[str] = []
for asset_id in combination:
if asset_id not in asset_to_uploaded:
raise ValueError(
f"Multi-compare truth asset '{asset_id}' is not one of "
f"the rapid's assets: {list(asset_to_uploaded)}"
)
translated_combination.append(asset_to_uploaded[asset_id])
translated_combinations.append(translated_combination)
return IValidationTruthModel(
actual_instance=IValidationTruthModelMultiCompareTruthModel(
_t="MultiCompareTruth", correctCombinations=translated_combinations
)
)

return truth
2 changes: 0 additions & 2 deletions src/rapidata/rapidata_client/job/rapidata_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from rapidata.rapidata_client.settings import RapidataSetting
from rapidata.rapidata_client.job.rapidata_job_definition import RapidataJobDefinition
from typing import Sequence, Literal, TYPE_CHECKING
from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
from rapidata.rapidata_client.dataset._rapidata_dataset import RapidataDataset
from rapidata.rapidata_client.exceptions.failed_upload_exception import (
FailedUploadException,
Expand All @@ -34,7 +33,6 @@ def __init__(self, openapi_service: OpenAPIService):
self._openapi_service = openapi_service

self.__priority: int | None = None
self._asset_uploader = AssetUploader(openapi_service)
logger.debug("JobManager initialized")

def _create_general_job_definition(
Expand Down
2 changes: 0 additions & 2 deletions src/rapidata/rapidata_client/order/rapidata_order_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from rapidata.rapidata_client.config import logger
from rapidata.rapidata_client.config.tracer import tracer
from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
from rapidata.rapidata_client.datapoints._datapoints_validator import (
DatapointsValidator,
Expand Down Expand Up @@ -43,7 +42,6 @@ def __init__(self, openapi_service: OpenAPIService):

self.__priority: int | None = None
self.__sticky_config: StickyConfig | None = None
self._asset_uploader = AssetUploader(openapi_service)
logger.debug("RapidataOrderManager initialized")

def _create_general_order(
Expand Down
Loading
Loading