From 68d6f6e0ceb74b92fc14ae07037b452345f4c40d Mon Sep 17 00:00:00 2001 From: Copilot <223556219+Copilot@users.noreply.github.com> Date: Sun, 7 Jun 2026 18:02:17 -0700 Subject: [PATCH 1/2] Remove deprecations scheduled for v0.15.0 Now that v0.14.0 has shipped, drop every API marked removed_in="0.15.0": * MessagePiece fields originator, scorer_identifier, scores, and targeted_harm_categories (plus their Pydantic kwarg warnings) * AttackResult.attack_identifier property + promotion validator * AttackResultEntry.attack_identifier column and PromptMemoryEntry.targeted_harm_categories column (with Alembic migration f1a2b3c4d5e6) * MemoryInterface.export_conversations, SQLiteMemory.export_conversations, SQLiteMemory.export_all_tables, MemoryExporter shim + module * MemoryInterface.get_attack_results kwargs attack_class and targeted_harm_categories, and the per-backend harm-category condition helpers (callers should use the labels filter going forward) * ScenarioStrategy.normalize_strategies, ContentHarms/ContentHarmsStrategy aliases, AzureSpeechAudioToTextConverter.recognize_audio, HuggingFaceEndpointTarget (deleted entirely) * AzureMLChatTarget(message_normalizer=...), OpenAIImageTarget style param + DALL-E sizes/qualities + URL fallback, use_entra_auth on AzureSpeech*Converter and Audio*Scorer helpers, positional / x_pos / y_pos kwargs on AddTextImageConverter and AddImageTextConverter * ChatMessage.to_json/from_json, EmbeddingResponse.to_json, ScorerMetrics.from_json deprecated aliases Downstream rewires: SelectorScope and compute_technique_stats no longer accept targeted_harm_categories (per the user-approved feature loss); conversation_manager queries memory for prepended-conversation scores instead of relying on the removed MessagePiece.scores field; get_prompt_scores resolves scores against ScoreEntry directly. Tests for deprecated paths are deleted; callers using those paths are migrated. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/analytics/technique_analysis.py | 4 - .../attack/component/conversation_manager.py | 27 +- pyrit/memory/__init__.py | 17 - ...2b3c4d5e6_drop_v0_14_deprecated_columns.py | 48 +++ pyrit/memory/azure_sql_memory.py | 37 -- pyrit/memory/memory_exporter.py | 122 ------- pyrit/memory/memory_interface.py | 151 +------- pyrit/memory/memory_models.py | 31 +- pyrit/memory/sqlite_memory.py | 155 +------- pyrit/models/chat_message.py | 38 -- pyrit/models/embeddings.py | 17 - pyrit/models/messages/message_piece.py | 55 +-- pyrit/models/results/attack_result.py | 51 +-- .../add_image_text_converter.py | 48 +-- .../add_text_image_converter.py | 20 +- .../azure_speech_audio_to_text_converter.py | 54 +-- .../azure_speech_text_to_audio_converter.py | 22 -- pyrit/prompt_target/__init__.py | 2 - pyrit/prompt_target/azure_ml_chat_target.py | 50 +-- .../hugging_face_endpoint_target.py | 208 ----------- .../openai/openai_image_target.py | 75 +--- pyrit/scenario/__init__.py | 2 +- pyrit/scenario/core/scenario_strategy.py | 26 -- .../adaptive/selectors/epsilon_greedy.py | 1 - .../adaptive/selectors/technique_selector.py | 9 +- pyrit/scenario/scenarios/airt/__init__.py | 20 +- .../scenario/scenarios/airt/content_harms.py | 52 --- pyrit/score/audio_transcript_scorer.py | 17 - pyrit/score/conversation_scorer.py | 1 - .../float_scale/audio_float_scale_scorer.py | 8 - .../score/scorer_evaluation/scorer_metrics.py | 23 -- .../true_false/audio_true_false_scorer.py | 8 - tests/unit/analytics/test_result_analysis.py | 6 +- .../unit/analytics/test_technique_analysis.py | 10 - tests/unit/backend/test_mappers.py | 66 ++-- .../component/test_conversation_manager.py | 96 ++--- .../test_interface_attack_results.py | 145 +------- .../memory_interface/test_interface_export.py | 216 ----------- .../test_interface_prompts.py | 16 +- tests/unit/memory/test_memory_exporter.py | 106 ------ tests/unit/models/test_attack_result.py | 160 +-------- tests/unit/models/test_chat_message.py | 15 - tests/unit/models/test_embedding_response.py | 6 - tests/unit/models/test_message_piece.py | 170 +-------- .../test_add_image_text_converter.py | 39 -- .../test_add_text_image_converter.py | 18 - .../test_azure_speech_converter.py | 8 - .../test_azure_speech_text_converter.py | 57 --- .../test_hugging_face_endpoint_target.py | 334 ------------------ .../prompt_target/target/test_image_target.py | 144 -------- .../test_normalize_async_integration.py | 57 --- .../target/test_target_capabilities.py | 2 - .../unit/scenario/airt/test_rapid_response.py | 46 --- .../scenarios/adaptive/test_epsilon_greedy.py | 12 - .../scenarios/adaptive/test_selector_scope.py | 17 +- tests/unit/score/test_audio_scorer.py | 10 +- tests/unit/score/test_scorer_metrics.py | 20 -- 57 files changed, 218 insertions(+), 2957 deletions(-) create mode 100644 pyrit/memory/alembic/versions/f1a2b3c4d5e6_drop_v0_14_deprecated_columns.py delete mode 100644 pyrit/memory/memory_exporter.py delete mode 100644 pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py delete mode 100644 pyrit/scenario/scenarios/airt/content_harms.py delete mode 100644 tests/unit/memory/memory_interface/test_interface_export.py delete mode 100644 tests/unit/memory/test_memory_exporter.py delete mode 100644 tests/unit/prompt_target/target/test_hugging_face_endpoint_target.py diff --git a/pyrit/analytics/technique_analysis.py b/pyrit/analytics/technique_analysis.py index a091d29821..c44c032f4f 100644 --- a/pyrit/analytics/technique_analysis.py +++ b/pyrit/analytics/technique_analysis.py @@ -21,7 +21,6 @@ def compute_technique_stats( *, technique_eval_hashes: Sequence[str], scenario_result_id: str | None = None, - targeted_harm_categories: Sequence[str] | None = None, memory: MemoryInterface | None = None, ) -> dict[str, AttackStats]: """ @@ -40,8 +39,6 @@ def compute_technique_stats( Returned dict is keyed by these. scenario_result_id (str | None): Restrict to a single scenario run. Defaults to ``None`` (aggregate across all runs). - targeted_harm_categories (Sequence[str] | None): Restrict to results - whose prompts targeted these harm categories. Defaults to ``None``. memory (MemoryInterface | None): Memory backend to query. Defaults to ``CentralMemory.get_memory_instance()``. @@ -57,7 +54,6 @@ def compute_technique_stats( results = memory.get_attack_results( atomic_attack_eval_hashes=list(technique_eval_hashes), scenario_result_id=scenario_result_id, - targeted_harm_categories=targeted_harm_categories, ) requested = set(technique_eval_hashes) diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index b7d554775c..9dcb8e2826 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -568,17 +568,22 @@ async def _process_prepended_for_chat_target_async( if hasattr(context, "executed_turns"): context.executed_turns = state.turn_count # type: ignore[ty:invalid-assignment] - # Extract scores on final prepended assistant message if it exists and are relavent - # Multi-part messages (e.g., text + image) may have scores on multiple pieces - # only extract true_false scores with score_value=False. This allows attacks - # to use the score's rationale for feedback without re-scoring. - for piece in final_prepended_message.message_pieces: - for score in piece.scores: - if score.score_type == "true_false" and score.get_value() is False: - state.last_assistant_message_scores.append(score) - # context.last_score gets the first matching score for single-score use cases. - if hasattr(context, "last_score") and context.last_score is None: - context.last_score = score # type: ignore[ty:invalid-assignment] + # Extract scores on final prepended assistant message if it exists and are relevant. + # The prepended pieces were re-keyed with new ids when added to memory, so look + # them up by conversation_id and filter to the last assistant turn. Only extract + # true_false scores with score_value=False so attacks can use the rationale for + # feedback without re-scoring. + memory_pieces = self._memory.get_message_pieces(conversation_id=conversation_id) + assistant_piece_ids = [str(piece.id) for piece in memory_pieces if piece.api_role == "assistant"] + existing_scores = ( + self._memory.get_prompt_scores(prompt_ids=assistant_piece_ids) if assistant_piece_ids else [] + ) + for score in existing_scores: + if score.score_type == "true_false" and score.get_value() is False: + state.last_assistant_message_scores.append(score) + # context.last_score gets the first matching score for single-score use cases. + if hasattr(context, "last_score") and context.last_score is None: + context.last_score = score # type: ignore[ty:invalid-assignment] return state diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 2579dc1334..78acdc5dd8 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -7,8 +7,6 @@ This package defines the core `MemoryInterface` and concrete implementations for different storage backends. """ -from typing import Any - from pyrit.memory.azure_sql_memory import AzureSQLMemory from pyrit.memory.central_memory import CentralMemory from pyrit.memory.memory_embedding import MemoryEmbedding @@ -24,21 +22,6 @@ "EmbeddingDataEntry", "MemoryInterface", "MemoryEmbedding", - "MemoryExporter", "PromptMemoryEntry", "SeedEntry", ] - - -def __getattr__(name: str) -> Any: - if name == "MemoryExporter": - from pyrit.common.deprecation import print_deprecation_message - from pyrit.memory.memory_exporter import MemoryExporter - - print_deprecation_message( - old_item="pyrit.memory.MemoryExporter", - new_item="the pyrit.output module or direct serialization", - removed_in="0.15.0", - ) - return MemoryExporter - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/memory/alembic/versions/f1a2b3c4d5e6_drop_v0_14_deprecated_columns.py b/pyrit/memory/alembic/versions/f1a2b3c4d5e6_drop_v0_14_deprecated_columns.py new file mode 100644 index 0000000000..d650649300 --- /dev/null +++ b/pyrit/memory/alembic/versions/f1a2b3c4d5e6_drop_v0_14_deprecated_columns.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Drop deprecated columns scheduled for removal in v0.15.0. + +* ``AttackResultEntries.attack_identifier`` (superseded by + ``atomic_attack_identifier``). +* ``PromptMemoryEntries.targeted_harm_categories`` (callers should use the + attack-level ``labels`` column with ``{"harm_category": [...]}`` instead). + +Revision ID: f1a2b3c4d5e6 +Revises: 9c8b7a6d5e4f +Create Date: 2026-06-05 14:39:00.000000 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f1a2b3c4d5e6" +down_revision: str | None = "9c8b7a6d5e4f" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply this schema upgrade.""" + # SQLite does not support DROP COLUMN on a table with constraints in older + # versions; use batch_alter_table so the operation is portable across both + # SQLite and Azure SQL. + with op.batch_alter_table("AttackResultEntries") as batch_op: + batch_op.drop_column("attack_identifier") + with op.batch_alter_table("PromptMemoryEntries") as batch_op: + batch_op.drop_column("targeted_harm_categories") + + +def downgrade() -> None: + """Revert this schema upgrade.""" + # Re-add the columns as nullable so legacy code can still write to them + # (the not-null default on attack_identifier is intentionally relaxed on + # downgrade since we have no way to backfill the original value). + with op.batch_alter_table("PromptMemoryEntries") as batch_op: + batch_op.add_column(sa.Column("targeted_harm_categories", sa.JSON(), nullable=True)) + with op.batch_alter_table("AttackResultEntries") as batch_op: + batch_op.add_column(sa.Column("attack_identifier", sa.JSON(), nullable=True)) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 6723ae2842..4fafdbf169 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -448,43 +448,6 @@ def _get_condition_json_array_match( combined = joiner.join(conditions) return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND ({combined})""").bindparams(**bindparams_dict) - def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: - """ - Get the SQL Azure implementation for filtering AttackResults by targeted harm categories. - - Uses JSON_QUERY() function specific to SQL Azure to check if categories exist in the JSON array. - - Args: - targeted_harm_categories (Sequence[str]): List of harm category strings to filter by. - - Returns: - Any: SQLAlchemy exists subquery condition with bound parameters. - """ - # For SQL Azure, we need to use JSON_QUERY to check if a value exists in a JSON array - # OPENJSON can parse the array and we check if the category exists - # Using parameterized queries for safety - harm_conditions = [] - bindparams_dict = {} - for i, category in enumerate(targeted_harm_categories): - param_name = f"harm_cat_{i}" - # Check if the JSON array contains the category value - harm_conditions.append( - f"EXISTS(SELECT 1 FROM OPENJSON(targeted_harm_categories) WHERE value = :{param_name})" - ) - bindparams_dict[param_name] = category - - combined_conditions = " AND ".join(harm_conditions) - - return exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - PromptMemoryEntry.targeted_harm_categories.isnot(None), - PromptMemoryEntry.targeted_harm_categories != "", - PromptMemoryEntry.targeted_harm_categories != "[]", - text(f"ISJSON(targeted_harm_categories) = 1 AND {combined_conditions}").bindparams(**bindparams_dict), - ) - ) - def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence[str]]) -> Any: """ Azure SQL implementation for filtering AttackResults by labels. diff --git a/pyrit/memory/memory_exporter.py b/pyrit/memory/memory_exporter.py deleted file mode 100644 index 54e61505b3..0000000000 --- a/pyrit/memory/memory_exporter.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import csv -import json -from pathlib import Path -from typing import Optional - -from pyrit.models import MessagePiece - - -class MemoryExporter: - """ - Handles the export of data to various formats, currently supporting only JSON format. - This class utilizes the strategy design pattern to select the appropriate export format. - """ - - def __init__(self) -> None: - """ - Initialize the MemoryExporter. - - Sets up the available export formats using the strategy design pattern. - """ - # Using strategy design pattern for export functionality. - self.export_strategies = { - "json": self.export_to_json, - "csv": self.export_to_csv, - "md": self.export_to_markdown, - # Future formats can be added here - } - - def export_data( - self, data: list[MessagePiece], *, file_path: Optional[Path] = None, export_type: str = "json" - ) -> None: - """ - Export the provided data to a file in the specified format. - - Args: - data (list[MessagePiece]): The data to be exported, as a list of MessagePiece instances. - file_path (str): The full path, including the file name, where the data will be exported. - export_type (str, Optional): The format for exporting data. Defaults to "json". - - Raises: - ValueError: If no file_path is provided or if the specified export format is not supported. - """ - if not file_path: - raise ValueError("Please provide a valid file path for exporting data.") - - export_func = self.export_strategies.get(export_type) - if export_func: - export_func(data, file_path) - else: - raise ValueError(f"Unsupported export format: {export_type}") - - def export_to_json(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: - """ - Export the provided data to a JSON file at the specified file path. - Each item in the data list, representing a row from the table, - is converted to a dictionary before being written to the file. - - Args: - data (list[MessagePiece]): The data to be exported, as a list of MessagePiece instances. - file_path (Path): The full path, including the file name, where the data will be exported. - - Raises: - ValueError: If no file_path is provided. - """ - if not file_path: - raise ValueError("Please provide a valid file path for exporting data.") - if not data: - raise ValueError("No data to export.") - export_data = [piece.model_dump(mode="json") for piece in data] - with open(file_path, "w") as f: - json.dump(export_data, f, indent=4) - - def export_to_csv(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: - """ - Export the provided data to a CSV file at the specified file path. - Each item in the data list, representing a row from the table, - is converted to a dictionary before being written to the file. - - Args: - data (list[MessagePiece]): The data to be exported, as a list of MessagePiece instances. - file_path (Path): The full path, including the file name, where the data will be exported. - - Raises: - ValueError: If no file_path is provided. - """ - if not file_path: - raise ValueError("Please provide a valid file path for exporting data.") - if not data: - raise ValueError("No data to export.") - export_data = [piece.model_dump(mode="json") for piece in data] - fieldnames = list(export_data[0].keys()) - with open(file_path, "w", newline="") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - writer.writerows(export_data) - - def export_to_markdown(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: - """ - Export the provided data to a Markdown file at the specified file path. - Each item in the data list is converted to a dictionary and formatted as a table. - - Args: - data (list[MessagePiece]): The data to be exported, as a list of MessagePiece instances. - file_path (Path): The full path, including the file name, where the data will be exported. - - Raises: - ValueError: If no file_path is provided or if there is no data to export. - """ - if not file_path: - raise ValueError("Please provide a valid file path for exporting data.") - if not data: - raise ValueError("No data to export.") - export_data = [piece.model_dump(mode="json") for piece in data] - fieldnames = list(export_data[0].keys()) - with open(file_path, "w", newline="") as f: - f.write(f"| {' | '.join(fieldnames)} |\n") - f.write(f"| {' | '.join(['---'] * len(fieldnames))} |\n") - for row in export_data: - f.write(f"| {' | '.join(str(row[field]) for field in fieldnames)} |\n") diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 26448f5b6c..092d11547a 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -10,7 +10,6 @@ from collections.abc import MutableSequence, Sequence from contextlib import closing from datetime import datetime, timezone -from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union from sqlalchemy import MetaData, and_, not_, or_ @@ -21,9 +20,6 @@ if TYPE_CHECKING: from pyrit.memory.memory_embedding import MemoryEmbedding -from pyrit.common.deprecation import print_deprecation_message -from pyrit.common.path import DB_DATA_PATH -from pyrit.memory.memory_exporter import MemoryExporter from pyrit.memory.memory_models import ( AttackResultEntry, Base, @@ -103,8 +99,6 @@ def __init__(self, embedding_model: Optional[Any] = None) -> None: but also includes overhead. """ self.memory_embedding = embedding_model - # Initialize the MemoryExporter instance - self.exporter = MemoryExporter() self._init_storage_io() # Ensure cleanup at process exit @@ -593,19 +587,6 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict update_fields (dict): A dictionary of field names and their new values. """ - @abc.abstractmethod - def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: - """ - Return a database-specific condition for filtering AttackResults by targeted harm categories - in the associated PromptMemoryEntry records. - - Args: - targeted_harm_categories: List of harm categories that must ALL be present. - - Returns: - Database-specific SQLAlchemy condition. - """ - @abc.abstractmethod def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence[str]]) -> Any: """ @@ -826,21 +807,19 @@ def get_prompt_scores( converted_value_sha256=converted_value_sha256, ) - # Deduplicate message pieces by original_prompt_id to avoid duplicate scores - # since duplicated pieces share scores with their originals - seen_original_ids = set() - unique_pieces = [] - for piece in message_pieces: - if piece.original_prompt_id not in seen_original_ids: - seen_original_ids.add(piece.original_prompt_id) - unique_pieces.append(piece) - - scores = [] - for piece in unique_pieces: - if piece.scores: - scores.extend(piece.scores) + # Deduplicate by original_prompt_id since duplicated pieces share scores + # with their originals. + original_ids = {piece.original_prompt_id for piece in message_pieces if piece.original_prompt_id is not None} + if not original_ids: + return [] - return list(scores) + score_entries = self._execute_batched_query( + ScoreEntry, + batch_column=ScoreEntry.prompt_request_response_id, + batch_values=list(original_ids), + other_conditions=[], + ) + return [entry.get_score() for entry in score_entries] def get_conversation(self, *, conversation_id: str) -> MutableSequence[Message]: """ @@ -1561,76 +1540,6 @@ def get_seed_groups( return seed_groups - def export_conversations( - self, - *, - attack_id: Optional[str | uuid.UUID] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[Sequence[str] | Sequence[uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[Sequence[str]] = None, - converted_values: Optional[Sequence[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[Sequence[str]] = None, - file_path: Optional[Path] = None, - export_type: str = "json", - ) -> Path: - """ - Export conversation data with the given inputs to a specified file. - Defaults to all conversations if no filters are provided. - - Args: - attack_id (Optional[str | uuid.UUID], optional): The ID of the attack. Defaults to None. - conversation_id (Optional[str | uuid.UUID], optional): The ID of the conversation. Defaults to None. - prompt_ids (Optional[Sequence[str] | Sequence[uuid.UUID]], optional): A list of prompt IDs. - Defaults to None. - labels (Optional[dict[str, str]], optional): A dictionary of labels. Defaults to None. - sent_after (Optional[datetime], optional): Filter for prompts sent after this datetime. Defaults to None. - sent_before (Optional[datetime], optional): Filter for prompts sent before this datetime. Defaults to None. - original_values (Optional[Sequence[str]], optional): A list of original values. Defaults to None. - converted_values (Optional[Sequence[str]], optional): A list of converted values. Defaults to None. - data_type (Optional[str], optional): The data type to filter by. Defaults to None. - not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. - converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. - Defaults to None. - file_path (Optional[Path], optional): The path to the file where the data will be exported. - Defaults to None. - export_type (str, optional): The format of the export. Defaults to "json". - - Returns: - Path: The path to the exported file. - """ - print_deprecation_message( - old_item="MemoryInterface.export_conversations", - new_item="the pyrit.output module or direct serialization of get_message_pieces results", - removed_in="0.15.0", - ) - data = self.get_message_pieces( - attack_id=attack_id, - conversation_id=conversation_id, - prompt_ids=prompt_ids, - labels=labels, - sent_after=sent_after, - sent_before=sent_before, - original_values=original_values, - converted_values=converted_values, - data_type=data_type, - not_data_type=not_data_type, - converted_value_sha256=converted_value_sha256, - ) - - # If file_path is not provided, construct a default using the exporter's results_path - if not file_path: - file_name = f"exported_conversations_on_{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d')}.{export_type}" - file_path = DB_DATA_PATH / file_name - - self.exporter.export_data(list(data), file_path=file_path, export_type=export_type) - - return file_path - def add_attack_results_to_memory(self, *, attack_results: Sequence[AttackResult]) -> None: """ Insert a list of attack results into the memory storage. @@ -1720,13 +1629,11 @@ def get_attack_results( objective: Optional[str] = None, objective_sha256: Optional[Sequence[str]] = None, outcome: Optional[str] = None, - attack_class: Optional[str] = None, attack_classes: Optional[Sequence[str]] = None, atomic_attack_eval_hashes: Optional[Sequence[str]] = None, converter_classes: Optional[Sequence[str]] = None, converter_classes_match: Literal["all", "any"] = "all", has_converters: Optional[bool] = None, - targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str | Sequence[str]]] = None, identifier_filters: Optional[Sequence[IdentifierFilter]] = None, scenario_result_id: Optional[str] = None, @@ -1742,9 +1649,6 @@ def get_attack_results( Defaults to None. outcome (Optional[str], optional): The outcome to filter by (success, failure, undetermined). Defaults to None. - attack_class (Optional[str], optional): Deprecated. Filter by a single exact attack - class_name in attack_identifier. Equivalent to passing ``attack_classes=[attack_class]``. - Cannot be combined with ``attack_classes``. Defaults to None. attack_classes (Optional[Sequence[str]], optional): Filter by exact attack class_name in attack_identifier. Returns attacks matching ANY of the listed class names (OR logic, case-sensitive). An empty sequence applies no filter. Defaults to None. @@ -1766,13 +1670,6 @@ def get_attack_results( has_converters (Optional[bool], optional): Filter by converter presence. ``True`` returns only attacks that used at least one converter. ``False`` returns only attacks that used no converters. ``None`` applies no filter. Defaults to None. - targeted_harm_categories (Optional[Sequence[str]], optional): - A list of targeted harm categories to filter results by. - These targeted harm categories are associated with the prompts themselves, - meaning they are harm(s) we're trying to elicit with the prompt, - not necessarily one(s) that were found in the response. - By providing a list, this means ALL categories in the list must be present. - Defaults to None. labels (Optional[dict[str, str | Sequence[str]]], optional): Filter results by attack labels. Entries are AND-combined across label names; within a single entry, a string value is an equality match and a sequence value is @@ -1793,7 +1690,8 @@ def get_attack_results( Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. Raises: - ValueError: If both ``attack_class`` (deprecated) and ``attack_classes`` are provided. + ValueError: If any label key contains characters outside the allowlist + ``[A-Za-z0-9_.-]+``. """ # Handle empty list cases if attack_result_ids is not None and len(attack_result_ids) == 0: @@ -1801,18 +1699,6 @@ def get_attack_results( if objective_sha256 is not None and len(objective_sha256) == 0: return [] - if attack_class is not None and attack_classes is not None: - raise ValueError( - "Pass either `attack_class` (deprecated, singular) or `attack_classes` (plural), not both." - ) - if attack_class is not None and attack_classes is None: - print_deprecation_message( - old_item="get_attack_results(attack_class=...)", - new_item="get_attack_results(attack_classes=...)", - removed_in="0.15.0", - ) - attack_classes = [attack_class] - # Build non-list conditions conditions: list[ColumnElement[bool]] = [] if conversation_id: @@ -1888,15 +1774,6 @@ def get_attack_results( ) conditions.append(not_(empty_condition) if has_converters else empty_condition) - if targeted_harm_categories: - print_deprecation_message( - old_item="get_attack_results(targeted_harm_categories=...)", - new_item="get_attack_results(labels={'harm_category': [...]})", - removed_in="0.15.0", - ) - conditions.append( - self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories) - ) if labels: # Strip keys whose value is an empty sequence — an empty sequence means # "no OR-candidates", and per the docstring applies no filter for that diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index a5c277b3f1..64ec93d96a 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -233,7 +233,6 @@ class PromptMemoryEntry(Base): Can be the same number for multi-part requests or multi-part responses. timestamp (DateTime): The timestamp of the memory entry. labels (Dict[str, str]): The labels associated with the memory entry. Several can be standardized. - targeted_harm_categories (List[str]): The targeted harm categories for the memory entry. prompt_metadata (JSON): The metadata associated with the prompt. This can be specific to any scenarios. Because memory is how components talk with each other, this can be component specific. e.g. the URI from a file uploaded to a blob store, or a document type you want to upload. @@ -265,7 +264,6 @@ class PromptMemoryEntry(Base): timestamp = mapped_column(UTCDateTime, nullable=False) labels: Mapped[dict[str, str]] = mapped_column(JSON) prompt_metadata: Mapped[dict[str, str | int]] = mapped_column(JSON) - targeted_harm_categories: Mapped[list[str] | None] = mapped_column(JSON) converter_identifiers: Mapped[list[dict[str, str]] | None] = mapped_column(JSON) prompt_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON) attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON) @@ -308,7 +306,6 @@ def __init__(self, *, entry: MessagePiece) -> None: self.timestamp = entry.timestamp self.labels = entry.labels self.prompt_metadata = entry.prompt_metadata - self.targeted_harm_categories = entry.targeted_harm_categories self.converter_identifiers = _dump_identifiers(entry.converter_identifiers) self.prompt_target_identifier = _dump_identifier(entry.prompt_target_identifier) or {} self.attack_identifier = _dump_identifier(entry.attack_identifier) or {} @@ -331,7 +328,7 @@ def get_message_piece(self) -> MessagePiece: Convert this database entry back into a MessagePiece object. Returns: - MessagePiece: The reconstructed message piece with all its data and scores. + MessagePiece: The reconstructed message piece with all its data. """ # Reconstruct ComponentIdentifiers with the stored pyrit_version stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION @@ -358,13 +355,11 @@ def get_message_piece(self) -> MessagePiece: original_prompt_id=self.original_prompt_id, timestamp=self.timestamp, ) - # Assign deprecated containers post-construction so the DB-load path - # does not trip the ``MessagePiece`` deprecation-kwarg validator. - # ``validate_assignment=False`` on the model makes this assignment - # bypass the model_validator entirely. + # Assign deprecated ``labels`` container post-construction so the DB-load + # path does not trip the ``MessagePiece`` deprecation-kwarg validator. + # ``validate_assignment=False`` on the model makes this assignment bypass + # the model_validator entirely. message_piece.labels = self.labels or {} - message_piece.targeted_harm_categories = self.targeted_harm_categories or [] - message_piece.scores = [score.get_score() for score in self.scores] return message_piece def __str__(self) -> str: @@ -732,7 +727,8 @@ class AttackResultEntry(Base): id (Uuid): The unique identifier for the attack result entry. conversation_id (str): The unique identifier of the conversation that produced this result. objective (str): Natural-language description of the attacker's objective. - attack_identifier (dict[str, str]): Identifier of the attack (e.g., name, module). + atomic_attack_identifier (dict[str, Any] | None): Composite identifier of the attack + (technique, seeds, etc.). objective_sha256 (str): The SHA256 hash of the objective. last_response_id (Uuid): Foreign key to the last response MessagePiece. last_score_id (Uuid): Foreign key to the last score ScoreEntry. @@ -757,7 +753,6 @@ class AttackResultEntry(Base): id = mapped_column(CustomUUID, nullable=False, primary_key=True) conversation_id = mapped_column(String, nullable=False) objective = mapped_column(Unicode, nullable=False) - attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False) atomic_attack_identifier: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) objective_sha256 = mapped_column(String, nullable=True) last_response_id: Mapped[uuid.UUID | None] = mapped_column( @@ -822,10 +817,6 @@ def __init__(self, *, entry: AttackResult) -> None: self.id = uuid.UUID(entry.attack_result_id) self.conversation_id = entry.conversation_id self.objective = entry.objective - # Deprecated column: populated from atomic_attack_identifier for backward compatibility. - # Will be removed in 0.15.0. - _attack_strategy_id = entry.get_attack_strategy_identifier() - self.attack_identifier = _dump_identifier(_attack_strategy_id) or {} # Ensure eval_hash is set before truncation so it survives the DB round-trip if entry.atomic_attack_identifier and entry.atomic_attack_identifier.eval_hash is None: entry.atomic_attack_identifier = entry.atomic_attack_identifier.with_eval_hash( @@ -947,15 +938,7 @@ def get_attack_result(self) -> AttackResult: ) ) - # Reconstruct atomic_attack_identifier, with backward compatibility for - # legacy rows that only have the attack_identifier column. atomic_id = _load_identifier(self.atomic_attack_identifier) - if atomic_id is None and self.attack_identifier: - from pyrit.models import build_atomic_attack_identifier - - atomic_id = build_atomic_attack_identifier( - attack_identifier=ComponentIdentifier.model_validate(self.attack_identifier), - ) # Deserialize retry events from JSON retry_events = [] diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 461d2b871b..ecbd761e89 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -3,12 +3,11 @@ import json import logging -import uuid from collections.abc import MutableSequence, Sequence from contextlib import closing, suppress from datetime import datetime from pathlib import Path -from typing import Any, Literal, Optional, TypeVar, Union, cast +from typing import Any, Literal, Optional, TypeVar, Union from sqlalchemy import and_, create_engine, exists, func, or_, text from sqlalchemy.engine.base import Engine @@ -18,7 +17,6 @@ from sqlalchemy.pool import StaticPool from sqlalchemy.sql.expression import TextClause -from pyrit.common.deprecation import print_deprecation_message from pyrit.common.path import DB_DATA_PATH from pyrit.common.singleton import Singleton from pyrit.memory.memory_interface import MemoryInterface @@ -36,14 +34,6 @@ Model = TypeVar("Model") -class _ExportableConversationPiece: - def __init__(self, data: dict[str, Any]) -> None: - self._data = data - - def model_dump(self, *, mode: str = "python") -> dict[str, Any]: - return self._data - - class SQLiteMemory(MemoryInterface, metaclass=Singleton): """ A memory interface that uses SQLite as the backend database. @@ -484,100 +474,6 @@ def dispose_engine(self) -> None: finally: logging.raiseExceptions = previous_raise - def export_conversations( - self, - *, - attack_id: Optional[str | uuid.UUID] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[Sequence[str] | Sequence[uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[Sequence[str]] = None, - converted_values: Optional[Sequence[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[Sequence[str]] = None, - file_path: Optional[Path] = None, - export_type: str = "json", - ) -> Path: - """ - Export conversations and their associated scores from the database to a specified file. - - Returns: - Path: The path to the exported file. - - Raises: - ValueError: If the specified export format is not supported. - """ - print_deprecation_message( - old_item="SQLiteMemory.export_conversations", - new_item="the pyrit.output module or direct serialization of get_message_pieces results", - removed_in="0.15.0", - ) - # Import here to avoid circular import issues - from pyrit.memory.memory_exporter import MemoryExporter - - if not self.exporter: - self.exporter = MemoryExporter() - - # Get message pieces using the parent class method with appropriate filters - message_pieces = self.get_message_pieces( - attack_id=attack_id, - conversation_id=conversation_id, - prompt_ids=prompt_ids, - labels=labels, - sent_after=sent_after, - sent_before=sent_before, - original_values=original_values, - converted_values=converted_values, - data_type=data_type, - not_data_type=not_data_type, - converted_value_sha256=converted_value_sha256, - ) - - # Create the filename if not provided - if not file_path: - if attack_id: - file_name = f"{attack_id}.{export_type}" - elif conversation_id: - file_name = f"{conversation_id}.{export_type}" - else: - file_name = f"all_conversations.{export_type}" - file_path = Path(DB_DATA_PATH, file_name) - - # Get scores for the message pieces - if message_pieces: - message_piece_ids = [str(piece.id) for piece in message_pieces] - scores = self.get_prompt_scores(prompt_ids=message_piece_ids) - else: - scores = [] - - # Merge conversations and scores - create the data structure manually - merged_data = [] - for piece in message_pieces: - piece_data = piece.model_dump(mode="json") - # Find associated scores - piece_scores = [score for score in scores if score.message_piece_id == piece.id] - piece_data["scores"] = [score.model_dump(mode="json") for score in piece_scores] - merged_data.append(piece_data) - - if not merged_data: - if export_type == "json": - with open(file_path, "w", encoding="utf-8") as f: - json.dump(merged_data, f, indent=4) - elif export_type in self.exporter.export_strategies: - file_path.write_text("", encoding="utf-8") - else: - raise ValueError(f"Unsupported export format: {export_type}") - return file_path - - exportable_pieces = [_ExportableConversationPiece(data=piece_data) for piece_data in merged_data] - self.exporter.export_data( - cast("list[MessagePiece]", exportable_pieces), file_path=file_path, export_type=export_type - ) - return file_path - def print_schema(self) -> None: """ Print the schema of all tables in the SQLite database. @@ -592,55 +488,6 @@ def print_schema(self) -> None: default = f" DEFAULT {column.default}" if column.default else "" print(f" {column.name}: {column.type} {nullable}{default}") - def export_all_tables(self, *, export_type: str = "json") -> None: - """ - Export all table data using the specified exporter. - - Iterate over all tables, retrieves their data, and exports each to a file named after the table. - - Args: - export_type (str): The format to export the data in (defaults to "json"). - """ - print_deprecation_message( - old_item="SQLiteMemory.export_all_tables", - new_item="the pyrit.output module or direct serialization of table query results", - removed_in="0.15.0", - ) - table_models = self.get_all_table_models() - - for model in table_models: - data = self._query_entries(model) - table_name = model.__tablename__ - file_extension = f".{export_type}" - file_path = DB_DATA_PATH / f"{table_name}{file_extension}" - # Convert to list for exporter compatibility - self.exporter.export_data(list(data), file_path=file_path, export_type=export_type) - - def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: - """ - SQLite implementation for filtering AttackResults by targeted harm categories. - Uses json_extract() function specific to SQLite. - - Returns: - Any: A SQLAlchemy subquery for filtering by targeted harm categories. - """ - targeted_harm_categories_subquery = exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - # Exclude empty strings, None, and empty lists - PromptMemoryEntry.targeted_harm_categories.isnot(None), - PromptMemoryEntry.targeted_harm_categories != "", - PromptMemoryEntry.targeted_harm_categories != "[]", - and_( - *[ - func.json_extract(PromptMemoryEntry.targeted_harm_categories, "$").like(f'%"{category}"%') - for category in targeted_harm_categories - ] - ), - ) - ) - return targeted_harm_categories_subquery # noqa: RET504 - def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence[str]]) -> Any: """ SQLite implementation for filtering AttackResults by labels. diff --git a/pyrit/models/chat_message.py b/pyrit/models/chat_message.py index c2f801862d..b32f1cb06c 100644 --- a/pyrit/models/chat_message.py +++ b/pyrit/models/chat_message.py @@ -45,44 +45,6 @@ def to_dict(self) -> dict[str, Any]: """ return self.model_dump(exclude_none=True) - def to_json(self) -> str: - """ - Serialize the ChatMessage to a JSON string (deprecated, use ``model_dump_json`` instead). - - Returns: - A JSON string representation of the message. - - """ - from pyrit.common.deprecation import print_deprecation_message - - print_deprecation_message( - old_item="ChatMessage.to_json", - new_item="ChatMessage.model_dump_json", - removed_in="0.15.0", - ) - return self.model_dump_json() - - @classmethod - def from_json(cls, json_str: str) -> "ChatMessage": - """ - Deserialize a ChatMessage from a JSON string (deprecated, use ``model_validate_json`` instead). - - Args: - json_str: A JSON string representation of a ChatMessage. - - Returns: - A ChatMessage instance. - - """ - from pyrit.common.deprecation import print_deprecation_message - - print_deprecation_message( - old_item="ChatMessage.from_json", - new_item="ChatMessage.model_validate_json", - removed_in="0.15.0", - ) - return cls.model_validate_json(json_str) - class ChatMessagesDataset(BaseModel): """ diff --git a/pyrit/models/embeddings.py b/pyrit/models/embeddings.py index e51ae48f8e..cff37122b4 100644 --- a/pyrit/models/embeddings.py +++ b/pyrit/models/embeddings.py @@ -68,23 +68,6 @@ def load_from_file(file_path: Path) -> EmbeddingResponse: embedding_json_data = file_path.read_text(encoding="utf-8") return EmbeddingResponse.model_validate_json(embedding_json_data) - def to_json(self) -> str: - """ - Serialize this embedding response to JSON (deprecated, use ``model_dump_json`` instead). - - Returns: - str: JSON-encoded embedding response. - - """ - from pyrit.common.deprecation import print_deprecation_message - - print_deprecation_message( - old_item="EmbeddingResponse.to_json", - new_item="EmbeddingResponse.model_dump_json", - removed_in="0.15.0", - ) - return self.model_dump_json() - class EmbeddingSupport(ABC): """Protocol-like interface for classes that generate text embeddings.""" diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py index a8e6cb7eca..7eeccc708c 100644 --- a/pyrit/models/messages/message_piece.py +++ b/pyrit/models/messages/message_piece.py @@ -5,7 +5,7 @@ import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Optional from uuid import uuid4 from pydantic import ( @@ -25,7 +25,6 @@ ) from pyrit.models.score import ( # noqa: TC001 (runtime-required by Pydantic field annotations) ComponentIdentifierField, - Score, ) if TYPE_CHECKING: @@ -40,44 +39,12 @@ # These can be deleted entirely once their ``removed_in`` releases ship — the # Pydantic field definitions and ``extra="forbid"`` config will then reject # the kwargs naturally. -_DEPRECATED_KWARGS: tuple[tuple[str, str], ...] = ( - ("labels", "0.16.0"), - ("scorer_identifier", "0.15.0"), - ("scores", "0.15.0"), - ("targeted_harm_categories", "0.15.0"), -) - - -# ``ComponentIdentifierField`` (and ``Score``) are imported from ``pyrit.models.score`` -# above. Both round-trip through the flat dict storage shape via their own Pydantic -# serializers, so no local annotated aliases are needed here. - - -def __getattr__(name: str) -> Any: - """ - Lazily resolve deprecated module-level aliases. +_DEPRECATED_KWARGS: tuple[tuple[str, str], ...] = (("labels", "0.16.0"),) - Args: - name: The attribute name being accessed. - - Returns: - The resolved alias (currently only ``Originator``). - Raises: - AttributeError: If ``name`` is not a known deprecated alias. - """ - if name == "Originator": - print_deprecation_message( - old_item="pyrit.models.message_piece.Originator", - new_item=( - "inline Literal['attack', 'converter', 'undefined', 'scorer'] " - "(the type alias is being removed; the originator field itself is " - "deprecated and will be removed in 0.15.0)" - ), - removed_in="0.15.0", - ) - return Literal["attack", "converter", "undefined", "scorer"] - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +# ``ComponentIdentifierField`` is imported from ``pyrit.models.score`` above. +# It round-trips through the flat dict storage shape via its own Pydantic +# serializer, so no local annotated alias is needed here. class MessagePiece(BaseModel): @@ -107,16 +74,12 @@ class MessagePiece(BaseModel): converted_value_data_type: PromptDataType = "text" converted_value_sha256: Optional[str] = None response_error: PromptResponseError = "none" - originator: Literal["attack", "converter", "undefined", "scorer"] = "undefined" original_prompt_id: Optional[uuid.UUID] = None labels: dict[str, Any] = Field(default_factory=dict) - targeted_harm_categories: list[str] = Field(default_factory=list) prompt_metadata: dict[str, Any] = Field(default_factory=dict) converter_identifiers: list[ComponentIdentifierField] = Field(default_factory=list) prompt_target_identifier: Optional[ComponentIdentifierField] = None attack_identifier: Optional[ComponentIdentifierField] = None - scorer_identifier: Optional[ComponentIdentifierField] = None - scores: list[Score] = Field(default_factory=list) # When True, the memory layer skips persisting this piece. Used for ephemeral # pieces a scorer creates to score arbitrary content; ``exclude=True`` keeps @@ -145,14 +108,6 @@ def _warn_on_deprecated_kwargs(cls, data: Any) -> Any: new_item="MessagePiece(...)", removed_in=removed_in, ) - # ``originator`` is special: only warn when the caller explicitly - # opts into a non-default value. - if data.get("originator", "undefined") != "undefined": - print_deprecation_message( - old_item="MessagePiece(..., originator=...)", - new_item="MessagePiece(...)", - removed_in="0.15.0", - ) return data @model_validator(mode="before") diff --git a/pyrit/models/results/attack_result.py b/pyrit/models/results/attack_result.py index 8d2043f16d..648c837214 100644 --- a/pyrit/models/results/attack_result.py +++ b/pyrit/models/results/attack_result.py @@ -8,11 +8,10 @@ from enum import Enum from typing import Any, TypeVar -from pydantic import AwareDatetime, Field, model_validator +from pydantic import AwareDatetime, Field from pyrit.common.deprecation import print_deprecation_message from pyrit.models.conversation_reference import ConversationReference, ConversationType -from pyrit.models.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.models.identifiers.component_identifier import ComponentIdentifier from pyrit.models.messages.message_piece import MessagePiece from pyrit.models.results.strategy_result import StrategyResult @@ -112,54 +111,6 @@ class AttackResult(StrategyResult): attribution_parent_id: str | None = None attribution_data: dict[str, Any] | None = None - @model_validator(mode="before") - @classmethod - def _promote_deprecated_attack_identifier(cls, data: Any) -> Any: - """ - Promote the deprecated ``attack_identifier`` kwarg to ``atomic_attack_identifier``. - - Runs ahead of ``extra="forbid"`` so the legacy kwarg is consumed before - Pydantic would reject it. Emits a deprecation warning when present. - - Returns: - The input ``data`` with ``attack_identifier`` removed and (when it was - set and ``atomic_attack_identifier`` was not) promoted. - """ - if not isinstance(data, dict): - return data - data = dict(data) - attack_identifier = data.pop("attack_identifier", None) - if attack_identifier is not None: - print_deprecation_message( - old_item="AttackResult(attack_identifier=...)", - new_item="AttackResult(atomic_attack_identifier=...)", - removed_in="0.15.0", - ) - if data.get("atomic_attack_identifier") is None: - data["atomic_attack_identifier"] = build_atomic_attack_identifier( - attack_identifier=attack_identifier, - ) - return data - - @property - def attack_identifier(self) -> ComponentIdentifier | None: - """ - Deprecated: use ``get_attack_strategy_identifier()`` or ``atomic_attack_identifier`` instead. - - Returns the attack strategy ``ComponentIdentifier`` extracted from - ``atomic_attack_identifier``, emitting a deprecation warning. - - Returns: - ComponentIdentifier | None: The attack strategy identifier, or ``None``. - - """ - print_deprecation_message( - old_item="AttackResult.attack_identifier", - new_item="AttackResult.atomic_attack_identifier or get_attack_strategy_identifier()", - removed_in="0.15.0", - ) - return self.get_attack_strategy_identifier() - def get_attack_strategy_identifier(self) -> ComponentIdentifier | None: """ Return the attack strategy identifier from the composite atomic identifier. diff --git a/pyrit/prompt_converter/add_image_text_converter.py b/pyrit/prompt_converter/add_image_text_converter.py index d4ee0809c8..14439d73cd 100644 --- a/pyrit/prompt_converter/add_image_text_converter.py +++ b/pyrit/prompt_converter/add_image_text_converter.py @@ -9,15 +9,12 @@ from PIL import Image, ImageFont from PIL.ImageFont import FreeTypeFont -from pyrit.common.deprecation import print_deprecation_message from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory from pyrit.prompt_converter.base_image_text_converter import _BaseImageTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult logger = logging.getLogger(__name__) -_UNSET = object() - class AddImageTextConverter(_BaseImageTextConverter): """ @@ -38,13 +35,11 @@ class AddImageTextConverter(_BaseImageTextConverter): def __init__( self, - *args: str, - img_to_add: str = "", + *, + img_to_add: str, font_name: str | None = None, color: tuple[int, int, int] = (0, 0, 0), font_size: int | tuple[int, int] = 15, - x_pos: int = _UNSET, # type: ignore[ty:invalid-parameter-default] - y_pos: int = _UNSET, # type: ignore[ty:invalid-parameter-default] bounding_box: tuple[int, int, int, int] | None = None, rotation: float = 0.0, center_text: bool = False, @@ -53,16 +48,12 @@ def __init__( Initialize the converter with the image file path and text properties. Args: - *args: Deprecated positional argument for img_to_add. Use img_to_add=... instead. - Will be removed in version 0.15.0. img_to_add (str): File path of image to add text to. font_name (str | None): Path of font to use. Must be a TrueType font (.ttf). Defaults to None which uses Pillow's built-in default font. color (tuple[int, int, int]): Color to print text in, using RGB values. Defaults to (0, 0, 0). font_size (int | tuple[int, int]): Font size as a fixed int, or a (min, max) tuple for automatic sizing that shrinks from max down to min to fit text in the bounding box. Defaults to 15. - x_pos (int): Deprecated. Use bounding_box instead. Will be removed in version 0.15.0. - y_pos (int): Deprecated. Use bounding_box instead. Will be removed in version 0.15.0. bounding_box (tuple[int, int, int, int] | None): Optional (x1, y1, x2, y2) region to constrain text within. When not set, the full image is used with a default margin. Defaults to None. @@ -71,38 +62,9 @@ def __init__( Defaults to False. Raises: - TypeError: If more than one positional argument is passed, or if img_to_add - is passed as both positional and keyword argument. ValueError: If img_to_add is empty, font_name doesn't end with ".ttf", - font_size tuple is invalid, bounding_box coordinates are invalid, - or x_pos/y_pos are used together with bounding_box. + font_size tuple is invalid, or bounding_box coordinates are invalid. """ - if args: - if len(args) > 1: - raise TypeError(f"AddImageTextConverter takes at most 1 positional argument, got {len(args)}") - if img_to_add: - raise TypeError("Cannot pass img_to_add as both positional and keyword argument") - print_deprecation_message( - old_item="Passing img_to_add as a positional argument to AddImageTextConverter", - new_item="AddImageTextConverter(img_to_add=...) keyword argument", - removed_in="0.15.0", - ) - img_to_add = args[0] - if x_pos is not _UNSET or y_pos is not _UNSET: - if bounding_box is not None: - raise ValueError( - "Cannot pass x_pos/y_pos together with bounding_box. Use bounding_box=(x, y, x2, y2) instead." - ) - print_deprecation_message( - old_item="AddImageTextConverter(x_pos=..., y_pos=...)", - new_item="AddImageTextConverter(bounding_box=(x1, y1, x2, y2))", - removed_in="0.15.0", - ) - # Resolve defaults after deprecation check - if x_pos is _UNSET: - x_pos = 10 - if y_pos is _UNSET: - y_pos = 10 if not img_to_add: raise ValueError("Please provide valid image path") if font_name is not None and not font_name.endswith(".ttf"): @@ -118,8 +80,6 @@ def __init__( self._font_load_failed = font_name is None self._font = self._load_font() self._color = color - self._x_pos = x_pos - self._y_pos = y_pos self._bounding_box = bounding_box self._rotation = rotation self._center_text = center_text @@ -251,7 +211,7 @@ def _add_text_to_image(self, text: str) -> Image.Image: else: # Default to full image with margin to preserve backward-compatible behavior margin = self._DEFAULT_MARGIN - bounding_box = (self._x_pos, self._y_pos, image.width - margin, image.height - margin) + bounding_box = (10, 10, image.width - margin, image.height - margin) if self._auto_font_size: x1, y1, x2, y2 = bounding_box diff --git a/pyrit/prompt_converter/add_text_image_converter.py b/pyrit/prompt_converter/add_text_image_converter.py index 57964c8e75..7bdf07a4ea 100644 --- a/pyrit/prompt_converter/add_text_image_converter.py +++ b/pyrit/prompt_converter/add_text_image_converter.py @@ -10,7 +10,6 @@ from PIL import Image, ImageFont from PIL.ImageFont import FreeTypeFont -from pyrit.common.deprecation import print_deprecation_message from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory from pyrit.prompt_converter.base_image_text_converter import _BaseImageTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult @@ -31,8 +30,8 @@ class AddTextImageConverter(_BaseImageTextConverter): def __init__( self, - *args: str, - text_to_add: str = "", + *, + text_to_add: str, font_name: str = "helvetica.ttf", color: tuple[int, int, int] = (0, 0, 0), font_size: int = 15, @@ -43,8 +42,6 @@ def __init__( Initialize the converter with the text and text properties. Args: - *args: Deprecated positional argument for text_to_add. Use text_to_add=... instead. - Will be removed in version 0.15.0. text_to_add (str): Text to add to an image. font_name (str): Path of font to use. Must be a TrueType font (.ttf). Defaults to "helvetica.ttf". color (tuple): Color to print text in, using RGB values. Defaults to (0, 0, 0). @@ -53,21 +50,8 @@ def __init__( y_pos (int): Y coordinate to place text in (0 is upper most). Defaults to 10. Raises: - TypeError: If more than one positional argument is passed, or if text_to_add - is passed as both positional and keyword argument. ValueError: If ``text_to_add`` is empty, or if ``font_name`` does not end with ".ttf". """ - if args: - if len(args) > 1: - raise TypeError(f"AddTextImageConverter takes at most 1 positional argument, got {len(args)}") - if text_to_add: - raise TypeError("Cannot pass text_to_add as both positional and keyword argument") - print_deprecation_message( - old_item="Passing text_to_add as a positional argument to AddTextImageConverter", - new_item="AddTextImageConverter(text_to_add=...) keyword argument", - removed_in="0.15.0", - ) - text_to_add = args[0] if text_to_add.strip() == "": raise ValueError("Please provide valid text_to_add value") if not font_name.endswith(".ttf"): diff --git a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py index ae4190d59a..4dcf5c889d 100644 --- a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py +++ b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py @@ -10,9 +10,8 @@ if TYPE_CHECKING: import azure.cognitiveservices.speech as speechsdk -from pyrit.auth.azure_auth import get_speech_config, get_speech_config_async +from pyrit.auth.azure_auth import get_speech_config_async from pyrit.common import default_values -from pyrit.common.deprecation import print_deprecation_message from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -50,7 +49,6 @@ def __init__( azure_speech_region: Optional[str] = None, azure_speech_key: Optional[str | Callable[[], str | Awaitable[str]]] = None, azure_speech_resource_id: Optional[str] = None, - use_entra_auth: Optional[bool] = None, recognition_language: str = "en-US", ) -> None: """ @@ -66,16 +64,6 @@ def __init__( If omitted, Entra ID auth via ``DefaultAzureCredential`` is used automatically. azure_speech_resource_id (str, Optional): The resource ID for accessing the service when using Entra ID auth. Required when using a callable token provider or when no API key is available. - use_entra_auth (bool, Optional): **Deprecated.** Will be removed in 0.15.0. - Authentication is now selected automatically based on what you pass to - ``azure_speech_key`` (and ``AZURE_SPEECH_KEY`` env var): - - - Pass a **string** API key (or set ``AZURE_SPEECH_KEY``) to use API-key auth. - - Pass a **callable token provider** (sync or async returning a token string) - to use Entra ID with a custom token; ``azure_speech_resource_id`` must also - be set. - - Omit ``azure_speech_key`` entirely to use Entra ID via - ``DefaultAzureCredential``; ``azure_speech_resource_id`` must be set. recognition_language (str): Recognition voice language. Defaults to "en-US". For more on supported languages, see the following link: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support @@ -83,16 +71,6 @@ def __init__( Raises: ValueError: If the required environment variables or parameters are not set. """ - if use_entra_auth is not None: - print_deprecation_message( - old_item="AzureSpeechAudioToTextConverter(use_entra_auth=...)", - new_item=( - "AzureSpeechAudioToTextConverter(" - "azure_speech_key=)" - ), - removed_in="0.15.0", - ) - self._azure_speech_region: str = default_values.get_required_value( env_var_name=self.AZURE_SPEECH_REGION_ENVIRONMENT_VARIABLE, passed_value=azure_speech_region, @@ -182,36 +160,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "audi raise return ConverterResult(output_text=transcript, output_type="text") - def recognize_audio(self, audio_bytes: bytes) -> str: - """ - Recognize audio file and return transcribed text. - - .. deprecated:: - Use ``convert_async`` instead, which resolves token providers correctly. - This method does not support callable token providers. - - Args: - audio_bytes (bytes): Audio bytes input. - - Returns: - str: Transcribed text. - - Raises: - ModuleNotFoundError: If the azure.cognitiveservices.speech module is not installed. - """ - if self._token_provider: - print_deprecation_message( - old_item="AzureSpeechAudioToTextConverter.recognize_audio", - new_item="AzureSpeechAudioToTextConverter.convert_async", - removed_in="0.15.0", - ) - speech_config = get_speech_config( - resource_id=self._azure_speech_resource_id, - key=self._azure_speech_key, - region=self._azure_speech_region, - ) - return self._recognize_audio(audio_bytes=audio_bytes, speech_config=speech_config) - def _recognize_audio(self, *, audio_bytes: bytes, speech_config: "speechsdk.SpeechConfig") -> str: """ Recognize audio from bytes using the given speech config. diff --git a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py index 03463a3b62..689a9dbb0d 100644 --- a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py +++ b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py @@ -10,7 +10,6 @@ from pyrit.auth.azure_auth import get_speech_config_async from pyrit.common import default_values -from pyrit.common.deprecation import print_deprecation_message from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -51,7 +50,6 @@ def __init__( azure_speech_region: Optional[str] = None, azure_speech_key: Optional[str | Callable[[], str | Awaitable[str]]] = None, azure_speech_resource_id: Optional[str] = None, - use_entra_auth: Optional[bool] = None, synthesis_language: str = "en_US", synthesis_voice_name: str = "en-US-AvaNeural", output_format: AzureSpeechAudioFormat = "wav", @@ -69,16 +67,6 @@ def __init__( If omitted, Entra ID auth via ``DefaultAzureCredential`` is used automatically. azure_speech_resource_id (str, Optional): The resource ID for accessing the service when using Entra ID auth. Required when using a callable token provider or when no API key is available. - use_entra_auth (bool, Optional): **Deprecated.** Will be removed in 0.15.0. - Authentication is now selected automatically based on what you pass to - ``azure_speech_key`` (and ``AZURE_SPEECH_KEY`` env var): - - - Pass a **string** API key (or set ``AZURE_SPEECH_KEY``) to use API-key auth. - - Pass a **callable token provider** (sync or async returning a token string) - to use Entra ID with a custom token; ``azure_speech_resource_id`` must also - be set. - - Omit ``azure_speech_key`` entirely to use Entra ID via - ``DefaultAzureCredential``; ``azure_speech_resource_id`` must be set. synthesis_language (str): Synthesis voice language. synthesis_voice_name (str): Synthesis voice name. For more details see the following link for synthesis language and synthesis voice: @@ -88,16 +76,6 @@ def __init__( Raises: ValueError: If the required environment variables or parameters are not set. """ - if use_entra_auth is not None: - print_deprecation_message( - old_item="AzureSpeechTextToAudioConverter(use_entra_auth=...)", - new_item=( - "AzureSpeechTextToAudioConverter(" - "azure_speech_key=)" - ), - removed_in="0.15.0", - ) - self._azure_speech_region: str = default_values.get_required_value( env_var_name=self.AZURE_SPEECH_REGION_ENVIRONMENT_VARIABLE, passed_value=azure_speech_region, diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index b0d42c9a76..366b3f2f7f 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -36,7 +36,6 @@ get_http_target_regex_matching_callback_function, ) from pyrit.prompt_target.http_target.httpx_api_target import HTTPXAPITarget -from pyrit.prompt_target.hugging_face.hugging_face_endpoint_target import HuggingFaceEndpointTarget from pyrit.prompt_target.openai.openai_chat_audio_config import OpenAIChatAudioConfig from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget from pyrit.prompt_target.openai.openai_completion_target import OpenAICompletionTarget @@ -87,7 +86,6 @@ def __getattr__(name: str) -> object: "HTTPTarget", "HTTPXAPITarget", "HuggingFaceChatTarget", - "HuggingFaceEndpointTarget", "limit_requests_per_minute", "OpenAICompletionTarget", "OpenAIChatAudioConfig", diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 80ca154b83..0bb58fa2d1 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -9,26 +9,20 @@ from pyrit.auth import ensure_async_token_provider from pyrit.common import default_values, net_utility -from pyrit.common.deprecation import print_deprecation_message from pyrit.exceptions import ( EmptyResponseException, RateLimitException, handle_bad_request_exception, pyrit_target_retry, ) -from pyrit.message_normalizer import ChatMessageNormalizer, MessageListNormalizer +from pyrit.message_normalizer import ChatMessageNormalizer from pyrit.models import ( ComponentIdentifier, Message, construct_response_from_request, ) from pyrit.prompt_target.common.prompt_target import PromptTarget -from pyrit.prompt_target.common.target_capabilities import ( - CapabilityHandlingPolicy, - CapabilityName, - TargetCapabilities, - UnsupportedCapabilityBehavior, -) +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p @@ -65,7 +59,6 @@ def __init__( endpoint: str | None = None, api_key: str | Callable[[], str | Awaitable[str]] | None = None, model_name: str = "", - message_normalizer: MessageListNormalizer[Any] | None = None, max_new_tokens: int = 400, temperature: float = 1.0, top_p: float = 1.0, @@ -88,10 +81,6 @@ def __init__( Defaults to the value of the ``AZURE_ML_KEY`` environment variable. model_name (str): The name of the model being used (e.g., "Llama-3.2-3B-Instruct"). Used for identification purposes. Defaults to empty string. - message_normalizer (MessageListNormalizer[Any] | None): **Deprecated.** Use - ``custom_configuration`` with ``CapabilityHandlingPolicy`` instead. Previously used for - models that do not allow system prompts. - Will be removed in 0.15.0. max_new_tokens (int): The maximum number of tokens to generate in the response. Defaults to 400. temperature (float): The temperature for generating diverse responses. 1.0 is most random, @@ -111,46 +100,11 @@ def __init__( Note that the link above may not be comprehensive, and specific acceptable parameters may be model-dependent. If a model does not accept a certain parameter that is passed in, it will be skipped without throwing an error. - - Raises: - ValueError: If both `message_normalizer` and `custom_configuration` are provided, - since `message_normalizer` is deprecated and the two configurations may conflict. """ endpoint_value = default_values.get_required_value( env_var_name=self.endpoint_uri_environment_variable, passed_value=endpoint ) - # Translate legacy message_normalizer into TargetConfiguration - if message_normalizer is not None: - if custom_configuration is not None: - raise ValueError( - "Cannot specify both 'message_normalizer' and 'custom_configuration'. " - "Use 'custom_configuration' only; 'message_normalizer' is deprecated and " - "will be removed in 0.15.0." - ) - print_deprecation_message( - old_item="AzureMLChatTarget(message_normalizer=...)", - new_item="AzureMLChatTarget(custom_configuration=...)", - removed_in="0.15.0", - ) - # The legacy message_normalizer was primarily used to handle system prompts - # for models that don't support them (e.g. GenericSystemSquashNormalizer). - # We translate it into a TargetConfiguration that marks system_prompt as - # unsupported + ADAPT so the pipeline invokes the user's normalizer. - default_caps = self._DEFAULT_CONFIGURATION.capabilities - default_behaviors = dict(self._DEFAULT_CONFIGURATION.policy.behaviors) - default_behaviors[CapabilityName.SYSTEM_PROMPT] = UnsupportedCapabilityBehavior.ADAPT - custom_configuration = TargetConfiguration( - capabilities=TargetCapabilities( - supports_multi_message_pieces=default_caps.supports_multi_message_pieces, - supports_editable_history=default_caps.supports_editable_history, - supports_multi_turn=default_caps.supports_multi_turn, - supports_system_prompt=False, - ), - policy=CapabilityHandlingPolicy(behaviors=default_behaviors), - normalizer_overrides={CapabilityName.SYSTEM_PROMPT: message_normalizer}, - ) - PromptTarget.__init__( self, max_requests_per_minute=max_requests_per_minute, diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py deleted file mode 100644 index eca7899ef7..0000000000 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import logging -import warnings - -from pyrit.common.deprecation import print_deprecation_message -from pyrit.common.net_utility import make_request_and_raise_if_error_async -from pyrit.models import ComponentIdentifier, Message, construct_response_from_request -from pyrit.prompt_target.common.prompt_target import PromptTarget -from pyrit.prompt_target.common.target_configuration import TargetConfiguration -from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p - -logger = logging.getLogger(__name__) - - -class HuggingFaceEndpointTarget(PromptTarget): - """ - The HuggingFaceEndpointTarget interacts with HuggingFace models hosted on cloud endpoints. - - .. deprecated:: 0.13.0 - Use ``OpenAIChatTarget`` with ``endpoint="https://router.huggingface.co/v1"`` - and ``api_key=HUGGINGFACE_TOKEN`` instead. The HuggingFace Inference Providers API - is OpenAI-compatible, making this target redundant. Will be removed in v0.15.0. - """ - - def __init__( - self, - *, - hf_token: str, - endpoint: str, - model_id: str, - max_tokens: int = 400, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int | None = None, - do_sample: bool | None = None, - repetition_penalty: float | None = None, - max_requests_per_minute: int | None = None, - verbose: bool = False, - custom_configuration: TargetConfiguration | None = None, - ) -> None: - """ - Initialize the HuggingFaceEndpointTarget with API credentials and model parameters. - - Args: - hf_token (str): The Hugging Face token for authenticating with the Hugging Face endpoint. - endpoint (str): The endpoint URL for the Hugging Face model. - model_id (str): The model ID to be used at the endpoint. - max_tokens (int): The maximum number of tokens to generate. Defaults to 400. - temperature (float): The sampling temperature to use. Defaults to 1.0. - top_p (float): The cumulative probability for nucleus sampling. Defaults to 1.0. - top_k (int | None): Top-K sampling parameter. Only used when do_sample is True. - Defaults to None (uses model default). - do_sample (bool | None): Whether to use sampling instead of greedy decoding. - Defaults to None. - repetition_penalty (float | None): Penalty for repeating tokens. Values > 1.0 - discourage repetition. Defaults to None (uses model default). - max_requests_per_minute (int | None): The maximum number of requests per minute. Defaults to None. - verbose (bool): Flag to enable verbose logging. Defaults to False. - custom_configuration (TargetConfiguration | None): Custom configuration for this target instance. - """ - print_deprecation_message( - old_item=HuggingFaceEndpointTarget, - new_item="OpenAIChatTarget with endpoint='https://router.huggingface.co/v1'", - removed_in="0.15.0", - ) - - super().__init__( - max_requests_per_minute=max_requests_per_minute, - verbose=verbose, - endpoint=endpoint, - model_name=model_id, - custom_configuration=custom_configuration, - ) - - validate_temperature(temperature) - validate_top_p(top_p) - - self.hf_token = hf_token - self.endpoint = endpoint - self.model_id = model_id - self.max_tokens = max_tokens - self._temperature = temperature - self._top_p = top_p - self._top_k = top_k - self._do_sample = do_sample - self._repetition_penalty = repetition_penalty - - self._warn_if_sampling_params_without_do_sample() - - def _build_identifier(self) -> ComponentIdentifier: - """ - Build the identifier with HuggingFace endpoint-specific parameters. - - Returns: - ComponentIdentifier: The identifier for this target instance. - """ - return self._create_identifier( - params={ - "temperature": self._temperature, - "top_p": self._top_p, - "top_k": self._top_k, - "do_sample": self._do_sample, - "repetition_penalty": self._repetition_penalty, - "max_tokens": self.max_tokens, - }, - ) - - @limit_requests_per_minute - async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: - """ - Send a normalized prompt asynchronously to a cloud-based HuggingFace model endpoint. - - Args: - normalized_conversation (list[Message]): The full conversation - (history + current message) after running the normalization - pipeline. The current message is the last element. - - Returns: - list[Message]: A list containing the response object with generated text pieces. - - Raises: - ValueError: If the response from the Hugging Face API is not successful. - Exception: If an error occurs during the HTTP request to the Hugging Face endpoint. - """ - message = normalized_conversation[-1] - request = message.message_pieces[0] - headers = {"Authorization": f"Bearer {self.hf_token}"} - parameters: dict[str, object] = { - "max_tokens": self.max_tokens, - "temperature": self._temperature, - "top_p": self._top_p, - } - if self._top_k is not None: - parameters["top_k"] = self._top_k - if self._do_sample is not None: - parameters["do_sample"] = self._do_sample - if self._repetition_penalty is not None: - parameters["repetition_penalty"] = self._repetition_penalty - payload: dict[str, object] = { - "inputs": request.converted_value, - "parameters": parameters, - } - - logger.info(f"Sending the following prompt to the cloud endpoint: {request.converted_value}") - - try: - # Use the utility method to make the request - response = await make_request_and_raise_if_error_async( - endpoint_uri=self.endpoint, - method="POST", - request_body=payload, - headers=headers, - post_type="json", - ) - - response_data = response.json() - - # Check if the response is a list and handle appropriately - if isinstance(response_data, list): - # Access the first element if it's a list and extract 'generated_text' safely - response_message = response_data[0].get("generated_text", "") - else: - response_message = response_data.get("generated_text", "") - - message = construct_response_from_request( - request=request, - response_text_pieces=[response_message], - prompt_metadata={"model_id": self.model_id}, - ) - return [message] - - except Exception as e: - logger.error(f"Error occurred during HTTP request to the Hugging Face endpoint: {e}") - raise - - def _validate_request(self, *, normalized_conversation: list[Message]) -> None: - """ - Validate the provided message. - - Args: - normalized_conversation: The normalized conversation to validate. - - Raises: - ValueError: If the request is not valid for this target. - """ - message = normalized_conversation[-1] - n_pieces = len(message.message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") - - def _warn_if_sampling_params_without_do_sample(self) -> None: - """ - Emit a warning when sampling parameters are set but do_sample is not explicitly True. - - Sampling-specific parameters (temperature != 1.0, top_p != 1.0, top_k) are - ignored by HuggingFace unless do_sample=True. - """ - has_sampling_override = self._temperature != 1.0 or self._top_p != 1.0 or self._top_k is not None - if has_sampling_override and self._do_sample is not True: - warnings.warn( - "Sampling parameters (temperature, top_p, top_k) are set but do_sample is not True. " - "HuggingFace ignores these parameters during greedy decoding. " - "Set do_sample=True to enable sampling.", - UserWarning, - stacklevel=3, - ) diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index bc56207c7a..eefca27fe8 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -4,9 +4,6 @@ import logging from typing import Any, Literal, Optional -import httpx - -from pyrit.common.deprecation import print_deprecation_message from pyrit.exceptions import ( EmptyResponseException, pyrit_target_retry, @@ -48,11 +45,6 @@ class OpenAIImageTarget(OpenAITarget): ) ) - # DALL-E-only image sizes that are deprecated in favor of GPT image model sizes. - _DEPRECATED_SIZES = {"256x256", "512x512", "1792x1024", "1024x1792"} - # DALL-E-only quality values that are deprecated in favor of GPT image model values. - _DEPRECATED_QUALITY_VALUES = {"standard", "hd"} - # Grandfathered: positional params predate the kwargs-only contract; the # sandwiched ``*args``/``**kwargs`` shape forwards extras to ``OpenAITarget``. # TODO: remove this opt-out and move ``*args`` up to immediately after @@ -67,14 +59,9 @@ def __init__( "1024x1024", "1536x1024", "1024x1536", - "256x256", - "512x512", - "1792x1024", - "1024x1792", ] = "1024x1024", output_format: Optional[Literal["png", "jpeg", "webp"]] = None, - quality: Optional[Literal["auto", "low", "medium", "high", "standard", "hd"]] = None, - style: Optional[Literal["natural", "vivid"]] = None, + quality: Optional[Literal["auto", "low", "medium", "high"]] = None, background: Optional[Literal["transparent", "opaque", "auto"]] = None, custom_configuration: Optional[TargetConfiguration] = None, *args: Any, @@ -98,21 +85,12 @@ def __init__( image_size (Literal, Optional): The size of the generated image. GPT image models support "auto", "1024x1024", "1536x1024", and "1024x1536". Defaults to "1024x1024". - - **Deprecated sizes (will be removed in v0.15.0):** - "256x256", "512x512" (DALL-E-2 only), "1792x1024", "1024x1792" (DALL-E-3 only). output_format (Literal["png", "jpeg", "webp"], Optional): The output format of the generated images. Default is to not specify (which will use the model's default format, e.g. PNG). quality (Literal["auto", "low", "medium", "high"], Optional): The quality of the generated images. GPT image models support "auto", "high", "medium", and "low". Default is to not specify, which will use "auto" behavior for platform OpenAI endpoints and "high" behavior for Azure OpenAI endpoints. - - **Deprecated values (will be removed in v0.15.0):** - "standard", "hd" (DALL-E only). - style (Literal["natural", "vivid"], Optional): **Deprecated.** This parameter was only - supported for DALL-E-3 and is not supported by GPT image models. - Will be removed in v0.15.0. background (Literal["transparent", "opaque", "auto"], Optional): Background behavior for the generated image. When "transparent", the output format must support transparency ("png" or "webp"). When "auto", the model automatically determines the best background. @@ -129,36 +107,6 @@ def __init__( ValueError: If background is "transparent" and output_format is "jpeg", since JPEG does not support transparency. """ - # Emit deprecation warnings for DALL-E-only parameters - if style is not None: - print_deprecation_message( - old_item="OpenAIImageTarget(style=...)", - new_item="OpenAIImageTarget(...) without style (DALL-E-3 is being shut down on 2026-05-12)", - removed_in="0.15.0", - ) - - if image_size in self._DEPRECATED_SIZES: - print_deprecation_message( - old_item=f"OpenAIImageTarget(image_size='{image_size}')", - new_item=( - "OpenAIImageTarget(image_size=...) with a GPT image model value " - "('auto', '1024x1024', '1536x1024', or '1024x1536'); " - "DALL-E models are being shut down on 2026-05-12" - ), - removed_in="0.15.0", - ) - - if quality is not None and quality in self._DEPRECATED_QUALITY_VALUES: - print_deprecation_message( - old_item=f"OpenAIImageTarget(quality='{quality}')", - new_item=( - "OpenAIImageTarget(quality=...) with a GPT image model value " - "('auto', 'low', 'medium', or 'high'); " - "DALL-E models are being shut down on 2026-05-12" - ), - removed_in="0.15.0", - ) - if background == "transparent" and output_format == "jpeg": raise ValueError( "background='transparent' requires an output format that supports transparency ('png' or 'webp'). " @@ -167,7 +115,6 @@ def __init__( self.output_format = output_format self.quality = quality - self.style = style self.image_size = image_size self.background = background @@ -200,7 +147,6 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "image_size": self.image_size, "quality": self.quality, - "style": self.style, "background": self.background, }, ) @@ -262,8 +208,6 @@ async def _send_generate_request_async(self, message: Message) -> Message: image_generation_args["output_format"] = self.output_format if self.quality: image_generation_args["quality"] = self.quality - if self.style: - image_generation_args["style"] = self.style if self.background: image_generation_args["background"] = self.background @@ -317,8 +261,6 @@ async def _send_edit_request_async(self, message: Message) -> Message: image_edit_args["output_format"] = self.output_format if self.quality: image_edit_args["quality"] = self.quality - if self.style: - image_edit_args["style"] = self.style if self.background: image_edit_args["background"] = self.background @@ -369,25 +311,12 @@ async def _get_image_bytes_async(self, image_data: Any) -> bytes: bytes: The raw image bytes. Raises: - EmptyResponseException: If neither base64 data nor URL is available. + EmptyResponseException: If base64 data is not available. """ b64_data = getattr(image_data, "b64_json", None) if b64_data: return base64.b64decode(b64_data) - # Legacy fallback for DALL-E models that may return URLs instead of base64. - # This code path is deprecated and will be removed in v0.15.0. - image_url = getattr(image_data, "url", None) - if image_url: - logger.warning( - "Image model returned a URL instead of base64 data. " - "This is a DALL-E behavior that is deprecated. Downloading image from URL." - ) - async with httpx.AsyncClient() as http_client: - image_response = await http_client.get(image_url) - image_response.raise_for_status() - return image_response.content - raise EmptyResponseException(message="The image generation returned an empty response.") def _validate_request(self, *, normalized_conversation: list[Message]) -> None: diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index 1be2caea53..2bbc17429e 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -29,7 +29,7 @@ ) # Import scenario submodules directly and register them as virtual subpackages -# This allows: from pyrit.scenario.airt import ContentHarms +# This allows: from pyrit.scenario.airt import Jailbreak # without needing separate pyrit/scenario/airt/ directories from pyrit.scenario.scenarios import adaptive as _adaptive_module from pyrit.scenario.scenarios import airt as _airt_module diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index de557d724e..47fa029f81 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -190,32 +190,6 @@ def get_aggregate_strategies(cls: type[T]) -> list[T]: aggregate_tags = cls.get_aggregate_tags() return [s for s in cls if s.value in aggregate_tags] - @classmethod - def normalize_strategies(cls: type[T], strategies: set[T]) -> set[T]: - """ - Normalize a set of attack strategies by expanding aggregate tags. - - This method processes a set of strategies and expands any aggregate tags - (like EASY, MODERATE, DIFFICULT or FAST, MEDIUM) into their constituent concrete strategies. - The aggregate tag markers themselves are removed from the result. - - The special "all" tag is automatically supported and expands to all non-aggregate strategies. - - Args: - strategies (Set[T]): The initial set of attack strategies, which may include - aggregate tags. - - Returns: - Set[T]: The normalized set of concrete attack strategies with aggregate tags - expanded and removed. - """ - print_deprecation_message( - old_item="ScenarioStrategy.normalize_strategies", - new_item="ScenarioStrategy.expand", - removed_in="0.15.0", - ) - return set(cls.expand(strategies)) - @classmethod def expand(cls: type[T], strategies: set[T]) -> list[T]: """ diff --git a/pyrit/scenario/scenarios/adaptive/selectors/epsilon_greedy.py b/pyrit/scenario/scenarios/adaptive/selectors/epsilon_greedy.py index a415182733..bb701a6a60 100644 --- a/pyrit/scenario/scenarios/adaptive/selectors/epsilon_greedy.py +++ b/pyrit/scenario/scenarios/adaptive/selectors/epsilon_greedy.py @@ -132,7 +132,6 @@ async def select_async( stats = compute_technique_stats( technique_eval_hashes=technique_list, scenario_result_id=effective_run_id, - targeted_harm_categories=self._scope.targeted_harm_categories, ) chosen: list[str] = [] diff --git a/pyrit/scenario/scenarios/adaptive/selectors/technique_selector.py b/pyrit/scenario/scenarios/adaptive/selectors/technique_selector.py index eada0fb5ed..56b03e26f3 100644 --- a/pyrit/scenario/scenarios/adaptive/selectors/technique_selector.py +++ b/pyrit/scenario/scenarios/adaptive/selectors/technique_selector.py @@ -19,9 +19,8 @@ class SelectorScope: queries when estimating technique success rates. All fields default to "no restriction"; combine fields to narrow the - scope (e.g. current run only, same harm category). Filter values flow - through ``compute_technique_stats`` to - ``MemoryInterface.get_attack_results``. + scope (e.g. current run only). Filter values flow through + ``compute_technique_stats`` to ``MemoryInterface.get_attack_results``. The scope is held by the selector at construction time. The per-call ``scenario_result_id`` is supplied by the dispatcher and is forwarded @@ -38,10 +37,6 @@ class SelectorScope: """Restrict to the dispatcher-supplied ``scenario_result_id`` for the in-flight run. When ``False`` (default), query across all runs.""" - targeted_harm_categories: Sequence[str] | None = None - """Filter to results whose prompts targeted these harm categories. - ``None`` means no harm-category filter.""" - @classmethod def all_runs(cls) -> SelectorScope: """ diff --git a/pyrit/scenario/scenarios/airt/__init__.py b/pyrit/scenario/scenarios/airt/__init__.py index 0bd10033fd..d1efdbd82d 100644 --- a/pyrit/scenario/scenarios/airt/__init__.py +++ b/pyrit/scenario/scenarios/airt/__init__.py @@ -3,8 +3,7 @@ """AIRT scenario classes.""" -import importlib -from typing import TYPE_CHECKING, Any +from typing import Any from pyrit.scenario.scenarios.airt.cyber import Cyber, _build_cyber_strategy from pyrit.scenario.scenarios.airt.jailbreak import Jailbreak, JailbreakStrategy @@ -13,18 +12,13 @@ from pyrit.scenario.scenarios.airt.rapid_response import RapidResponse, _build_rapid_response_strategy from pyrit.scenario.scenarios.airt.scam import Scam, ScamStrategy -if TYPE_CHECKING: - from pyrit.scenario.scenarios.airt.rapid_response import RapidResponse as ContentHarms - - ContentHarmsStrategy = Any - def __getattr__(name: str) -> Any: """ - Lazily resolve dynamic strategy classes and deprecated aliases. + Lazily resolve dynamic strategy classes. Returns: - Any: The resolved strategy class or deprecated alias. + Any: The resolved strategy class. Raises: AttributeError: If the attribute name is not recognized. @@ -35,18 +29,10 @@ def __getattr__(name: str) -> Any: return _build_leakage_strategy() if name == "CyberStrategy": return _build_cyber_strategy() - if name in ("ContentHarms", "ContentHarmsStrategy"): - # Delegate to the content_harms module so it can emit the deprecation - # warning. We import lazily here to avoid triggering the warning on - # every `import pyrit.scenario.scenarios.airt`. - content_harms = importlib.import_module("pyrit.scenario.scenarios.airt.content_harms") - return getattr(content_harms, name) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") __all__ = [ - "ContentHarms", - "ContentHarmsStrategy", "Cyber", "CyberStrategy", "Jailbreak", diff --git a/pyrit/scenario/scenarios/airt/content_harms.py b/pyrit/scenario/scenarios/airt/content_harms.py deleted file mode 100644 index 7704ed9d4f..0000000000 --- a/pyrit/scenario/scenarios/airt/content_harms.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Deprecated — use ``rapid_response`` instead. - -``ContentHarms`` and ``ContentHarmsStrategy`` are thin aliases kept for -backward compatibility. They will be removed in v0.15.0. -""" - -from typing import TYPE_CHECKING, Any - -from pyrit.common.deprecation import print_deprecation_message - -if TYPE_CHECKING: - from pyrit.scenario.scenarios.airt.rapid_response import RapidResponse as ContentHarms - - ContentHarmsStrategy = Any - - -def __getattr__(name: str) -> Any: - """ - Lazily resolve deprecated aliases and emit a deprecation warning. - - Returns: - Any: The resolved alias (``RapidResponse`` or its strategy class). - - Raises: - AttributeError: If the attribute name is not recognized. - """ - if name == "ContentHarms": - print_deprecation_message( - old_item="pyrit.scenario.scenarios.airt.content_harms.ContentHarms", - new_item="pyrit.scenario.scenarios.airt.rapid_response.RapidResponse", - removed_in="0.15.0", - ) - from pyrit.scenario.scenarios.airt.rapid_response import RapidResponse - - return RapidResponse - if name == "ContentHarmsStrategy": - print_deprecation_message( - old_item="pyrit.scenario.scenarios.airt.content_harms.ContentHarmsStrategy", - new_item="pyrit.scenario.scenarios.airt.rapid_response.RapidResponseStrategy", - removed_in="0.15.0", - ) - from pyrit.scenario.scenarios.airt.rapid_response import _build_rapid_response_strategy - - return _build_rapid_response_strategy() - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -__all__ = ["ContentHarms", "ContentHarmsStrategy"] diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index 7b1c40f7a6..9678fab3f8 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -10,7 +10,6 @@ import av -from pyrit.common.deprecation import print_deprecation_message from pyrit.memory import CentralMemory from pyrit.models import MessagePiece, Score from pyrit.prompt_converter import AzureSpeechAudioToTextConverter @@ -107,7 +106,6 @@ def __init__( self, *, text_capable_scorer: Scorer, - use_entra_auth: Optional[bool] = None, ) -> None: """ Initialize the base audio scorer. @@ -115,25 +113,10 @@ def __init__( Args: text_capable_scorer (Scorer): A scorer capable of processing text that will be used to score the transcribed audio content. - use_entra_auth (bool, Optional): **Deprecated.** Will be removed in 0.15.0. - Authentication is now configured on the underlying - ``AzureSpeechAudioToTextConverter`` via its ``azure_speech_key`` parameter: - pass a string API key (or set ``AZURE_SPEECH_KEY``) for key auth, a callable - token provider for Entra ID with a custom token, or omit it to use Entra ID - via ``DefaultAzureCredential``. Raises: ValueError: If text_capable_scorer does not support text data type. """ - if use_entra_auth is not None: - print_deprecation_message( - old_item="AudioTranscriptHelper(use_entra_auth=...)", - new_item=( - "AudioTranscriptHelper(...) (configure auth on the underlying " - "AzureSpeechAudioToTextConverter via azure_speech_key)" - ), - removed_in="0.15.0", - ) self._validate_text_scorer(text_capable_scorer) self.text_scorer = text_capable_scorer diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index d4e824d1fe..234c8ce597 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -107,7 +107,6 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non original_value_data_type="text", converted_value_data_type="text", response_error="none", - originator=original_piece.originator, original_prompt_id=( cast("UUID", original_piece.original_prompt_id) if isinstance(original_piece.original_prompt_id, str) diff --git a/pyrit/score/float_scale/audio_float_scale_scorer.py b/pyrit/score/float_scale/audio_float_scale_scorer.py index 203f8e1281..edac657bde 100644 --- a/pyrit/score/float_scale/audio_float_scale_scorer.py +++ b/pyrit/score/float_scale/audio_float_scale_scorer.py @@ -24,7 +24,6 @@ def __init__( *, text_capable_scorer: FloatScaleScorer, validator: Optional[ScorerPromptValidator] = None, - use_entra_auth: Optional[bool] = None, ) -> None: """ Initialize the AudioFloatScaleScorer. @@ -33,12 +32,6 @@ def __init__( text_capable_scorer: A FloatScaleScorer capable of processing text. This scorer will be used to evaluate the transcribed audio content. validator: Validator for the scorer. Defaults to audio_path data type validator. - use_entra_auth: **Deprecated.** Will be removed in 0.15.0. - Authentication is now configured on the underlying - ``AzureSpeechAudioToTextConverter`` via its ``azure_speech_key`` parameter: - pass a string API key (or set ``AZURE_SPEECH_KEY``) for key auth, a callable - token provider for Entra ID with a custom token, or omit it to use Entra ID - via ``DefaultAzureCredential``. Raises: ValueError: If text_capable_scorer does not support text data type. @@ -46,7 +39,6 @@ def __init__( super().__init__(validator=validator or self._default_validator) self._audio_helper = AudioTranscriptHelper( text_capable_scorer=text_capable_scorer, - use_entra_auth=use_entra_auth, ) def _build_identifier(self) -> ComponentIdentifier: diff --git a/pyrit/score/scorer_evaluation/scorer_metrics.py b/pyrit/score/scorer_evaluation/scorer_metrics.py index fab1f9a505..237bb18e02 100644 --- a/pyrit/score/scorer_evaluation/scorer_metrics.py +++ b/pyrit/score/scorer_evaluation/scorer_metrics.py @@ -95,29 +95,6 @@ def from_json_file(cls: type[T], file_path: Union[str, Path]) -> T: return cls(**filtered_data) - @classmethod - def from_json(cls: type[T], file_path: Union[str, Path]) -> T: - """ - Load a metrics instance from a JSON file (deprecated alias for ``from_json_file``). - - The name ``from_json`` is misleading because it accepts a *file path*, not a JSON - string. Use ``from_json_file`` instead. - - Args: - file_path (Union[str, Path]): The path to the JSON file. - - Returns: - ScorerMetrics: An instance of ScorerMetrics (or subclass) with the loaded data. - """ - from pyrit.common.deprecation import print_deprecation_message - - print_deprecation_message( - old_item=f"{cls.__name__}.from_json", - new_item=f"{cls.__name__}.from_json_file", - removed_in="0.15.0", - ) - return cls.from_json_file(file_path) - @dataclass class HarmScorerMetrics(ScorerMetrics): diff --git a/pyrit/score/true_false/audio_true_false_scorer.py b/pyrit/score/true_false/audio_true_false_scorer.py index c10befbf44..2af0fc01ec 100644 --- a/pyrit/score/true_false/audio_true_false_scorer.py +++ b/pyrit/score/true_false/audio_true_false_scorer.py @@ -24,7 +24,6 @@ def __init__( *, text_capable_scorer: TrueFalseScorer, validator: Optional[ScorerPromptValidator] = None, - use_entra_auth: Optional[bool] = None, ) -> None: """ Initialize the AudioTrueFalseScorer. @@ -33,12 +32,6 @@ def __init__( text_capable_scorer: A TrueFalseScorer capable of processing text. This scorer will be used to evaluate the transcribed audio content. validator: Validator for the scorer. Defaults to audio_path data type validator. - use_entra_auth: **Deprecated.** Will be removed in 0.15.0. - Authentication is now configured on the underlying - ``AzureSpeechAudioToTextConverter`` via its ``azure_speech_key`` parameter: - pass a string API key (or set ``AZURE_SPEECH_KEY``) for key auth, a callable - token provider for Entra ID with a custom token, or omit it to use Entra ID - via ``DefaultAzureCredential``. Raises: ValueError: If text_capable_scorer does not support text data type. @@ -46,7 +39,6 @@ def __init__( super().__init__(validator=validator or self._DEFAULT_VALIDATOR) self._audio_helper = AudioTranscriptHelper( text_capable_scorer=text_capable_scorer, - use_entra_auth=use_entra_auth, ) def _build_identifier(self) -> ComponentIdentifier: diff --git a/tests/unit/analytics/test_result_analysis.py b/tests/unit/analytics/test_result_analysis.py index 21d18541ed..49a4cb393b 100644 --- a/tests/unit/analytics/test_result_analysis.py +++ b/tests/unit/analytics/test_result_analysis.py @@ -21,6 +21,7 @@ IdentifierFilter, IdentifierType, ObjectiveTargetEvaluationIdentifier, + build_atomic_attack_identifier, ) @@ -33,14 +34,15 @@ def make_attack( """ Minimal valid AttackResult for analytics tests. """ - attack_identifier: Optional[ComponentIdentifier] = None + atomic_attack_identifier: Optional[ComponentIdentifier] = None if attack_type is not None: attack_identifier = ComponentIdentifier(class_name=attack_type, class_module="tests.unit.analytics") + atomic_attack_identifier = build_atomic_attack_identifier(attack_identifier=attack_identifier) return AttackResult( conversation_id=conversation_id, objective="test objective", - attack_identifier=attack_identifier, + atomic_attack_identifier=atomic_attack_identifier, outcome=outcome, ) diff --git a/tests/unit/analytics/test_technique_analysis.py b/tests/unit/analytics/test_technique_analysis.py index 04b1d94890..20546cd2c8 100644 --- a/tests/unit/analytics/test_technique_analysis.py +++ b/tests/unit/analytics/test_technique_analysis.py @@ -94,7 +94,6 @@ def test_passes_eval_hashes_to_memory_query(self, _patch_memory): call_kwargs = _patch_memory.get_attack_results.call_args[1] assert call_kwargs["atomic_attack_eval_hashes"] == ["x", "y"] assert call_kwargs["scenario_result_id"] is None - assert call_kwargs["targeted_harm_categories"] is None def test_passes_scenario_result_id_to_memory_query(self, _patch_memory): compute_technique_stats(technique_eval_hashes=["x"], scenario_result_id="run-123") @@ -124,15 +123,6 @@ def test_success_rate_computed(self, _patch_memory): assert stats["a"].success_rate == pytest.approx(0.5) - def test_passes_harm_categories_to_memory_query(self, _patch_memory): - compute_technique_stats( - technique_eval_hashes=["x"], - targeted_harm_categories=["misinformation", "hate"], - ) - - call_kwargs = _patch_memory.get_attack_results.call_args[1] - assert call_kwargs["targeted_harm_categories"] == ["misinformation", "hate"] - def test_injected_memory_bypasses_central_memory(self, _patch_memory): injected = MagicMock() injected.get_attack_results.return_value = [ diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 2d6e4da37c..7057016942 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -30,7 +30,7 @@ ) from pyrit.backend.mappers.converter_mappers import converter_object_to_instance from pyrit.backend.mappers.target_mappers import target_object_to_instance -from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier +from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, build_atomic_attack_identifier from pyrit.models.conversation_stats import ConversationStats from pyrit.prompt_target import PromptTarget, TargetCapabilities @@ -66,13 +66,15 @@ def _make_attack_result( conversation_id=conversation_id, objective="test", attack_result_id=str(uuid.uuid4()), - attack_identifier=ComponentIdentifier( - class_name=name, - class_module="pyrit.backend", - params={ - "source": "gui", - }, - children=children, + atomic_attack_identifier=build_atomic_attack_identifier( + attack_identifier=ComponentIdentifier( + class_name=name, + class_module="pyrit.backend", + params={ + "source": "gui", + }, + children=children, + ) ), outcome=outcome, metadata={ @@ -263,29 +265,31 @@ def test_converters_extracted_from_identifier(self) -> None: conversation_id="attack-conv", objective="test", attack_result_id=str(uuid.uuid4()), - attack_identifier=ComponentIdentifier( - class_name="TestAttack", - class_module="pyrit.backend", - children={ - "request_converters": [ - ComponentIdentifier( - class_name="Base64Converter", - class_module="pyrit.converters", - params={ - "supported_input_types": ("text",), - "supported_output_types": ("text",), - }, - ), - ComponentIdentifier( - class_name="ROT13Converter", - class_module="pyrit.converters", - params={ - "supported_input_types": ("text",), - "supported_output_types": ("text",), - }, - ), - ], - }, + atomic_attack_identifier=build_atomic_attack_identifier( + attack_identifier=ComponentIdentifier( + class_name="TestAttack", + class_module="pyrit.backend", + children={ + "request_converters": [ + ComponentIdentifier( + class_name="Base64Converter", + class_module="pyrit.converters", + params={ + "supported_input_types": ("text",), + "supported_output_types": ("text",), + }, + ), + ComponentIdentifier( + class_name="ROT13Converter", + class_module="pyrit.converters", + params={ + "supported_input_types": ("text",), + "supported_output_types": ("text",), + }, + ), + ], + }, + ) ), outcome=AttackOutcome.UNDETERMINED, metadata={"created_at": now.isoformat(), "updated_at": now.isoformat()}, diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index ed56813139..7a91d63cfd 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -889,8 +889,25 @@ async def test_multipart_message_extracts_scores_from_all_pieces( # All pieces in a Message must share the same conversation_id piece_conversation_id = str(uuid.uuid4()) - # Create score for first piece - # Prepended conversations are simulated, so only false scores are extracted + piece1 = MessagePiece( + role="assistant", + original_value="Here is the analysis:", + original_value_data_type="text", + conversation_id=piece_conversation_id, + ) + piece2 = MessagePiece( + role="assistant", + original_value="chart_image.png", + original_value_data_type="image_path", + conversation_id=piece_conversation_id, + ) + + # Pre-stage the original pieces + scores in memory so add_scores_to_memory + # passes its existence check. initialize_context_async will then duplicate + # the pieces under the target conversation_id, keeping ``original_prompt_id`` + # set to the input id (which is what ScoreEntry.prompt_request_response_id + # points at), so the per-conversation score lookup resolves them. + manager._memory.add_message_pieces_to_memory(message_pieces=[piece1, piece2]) score1 = Score( score_type="true_false", score_value="false", @@ -898,19 +915,9 @@ async def test_multipart_message_extracts_scores_from_all_pieces( score_value_description="Score for text piece", score_rationale="Test rationale for text", score_metadata={}, - message_piece_id=str(uuid.uuid4()), + message_piece_id=str(piece1.id), scorer_class_identifier=get_mock_scorer_identifier(), ) - piece1 = MessagePiece( - role="assistant", - original_value="Here is the analysis:", - original_value_data_type="text", - conversation_id=piece_conversation_id, - scores=[score1], # Attach score directly to piece - ) - - # Create score for second piece - # Also false since prepended conversations only extract false scores score2 = Score( score_type="true_false", score_value="false", @@ -918,16 +925,10 @@ async def test_multipart_message_extracts_scores_from_all_pieces( score_value_description="Score for image piece", score_rationale="Test rationale for image", score_metadata={}, - message_piece_id=str(uuid.uuid4()), + message_piece_id=str(piece2.id), scorer_class_identifier=get_mock_scorer_identifier(), ) - piece2 = MessagePiece( - role="assistant", - original_value="chart_image.png", - original_value_data_type="image_path", - conversation_id=piece_conversation_id, - scores=[score2], # Attach score directly to piece - ) + manager._memory.add_scores_to_memory(scores=[score1, score2]) multipart_response = Message(message_pieces=[piece1, piece2]) context.prepended_conversation = [ @@ -944,8 +945,9 @@ async def test_multipart_message_extracts_scores_from_all_pieces( # Verify scores from both pieces are returned assert len(state.last_assistant_message_scores) == 2 - assert score1 in state.last_assistant_message_scores - assert score2 in state.last_assistant_message_scores + returned_ids = {s.id for s in state.last_assistant_message_scores} + assert score1.id in returned_ids + assert score2.id in returned_ids async def test_prepended_conversation_ignores_true_scores( self, @@ -959,9 +961,28 @@ async def test_prepended_conversation_ignores_true_scores( are extracted to provide feedback rationale for continued attack attempts. """ manager = ConversationManager(attack_identifier=attack_identifier) - conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) + piece_with_true = MessagePiece( + role="assistant", + original_value="Simulated success response", + original_value_data_type="text", + conversation_id=str(uuid.uuid4()), + ) + + piece_with_false = MessagePiece( + role="assistant", + original_value="Simulated refusal response", + original_value_data_type="text", + conversation_id=str(uuid.uuid4()), + ) + + # Pre-stage the pieces in memory so add_scores_to_memory passes its + # existence check. initialize_context_async will duplicate them under + # the target conversation_id, preserving ``original_prompt_id`` so the + # score lookup resolves the staged scores. + manager._memory.add_message_pieces_to_memory(message_pieces=[piece_with_true, piece_with_false]) + # Create a score with true value - should be ignored true_score = Score( score_type="true_false", @@ -970,7 +991,7 @@ async def test_prepended_conversation_ignores_true_scores( score_value_description="Should be ignored", score_rationale="This simulated success should not be extracted", score_metadata={}, - message_piece_id=str(uuid.uuid4()), + message_piece_id=str(piece_with_true.id), scorer_class_identifier=get_mock_scorer_identifier(), ) @@ -982,25 +1003,11 @@ async def test_prepended_conversation_ignores_true_scores( score_value_description="Should be extracted", score_rationale="This refusal can provide feedback", score_metadata={}, - message_piece_id=str(uuid.uuid4()), + message_piece_id=str(piece_with_false.id), scorer_class_identifier=get_mock_scorer_identifier(), ) - piece_with_true = MessagePiece( - role="assistant", - original_value="Simulated success response", - original_value_data_type="text", - conversation_id=str(uuid.uuid4()), - scores=[true_score], - ) - - piece_with_false = MessagePiece( - role="assistant", - original_value="Simulated refusal response", - original_value_data_type="text", - conversation_id=str(uuid.uuid4()), - scores=[false_score], - ) + manager._memory.add_scores_to_memory(scores=[true_score, false_score]) # Test with true score only - should get no scores context.prepended_conversation = [ @@ -1011,7 +1018,7 @@ async def test_prepended_conversation_ignores_true_scores( state = await manager.initialize_context_async( context=context, target=mock_chat_target, - conversation_id=conversation_id, + conversation_id=str(uuid.uuid4()), max_turns=10, ) @@ -1033,8 +1040,9 @@ async def test_prepended_conversation_ignores_true_scores( ) assert len(state2.last_assistant_message_scores) == 1 - assert false_score in state2.last_assistant_message_scores - assert context2.last_score == false_score + assert state2.last_assistant_message_scores[0].id == false_score.id + assert context2.last_score is not None + assert context2.last_score.id == false_score.id # ============================================================================= diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index cdb611e64b..a11c65ecca 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -27,16 +27,14 @@ from collections.abc import Sequence -def create_message_piece(conversation_id: str, prompt_num: int, targeted_harm_categories=None, labels=None): - """Helper function to create MessagePiece with optional targeted harm categories and labels.""" +def create_message_piece(conversation_id: str, prompt_num: int, labels=None): + """Helper function to create MessagePiece with optional labels.""" kwargs: dict = { "role": "user", "original_value": f"Test prompt {prompt_num}", "converted_value": f"Test prompt {prompt_num}", "conversation_id": conversation_id, } - if targeted_harm_categories is not None: - kwargs["targeted_harm_categories"] = targeted_harm_categories if labels is not None: kwargs["labels"] = labels return MessagePiece(**kwargs) @@ -736,62 +734,6 @@ def test_update_attack_result_stale_entry_does_not_overwrite(sqlite_instance: Me assert results[0].related_conversations.pop().conversation_id == "branch-1" -def test_get_attack_results_by_harm_category_single(sqlite_instance: MemoryInterface): - """Test filtering attack results by a single harm category.""" - - # Create message pieces with harm categories using helper function - message_piece1 = create_message_piece("conv_1", 1, targeted_harm_categories=["violence", "illegal"]) - message_piece2 = create_message_piece("conv_2", 2, targeted_harm_categories=["illegal"]) - message_piece3 = create_message_piece("conv_3", 3, targeted_harm_categories=["violence"]) - - # Add message pieces to memory - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results using helper function - attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) - attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE) - attack_result3 = create_attack_result("conv_3", 3, AttackOutcome.SUCCESS) - - sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) - - violence_results = sqlite_instance.get_attack_results(targeted_harm_categories=["violence"]) - assert len(violence_results) == 2 - conversation_ids = {result.conversation_id for result in violence_results} - assert conversation_ids == {"conv_1", "conv_3"} - - illegal_results = sqlite_instance.get_attack_results(targeted_harm_categories=["illegal"]) - assert len(illegal_results) == 2 - conversation_ids = {result.conversation_id for result in illegal_results} - assert conversation_ids == {"conv_1", "conv_2"} - - -def test_get_attack_results_by_harm_category_multiple(sqlite_instance: MemoryInterface): - """Test filtering attack results by multiple harm categories (AND logic).""" - - # Create message pieces with different harm category combinations - message_piece1 = create_message_piece("conv_1", 1, targeted_harm_categories=["violence", "illegal", "hate"]) - message_piece2 = create_message_piece("conv_2", 2, targeted_harm_categories=["violence", "illegal"]) - message_piece3 = create_message_piece("conv_3", 3, targeted_harm_categories=["violence"]) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results - attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) - attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.SUCCESS) - attack_result3 = create_attack_result("conv_3", 3, AttackOutcome.FAILURE) - - sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) - - # Test filtering by multiple harm categories - violence_and_illegal_results = sqlite_instance.get_attack_results(targeted_harm_categories=["violence", "illegal"]) - assert len(violence_and_illegal_results) == 2 - conversation_ids = {result.conversation_id for result in violence_and_illegal_results} - assert conversation_ids == {"conv_1", "conv_2"} - all_three_results = sqlite_instance.get_attack_results(targeted_harm_categories=["violence", "illegal", "hate"]) - assert len(all_three_results) == 1 - assert all_three_results[0].conversation_id == "conv_1" - - def test_get_attack_results_by_labels_single(sqlite_instance: MemoryInterface): """Test filtering attack results by single label.""" @@ -948,56 +890,6 @@ def test_get_attack_results_by_labels_or_within_key_and_across_keys(sqlite_insta assert {r.conversation_id for r in results} == {"conv_1", "conv_2"} -def test_get_attack_results_by_harm_category_and_labels(sqlite_instance: MemoryInterface): - """Test filtering attack results by both harm categories and labels.""" - - # Create message pieces with harm categories (harm categories still live on PromptMemoryEntry) - message_piece1 = create_message_piece("conv_1", 1, targeted_harm_categories=["violence", "illegal"]) - message_piece2 = create_message_piece("conv_2", 2, targeted_harm_categories=["violence"]) - message_piece3 = create_message_piece("conv_3", 3, targeted_harm_categories=["violence", "illegal"]) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results with labels - attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"}), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"}), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "other_op", "operator": "bob"}), - ] - - sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) - - # Test filtering by both harm categories and labels - violence_illegal_roakey_results = sqlite_instance.get_attack_results( - targeted_harm_categories=["violence", "illegal"], labels={"operator": "roakey"} - ) - assert len(violence_illegal_roakey_results) == 1 - assert violence_illegal_roakey_results[0].conversation_id == "conv_1" - - # Test filtering by harm category and operation - violence_test_op_results = sqlite_instance.get_attack_results( - targeted_harm_categories=["violence"], labels={"operation": "test_op"} - ) - assert len(violence_test_op_results) == 2 - conversation_ids = {result.conversation_id for result in violence_test_op_results} - assert conversation_ids == {"conv_1", "conv_2"} - - -def test_get_attack_results_harm_category_no_matches(sqlite_instance: MemoryInterface): - """Test filtering by harm category that doesn't exist.""" - - # Create attack result without the harm category we'll search for - message_piece = create_message_piece("conv_1", 1, targeted_harm_categories=["violence"]) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) - - attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) - sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) - - # Search for non-existent harm category - results = sqlite_instance.get_attack_results(targeted_harm_categories=["nonexistent"]) - assert len(results) == 0 - - def test_get_attack_results_labels_no_matches(sqlite_instance: MemoryInterface): """Test filtering by labels that don't exist.""" @@ -1574,22 +1466,6 @@ def test_get_attack_results_attack_classes_converter_classes_empty_matches_no_co assert {r.conversation_id for r in results} == {"conv_2"} -def test_get_attack_results_attack_class_backcompat_singular(sqlite_instance: MemoryInterface): - """Deprecated singular attack_class=... still works and is equivalent to attack_classes=[...].""" - ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") - ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") - sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) - - results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack") - assert {r.conversation_id for r in results} == {"conv_1"} - - -def test_get_attack_results_attack_class_and_attack_classes_both_raises(sqlite_instance: MemoryInterface): - """Passing both attack_class and attack_classes is rejected.""" - with pytest.raises(ValueError, match="attack_class"): - sqlite_instance.get_attack_results(attack_class="A", attack_classes=["B"]) - - def test_get_attack_results_has_converters_true(sqlite_instance: MemoryInterface): """has_converters=True returns only attacks with at least one converter.""" ar_with_conv = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) @@ -1757,20 +1633,3 @@ def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance ], ) assert len(results) == 0 - - -def test_get_attack_results_targeted_harm_categories_emits_deprecation_warning(sqlite_instance: MemoryInterface): - """Test that passing targeted_harm_categories emits a DeprecationWarning.""" - import warnings - - message_piece = create_message_piece("conv_1", 1, targeted_harm_categories=["violence"]) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) - - attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) - sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - sqlite_instance.get_attack_results(targeted_harm_categories=["violence"]) - deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)] - assert any("targeted_harm_categories" in str(m.message) for m in deprecation_msgs) diff --git a/tests/unit/memory/memory_interface/test_interface_export.py b/tests/unit/memory/memory_interface/test_interface_export.py deleted file mode 100644 index 34252b7547..0000000000 --- a/tests/unit/memory/memory_interface/test_interface_export.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import csv -import json -import os -import tempfile -from collections.abc import Sequence -from pathlib import Path -from unittest.mock import MagicMock, patch - -import pytest - -from pyrit.common.path import DB_DATA_PATH -from pyrit.memory import MemoryExporter, MemoryInterface -from pyrit.models import MessagePiece - - -def test_export_conversation_by_attack_id_file_created( - sqlite_instance: MemoryInterface, sample_conversations: Sequence[MessagePiece] -): - attack1_id = sample_conversations[0].attack_identifier.hash - - # Default path in export_conversations() - file_name = f"{attack1_id}.json" - file_path = Path(DB_DATA_PATH, file_name) - - sqlite_instance.exporter = MemoryExporter() - - with patch("pyrit.memory.sqlite_memory.SQLiteMemory.get_message_pieces") as mock_get: - mock_get.return_value = sample_conversations - sqlite_instance.export_conversations(attack_id=attack1_id, file_path=file_path) - - # Verify file was created - assert file_path.exists() - - -def test_export_all_conversations_file_created(sqlite_instance: MemoryInterface): - sqlite_instance.exporter = MemoryExporter() - - with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as temp_file: - with ( - patch("pyrit.memory.sqlite_memory.SQLiteMemory.get_message_pieces") as mock_get_pieces, - patch("pyrit.memory.sqlite_memory.SQLiteMemory.get_prompt_scores") as mock_get_scores, - ): - file_path = Path(temp_file.name) - - # Create mock with serializable data - - mock_get_pieces.return_value = [ - MagicMock( - original_prompt_id="1234", - converted_value="sample piece", - model_dump=lambda mode="json": {"message_piece_id": "1234", "conversation": ["sample piece"]}, - ) - ] - mock_get_scores.return_value = [ - MagicMock( - message_piece_id="1234", - score_value=10, - model_dump=lambda mode="json": {"message_piece_id": "1234", "score_value": 10}, - ) - ] - - result_path = sqlite_instance.export_conversations(file_path=file_path) - - assert result_path == file_path - assert file_path.exists() - - -def test_export_all_conversations_with_scores_correct_data(sqlite_instance: MemoryInterface): - sqlite_instance.exporter = MemoryExporter() - - with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as temp_file: - file_path = Path(temp_file.name) - temp_file.close() # Close the file to allow Windows to open it for writing - - try: - with ( - patch.object(sqlite_instance, "get_message_pieces") as mock_get_pieces, - patch.object(sqlite_instance, "get_prompt_scores") as mock_get_scores, - ): - # Create a mock piece - mock_piece = MagicMock() - mock_piece.id = "piece_id_1234" - mock_piece.original_prompt_id = "1234" - mock_piece.converted_value = "sample piece" - mock_piece.to_dict.return_value = { - "id": "piece_id_1234", - "original_prompt_id": "1234", - "converted_value": "sample piece", - } - mock_piece.model_dump.return_value = { - "id": "piece_id_1234", - "original_prompt_id": "1234", - "converted_value": "sample piece", - } - - # Create a mock score - mock_score = MagicMock() - mock_score.message_piece_id = "piece_id_1234" - mock_score.score_value = 10 - mock_score.model_dump.return_value = {"message_piece_id": "piece_id_1234", "score_value": 10} - - mock_get_pieces.return_value = [mock_piece] - mock_get_scores.return_value = [mock_score] - - result_path = sqlite_instance.export_conversations(file_path=file_path) - - # Verify the file was created and contains correct data - assert result_path == file_path - assert file_path.exists() - - # Read and verify the exported JSON content - with open(file_path) as f: - exported_data = json.load(f) - - assert len(exported_data) == 1 - assert exported_data[0]["id"] == "piece_id_1234" - assert exported_data[0]["original_prompt_id"] == "1234" - assert exported_data[0]["converted_value"] == "sample piece" - assert len(exported_data[0]["scores"]) == 1 - assert exported_data[0]["scores"][0]["score_value"] == 10 - finally: - # Clean up the temp file - if file_path.exists(): - os.remove(file_path) - - -def test_export_all_conversations_with_scores_empty_data(sqlite_instance: MemoryInterface): - sqlite_instance.exporter = MemoryExporter() - with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as temp_file: - file_path = Path(temp_file.name) - temp_file.close() # Close the file to allow Windows to open it for writing - - try: - with ( - patch("pyrit.memory.sqlite_memory.SQLiteMemory.get_message_pieces") as mock_get_pieces, - patch("pyrit.memory.sqlite_memory.SQLiteMemory.get_prompt_scores") as mock_get_scores, - ): - mock_get_pieces.return_value = [] - mock_get_scores.return_value = [] - - result_path = sqlite_instance.export_conversations(file_path=file_path) - - # Verify the file was created and is empty JSON array - assert result_path == file_path - assert file_path.exists() - - # Read and verify the exported JSON content is empty - with open(file_path) as f: - exported_data = json.load(f) - - assert exported_data == [] - finally: - # Clean up the temp file - if file_path.exists(): - os.remove(file_path) - - -@pytest.mark.parametrize("export_type, suffix", [("json", ".json"), ("csv", ".csv"), ("md", ".md")]) -def test_export_all_conversations_with_scores_respects_export_type( - sqlite_instance: MemoryInterface, export_type: str, suffix: str -): - sqlite_instance.exporter = MemoryExporter() - - with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: - file_path = Path(temp_file.name) - temp_file.close() - - try: - with ( - patch.object(sqlite_instance, "get_message_pieces") as mock_get_pieces, - patch.object(sqlite_instance, "get_prompt_scores") as mock_get_scores, - ): - mock_piece = MagicMock() - mock_piece.id = "piece_id_1234" - mock_piece.to_dict.return_value = { - "id": "piece_id_1234", - "converted_value": "sample piece", - } - mock_piece.model_dump.return_value = { - "id": "piece_id_1234", - "converted_value": "sample piece", - } - - mock_score = MagicMock() - mock_score.message_piece_id = "piece_id_1234" - mock_score.model_dump.return_value = {"message_piece_id": "piece_id_1234", "score_value": 10} - - mock_get_pieces.return_value = [mock_piece] - mock_get_scores.return_value = [mock_score] - - sqlite_instance.export_conversations(file_path=file_path, export_type=export_type) - - assert file_path.exists() - exported_content = file_path.read_text(encoding="utf-8") - assert "piece_id_1234" in exported_content - assert "sample piece" in exported_content - - if export_type == "json": - exported_data = json.loads(exported_content) - assert len(exported_data) == 1 - assert exported_data[0]["id"] == "piece_id_1234" - elif export_type == "csv": - with open(file_path, newline="") as exported_file: - reader = csv.DictReader(exported_file) - assert reader.fieldnames == ["id", "converted_value", "scores"] - rows = list(reader) - assert len(rows) == 1 - assert rows[0]["id"] == "piece_id_1234" - elif export_type == "md": - assert exported_content.startswith("| id | converted_value | scores |") - finally: - if file_path.exists(): - os.remove(file_path) diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index f1261b9597..eb247757c3 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -1155,6 +1155,7 @@ def test_get_message_pieces_sorts( def test_message_piece_scores_duplicate_piece(sqlite_instance: MemoryInterface): + """Scores for duplicated pieces are returned via get_prompt_scores.""" original_id = uuid4() duplicate_id = uuid4() @@ -1186,14 +1187,15 @@ def test_message_piece_scores_duplicate_piece(sqlite_instance: MemoryInterface): ) sqlite_instance.add_scores_to_memory(scores=[score]) - retrieved_pieces = sqlite_instance.get_message_pieces() + # Both the original and the duplicate piece resolve back to the same score + # via get_prompt_scores, which queries ScoreEntry by original_prompt_id. + scores_for_original = sqlite_instance.get_prompt_scores(prompt_ids=[str(original_id)]) + scores_for_duplicate = sqlite_instance.get_prompt_scores(prompt_ids=[str(duplicate_id)]) - assert len(retrieved_pieces[0].scores) == 1 - assert retrieved_pieces[0].scores[0].score_value == "0.8" - - # Check that the duplicate piece has the same score as the original - assert len(retrieved_pieces[1].scores) == 1 - assert retrieved_pieces[1].scores[0].score_value == "0.8" + assert len(scores_for_original) == 1 + assert scores_for_original[0].score_value == "0.8" + assert len(scores_for_duplicate) == 1 + assert scores_for_duplicate[0].score_value == "0.8" async def test_message_piece_hash_stored_and_retrieved(sqlite_instance: MemoryInterface): diff --git a/tests/unit/memory/test_memory_exporter.py b/tests/unit/memory/test_memory_exporter.py deleted file mode 100644 index d7efbf35dd..0000000000 --- a/tests/unit/memory/test_memory_exporter.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import csv -import json -from collections.abc import Sequence - -import pytest -from sqlalchemy.inspection import inspect - -from pyrit.memory.memory_exporter import MemoryExporter -from pyrit.memory.memory_models import PromptMemoryEntry -from pyrit.models import Message -from unit.mocks import get_sample_conversation_entries, get_sample_conversations - - -@pytest.fixture -def sample_conversation_entries() -> Sequence[PromptMemoryEntry]: - return get_sample_conversation_entries() - - -def model_to_dict(instance): - """Converts a SQLAlchemy model instance into a dictionary.""" - return {c.key: getattr(instance, c.key) for c in inspect(instance).mapper.column_attrs} - - -def read_file(file_path, export_type): - if export_type == "json": - with open(file_path) as f: - return json.load(f) - elif export_type == "csv": - with open(file_path, newline="") as f: - reader = csv.DictReader(f) - return list(reader) - else: - raise ValueError(f"Invalid export type: {export_type}") - - -def export(export_type, exporter, data, file_path): - if export_type == "json": - exporter.export_to_json(data, file_path) - elif export_type == "csv": - exporter.export_to_csv(data, file_path) - else: - raise ValueError(f"Invalid export type: {export_type}") - - -@pytest.mark.parametrize("export_type", ["json", "csv"]) -def test_export_to_json_creates_file(tmp_path, export_type): - exporter = MemoryExporter() - file_path = tmp_path / f"conversations.{export_type}" - conversations = get_sample_conversations() - sample_conversation_entries = list(Message.flatten_to_message_pieces(conversations)) - export(export_type=export_type, exporter=exporter, data=sample_conversation_entries, file_path=file_path) - - assert file_path.exists() # Check that the file was created - content = read_file(file_path=file_path, export_type=export_type) - # Perform more detailed checks on content if necessary - assert len(content) == 3 # Simple check for the number of items - # Convert each MessagePiece instance to a dictionary - expected_content = [message_piece.model_dump(mode="json") for message_piece in sample_conversation_entries] - - for expected, actual in zip(expected_content, content, strict=False): - assert expected["role"] == actual["role"] - assert expected["converted_value"] == actual["converted_value"] - assert expected["conversation_id"] == actual["conversation_id"] - assert expected["original_value_data_type"] == actual["original_value_data_type"] - assert expected["original_value"] == actual["original_value"] - - -@pytest.mark.parametrize("export_type", ["json", "csv"]) -def test_export_to_json_data_with_conversations(tmp_path, export_type): - exporter = MemoryExporter() - conversations = get_sample_conversations() - sample_conversation_entries = list(Message.flatten_to_message_pieces(conversations)) - conversation_id = sample_conversation_entries[0].conversation_id - - # Define the file path using tmp_path - file_path = tmp_path / "exported_conversations.json" - - # Call the method under test - export(export_type=export_type, exporter=exporter, data=sample_conversation_entries, file_path=file_path) - - # Verify the file was created - assert file_path.exists() - - # Read the file and verify its contents - content = read_file(file_path=file_path, export_type=export_type) - assert len(content) == 3 # Check for the expected number of items - assert content[0]["role"] == "user" - assert content[0]["converted_value"] == "Hello, how are you?" - assert content[0]["conversation_id"] == conversation_id - assert content[1]["role"] == "assistant" - assert content[1]["converted_value"] == "I'm fine, thank you!" - assert content[1]["conversation_id"] == conversation_id - - -@pytest.mark.parametrize("export_type", ["json", "csv", "md"]) -def test_export_data_creates_file(tmp_path, export_type): - exporter = MemoryExporter() - file_path = tmp_path / f"conversations.{export_type}" - conversations = get_sample_conversations() - sample_conversation_entries = list(Message.flatten_to_message_pieces(conversations)) - exporter.export_data(data=sample_conversation_entries, file_path=file_path, export_type=export_type) - - assert file_path.exists() diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index ea50d4de7e..ae08bf61d1 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -8,7 +8,7 @@ import pytest from pyrit.memory.memory_models import AttackResultEntry -from pyrit.models import ComponentIdentifier, build_atomic_attack_identifier +from pyrit.models import ComponentIdentifier from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.messages.message_piece import MessagePiece from pyrit.models.results.attack_result import AttackOutcome, AttackResult @@ -16,133 +16,6 @@ from pyrit.models.score import Score -class TestAttackResultDeprecation: - """Tests for the AttackResult attack_identifier deprecation behaviour.""" - - def _make_attack_identifier(self) -> ComponentIdentifier: - return ComponentIdentifier(class_name="TestAttack", class_module="tests.unit") - - def _make_atomic_identifier(self) -> ComponentIdentifier: - attack_id = self._make_attack_identifier() - return build_atomic_attack_identifier(attack_identifier=attack_id) - - # -- property deprecation ------------------------------------------------- - - def test_attack_identifier_property_emits_deprecation_warning(self) -> None: - """Accessing .attack_identifier should emit a DeprecationWarning.""" - result = AttackResult( - conversation_id="c1", - objective="test", - atomic_attack_identifier=self._make_atomic_identifier(), - ) - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - _ = result.attack_identifier - - deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert len(deprecation_warnings) >= 1, "Expected a DeprecationWarning from .attack_identifier" - assert "attack_identifier" in str(deprecation_warnings[0].message).lower() - - def test_attack_identifier_property_returns_correct_value(self) -> None: - """Accessing .attack_identifier should return the attack strategy child.""" - result = AttackResult( - conversation_id="c1", - objective="test", - atomic_attack_identifier=self._make_atomic_identifier(), - ) - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - value = result.attack_identifier - - assert value is not None - assert value.class_name == "TestAttack" - - def test_attack_identifier_property_returns_none_when_unset(self) -> None: - """Property returns None when atomic_attack_identifier is not set.""" - result = AttackResult(conversation_id="c1", objective="test") - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - assert result.attack_identifier is None - - # -- get_attack_strategy_identifier (non-deprecated) ---------------------- - - def test_get_attack_strategy_identifier_no_warning(self) -> None: - """get_attack_strategy_identifier() must NOT emit a deprecation warning.""" - result = AttackResult( - conversation_id="c1", - objective="test", - atomic_attack_identifier=self._make_atomic_identifier(), - ) - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - value = result.get_attack_strategy_identifier() - - deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert len(deprecation_warnings) == 0, "get_attack_strategy_identifier should not warn" - assert value is not None - assert value.class_name == "TestAttack" - - def test_get_attack_strategy_identifier_returns_none_when_unset(self) -> None: - result = AttackResult(conversation_id="c1", objective="test") - assert result.get_attack_strategy_identifier() is None - - # -- backward-compat constructor ------------------------------------------ - - def test_constructor_with_attack_identifier_kwarg_emits_warning(self) -> None: - """Passing attack_identifier= to the constructor should emit DeprecationWarning.""" - attack_id = self._make_attack_identifier() - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - result = AttackResult( - conversation_id="c1", - objective="test", - attack_identifier=attack_id, - ) - - deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert len(deprecation_warnings) >= 1, "Constructor should warn on attack_identifier=" - # The value should be promoted to atomic_attack_identifier - assert result.atomic_attack_identifier is not None - assert result.get_attack_strategy_identifier() == attack_id - - def test_constructor_attack_identifier_does_not_override_atomic(self) -> None: - """If both are supplied, atomic_attack_identifier takes precedence.""" - attack_id = self._make_attack_identifier() - atomic_id = self._make_atomic_identifier() - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - result = AttackResult( - conversation_id="c1", - objective="test", - attack_identifier=attack_id, - atomic_attack_identifier=atomic_id, - ) - - assert result.atomic_attack_identifier is atomic_id - - # -- construction without deprecated kwarg -------------------------------- - - def test_constructor_with_atomic_attack_identifier_only(self) -> None: - """Normal construction with atomic_attack_identifier should work with no warnings.""" - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - result = AttackResult( - conversation_id="c1", - objective="test", - atomic_attack_identifier=self._make_atomic_identifier(), - ) - - deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert len(deprecation_warnings) == 0 - assert result.get_attack_strategy_identifier() is not None - - def test_constructor_with_no_identifier_at_all(self) -> None: - """Construction with neither identifier should be fine.""" - result = AttackResult(conversation_id="c1", objective="test") - assert result.atomic_attack_identifier is None - assert result.get_attack_strategy_identifier() is None - - class TestAttackResultTimestamp: """Tests for the AttackResult.timestamp field and its round-trip through AttackResultEntry.""" @@ -470,37 +343,6 @@ def test_aware_iso_string_timestamp_is_preserved(self) -> None: result = AttackResult(conversation_id="c1", objective="test", timestamp="2026-01-01T12:00:00+00:00") assert result.timestamp == datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc) - def test_deprecated_kwarg_promotes_without_extra_field_error(self) -> None: - """The promote before-validator pops attack_identifier before extra='forbid' runs.""" - attack_id = ComponentIdentifier(class_name="TestAttack", class_module="tests.unit") - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - result = AttackResult( - conversation_id="c1", - objective="test", - attack_identifier=attack_id, - ) - - deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert len(deprecation_warnings) >= 1 - assert result.atomic_attack_identifier is not None - - def test_model_validate_does_not_mutate_input_dict(self) -> None: - """The promote before-validator must copy, not mutate, the caller-provided payload dict.""" - attack_id = ComponentIdentifier(class_name="TestAttack", class_module="tests.unit") - payload = { - "conversation_id": "c1", - "objective": "test", - "attack_identifier": attack_id, - "timestamp": "2026-01-01T12:00:00+00:00", - } - original = dict(payload) - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - AttackResult.model_validate(payload) - - assert payload == original, "model_validate must not mutate the input dict" - class TestAttackResultLegacyDictDeprecation: """to_dict()/from_dict() are retained as deprecated shims and must warn.""" diff --git a/tests/unit/models/test_chat_message.py b/tests/unit/models/test_chat_message.py index 8391e9340b..c3475285b7 100644 --- a/tests/unit/models/test_chat_message.py +++ b/tests/unit/models/test_chat_message.py @@ -104,21 +104,6 @@ def test_chat_message_accepts_all_valid_roles(role): assert msg.role == role -def test_chat_message_to_json_is_deprecated_alias_for_model_dump_json(): - msg = ChatMessage(role="user", content="test") - with pytest.warns(DeprecationWarning, match="ChatMessage.to_json"): - result = msg.to_json() - assert result == msg.model_dump_json() - - -def test_chat_message_from_json_is_deprecated_alias_for_model_validate_json(): - original = ChatMessage(role="system", content="you are helpful") - json_str = original.model_dump_json() - with pytest.warns(DeprecationWarning, match="ChatMessage.from_json"): - restored = ChatMessage.from_json(json_str) - assert restored == original - - def test_chat_messages_dataset_init(): msgs = [[ChatMessage(role="user", content="hi"), ChatMessage(role="assistant", content="hello")]] dataset = ChatMessagesDataset(name="test_ds", description="A test dataset", list_of_chat_messages=msgs) diff --git a/tests/unit/models/test_embedding_response.py b/tests/unit/models/test_embedding_response.py index 03e338f446..fe39fa102e 100644 --- a/tests/unit/models/test_embedding_response.py +++ b/tests/unit/models/test_embedding_response.py @@ -45,9 +45,3 @@ def test_save_load_loop_is_idempotent(my_embedding): output_file = my_embedding.save_to_file(Path(tmp_dir)) loaded_embedding = EmbeddingResponse.load_from_file(Path(output_file)) assert my_embedding == loaded_embedding - - -def test_to_json_is_deprecated_alias_for_model_dump_json(my_embedding: EmbeddingResponse): - with pytest.warns(DeprecationWarning, match="EmbeddingResponse.to_json"): - result = my_embedding.to_json() - assert result == my_embedding.model_dump_json() diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index e8d4457d81..6f4cf03cfb 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -680,7 +680,6 @@ def test_message_piece_to_dict(): conversation_id="test_conversation", sequence=1, labels={"label1": "value1"}, - targeted_harm_categories=["violence", "illegal"], prompt_metadata={"key": "metadata"}, converter_identifiers=[ ComponentIdentifier( @@ -697,34 +696,11 @@ def test_message_piece_to_dict(): class_name="PromptSendingAttack", class_module="pyrit.executor.attack.single_turn.prompt_sending_attack", ), - scorer_identifier=ComponentIdentifier( - class_name="TestScorer", - class_module="pyrit.score.test_scorer", - ), original_value_data_type="text", converted_value_data_type="text", response_error="none", - originator="undefined", original_prompt_id=uuid.uuid4(), timestamp=datetime.now(tz=timezone.utc), - scores=[ - Score( - id=str(uuid.uuid4()), - score_value="false", - score_value_description="true false score", - score_type="true_false", - score_category=["Category1"], - score_rationale="Rationale text", - score_metadata={"key": "value"}, - scorer_class_identifier=ComponentIdentifier( - class_name="Scorer1", - class_module="pyrit.score", - ), - message_piece_id=str(uuid.uuid4()), - timestamp=datetime.now(tz=timezone.utc), - objective="Task1", - ) - ], ) result = entry.model_dump(mode="json") @@ -736,12 +712,10 @@ def test_message_piece_to_dict(): "sequence", "timestamp", "labels", - "targeted_harm_categories", "prompt_metadata", "converter_identifiers", "prompt_target_identifier", "attack_identifier", - "scorer_identifier", "original_value_data_type", "original_value", "original_value_sha256", @@ -749,9 +723,7 @@ def test_message_piece_to_dict(): "converted_value", "converted_value_sha256", "response_error", - "originator", "original_prompt_id", - "scores", ] for key in expected_keys: @@ -764,12 +736,10 @@ def test_message_piece_to_dict(): # Pydantic v2 serializes UTC datetimes with a trailing "Z" rather than "+00:00". assert result["timestamp"] == entry.timestamp.isoformat().replace("+00:00", "Z") assert result["labels"] == entry.labels - assert result["targeted_harm_categories"] == entry.targeted_harm_categories assert result["prompt_metadata"] == entry.prompt_metadata assert result["converter_identifiers"] == [conv.to_dict() for conv in entry.converter_identifiers] assert result["prompt_target_identifier"] == entry.prompt_target_identifier.to_dict() assert result["attack_identifier"] == entry.attack_identifier.to_dict() - assert result["scorer_identifier"] == entry.scorer_identifier.to_dict() assert result["original_value_data_type"] == entry.original_value_data_type assert result["original_value"] == entry.original_value assert result["original_value_sha256"] == entry.original_value_sha256 @@ -777,30 +747,7 @@ def test_message_piece_to_dict(): assert result["converted_value"] == entry.converted_value assert result["converted_value_sha256"] == entry.converted_value_sha256 assert result["response_error"] == entry.response_error - assert result["originator"] == entry.originator assert result["original_prompt_id"] == str(entry.original_prompt_id) - assert result["scores"] == [score.to_dict() for score in entry.scores] - - -def test_message_piece_scorer_identifier_none_default(): - """Test that scorer_identifier defaults to None when not provided.""" - entry = MessagePiece( - role="user", - original_value="Hello", - ) - - assert entry.scorer_identifier is None - - -def test_message_piece_to_dict_scorer_identifier_none(): - """Test that to_dict() returns None for scorer_identifier when not set.""" - entry = MessagePiece( - role="user", - original_value="Hello", - ) - - result = entry.model_dump(mode="json") - assert result["scorer_identifier"] is None def test_construct_response_from_request_combines_metadata(): @@ -915,66 +862,6 @@ def test_message_piece_has_error_and_is_blocked_consistency(): assert no_error_entry.has_error() is False -def test_message_piece_harm_categories_none(): - """Test that harm_categories defaults to None.""" - entry = MessagePiece( - role="user", - original_value="Hello", - converted_value="Hello", - ) - assert entry.targeted_harm_categories == [] - - -def test_message_piece_harm_categories_single(): - """Test that harm_categories can be set to a single category.""" - entry = MessagePiece( - role="user", original_value="Hello", converted_value="Hello", targeted_harm_categories=["violence"] - ) - assert entry.targeted_harm_categories == ["violence"] - - -def test_message_piece_harm_categories_multiple(): - """Test that harm_categories can be set to multiple categories.""" - harm_categories = ["violence", "illegal", "hate_speech"] - entry = MessagePiece( - role="user", original_value="Hello", converted_value="Hello", targeted_harm_categories=harm_categories - ) - assert entry.targeted_harm_categories == harm_categories - - -def test_message_piece_harm_categories_serialization(): - """Test that harm_categories is properly serialized in to_dict().""" - harm_categories = ["violence", "illegal"] - entry = MessagePiece( - role="user", original_value="Hello", converted_value="Hello", targeted_harm_categories=harm_categories - ) - - result = entry.model_dump(mode="json") - assert "targeted_harm_categories" in result - assert result["targeted_harm_categories"] == harm_categories - - -def test_message_piece_harm_categories_with_labels(): - """Test that harm_categories and labels can coexist.""" - harm_categories = ["violence", "illegal"] - labels = {"operation": "test_op", "researcher": "alice"} - - entry = MessagePiece( - role="user", - original_value="Hello", - converted_value="Hello", - targeted_harm_categories=harm_categories, - labels=labels, - ) - - assert entry.targeted_harm_categories == harm_categories - assert entry.labels == labels - - result = entry.model_dump(mode="json") - assert result["targeted_harm_categories"] == harm_categories - assert result["labels"] == labels - - class TestSimulatedAssistantRole: """Tests for simulated_assistant role properties.""" @@ -1217,16 +1104,12 @@ def test_to_dict_golden_shape(self) -> None: "converted_value_data_type", "converted_value_sha256", "response_error", - "originator", "original_prompt_id", "labels", - "targeted_harm_categories", "prompt_metadata", "converter_identifiers", "prompt_target_identifier", "attack_identifier", - "scorer_identifier", - "scores", ] assert list(d.keys()) == expected_keys assert d["id"] == str(piece_id) @@ -1235,20 +1118,16 @@ def test_to_dict_golden_shape(self) -> None: assert d["sequence"] == 2 assert d["timestamp"] == ts.isoformat().replace("+00:00", "Z") assert d["labels"] == {} - assert d["targeted_harm_categories"] == [] assert d["prompt_metadata"] == {} assert d["converter_identifiers"] == [] assert d["prompt_target_identifier"] is None assert d["attack_identifier"] is None - assert d["scorer_identifier"] is None assert d["original_value_data_type"] == "text" assert d["original_value"] == "hello" assert d["converted_value_data_type"] == "text" assert d["converted_value"] == "hello" assert d["response_error"] == "none" - assert d["originator"] == "undefined" assert d["original_prompt_id"] == str(piece_id) - assert d["scores"] == [] def test_message_piece_is_unhashable(self) -> None: assert MessagePiece.__hash__ is None @@ -1272,47 +1151,6 @@ def _emit_deprecation_msgs(self, **kwargs) -> list[warnings.WarningMessage]: MessagePiece(role="user", original_value="hello", **kwargs) return [x for x in w if issubclass(x.category, DeprecationWarning)] - def test_scorer_identifier_emits_deprecation_warning(self): - scorer_id = ComponentIdentifier(class_name="X", class_module="x") - msgs = self._emit_deprecation_msgs(scorer_identifier=scorer_id) - assert any("scorer_identifier" in str(m.message) for m in msgs) - - def test_scorer_identifier_omitted_no_warning(self): - msgs = self._emit_deprecation_msgs() - assert not any("scorer_identifier" in str(m.message) for m in msgs) - - def test_originator_non_default_emits_deprecation_warning(self): - msgs = self._emit_deprecation_msgs(originator="attack") - assert any("originator" in str(m.message) for m in msgs) - - def test_originator_default_no_warning(self): - msgs = self._emit_deprecation_msgs(originator="undefined") - assert not any("originator" in str(m.message) for m in msgs) - - def test_scores_emits_deprecation_warning(self): - score = Score( - score_value="true", - score_value_description="d", - score_type="true_false", - score_rationale="r", - scorer_class_identifier=ComponentIdentifier(class_name="S", class_module="s"), - message_piece_id="mp-1", - ) - msgs = self._emit_deprecation_msgs(scores=[score]) - assert any("scores" in str(m.message) for m in msgs) - - def test_scores_omitted_no_warning(self): - msgs = self._emit_deprecation_msgs() - assert not any("scores" in str(m.message) for m in msgs) - - def test_targeted_harm_categories_emits_deprecation_warning(self): - msgs = self._emit_deprecation_msgs(targeted_harm_categories=["violence"]) - assert any("targeted_harm_categories" in str(m.message) for m in msgs) - - def test_targeted_harm_categories_omitted_no_warning(self): - msgs = self._emit_deprecation_msgs() - assert not any("targeted_harm_categories" in str(m.message) for m in msgs) - def test_labels_emits_deprecation_warning(self): msgs = self._emit_deprecation_msgs(labels={"k": "v"}) assert any("labels" in str(m.message) for m in msgs) @@ -1324,10 +1162,8 @@ def test_labels_omitted_no_warning(self): def test_memory_load_roundtrip_does_not_emit_deprecation_warnings(self) -> None: """Reconstructing a MessagePiece from PromptMemoryEntry must not emit deprecations. - The memory-layer load path assigns deprecated containers (``labels``, - ``scores``, ``targeted_harm_categories``) post-construction so the - deprecation-kwarg validator is not triggered. This regression-guards - that pattern. + The memory-layer load path assigns deprecated ``labels`` post-construction so the + deprecation-kwarg validator is not triggered. This regression-guards that pattern. """ from pyrit.memory.memory_models import PromptMemoryEntry @@ -1337,7 +1173,6 @@ def test_memory_load_roundtrip_does_not_emit_deprecation_warnings(self) -> None: conversation_id="conv-deprec", ) piece.labels = {"k": "v"} - piece.targeted_harm_categories = ["violence"] entry = PromptMemoryEntry(entry=piece) @@ -1348,7 +1183,6 @@ def test_memory_load_roundtrip_does_not_emit_deprecation_warnings(self) -> None: deprecation_msgs = [w for w in caught if issubclass(w.category, DeprecationWarning)] assert deprecation_msgs == [], [str(m.message) for m in deprecation_msgs] assert reconstructed.labels == {"k": "v"} - assert reconstructed.targeted_harm_categories == ["violence"] class TestMessagePieceDeprecatedMethodShims: diff --git a/tests/unit/prompt_converter/test_add_image_text_converter.py b/tests/unit/prompt_converter/test_add_image_text_converter.py index e2518f8828..a372fdb9b7 100644 --- a/tests/unit/prompt_converter/test_add_image_text_converter.py +++ b/tests/unit/prompt_converter/test_add_image_text_converter.py @@ -41,45 +41,6 @@ def test_add_image_text_converter_initialization(image_text_converter_sample_ima assert type(converter._font) is ImageFont.FreeTypeFont -def test_add_image_text_converter_positional_arg_deprecation(image_text_converter_sample_image): - with pytest.warns(DeprecationWarning, match="Passing img_to_add as a positional argument to AddImageTextConverter"): - converter = AddImageTextConverter(image_text_converter_sample_image) - assert converter._img_to_add == image_text_converter_sample_image - - -def test_add_image_text_converter_positional_and_keyword_raises(image_text_converter_sample_image): - with pytest.raises(TypeError, match="Cannot pass img_to_add as both positional and keyword"): - AddImageTextConverter(image_text_converter_sample_image, img_to_add=image_text_converter_sample_image) - - -def test_add_image_text_converter_too_many_positional_args_raises(image_text_converter_sample_image): - with pytest.raises(TypeError, match="takes at most 1 positional argument"): - AddImageTextConverter(image_text_converter_sample_image, "extra") - - -def test_add_image_text_converter_x_pos_y_pos_deprecation(image_text_converter_sample_image): - with pytest.warns(DeprecationWarning, match=r"AddImageTextConverter\(x_pos=\.\.\., y_pos=\.\.\.\)"): - AddImageTextConverter(img_to_add=image_text_converter_sample_image, x_pos=50, y_pos=50) - - -def test_add_image_text_converter_x_pos_y_pos_deprecation_default_value(image_text_converter_sample_image): - with pytest.warns(DeprecationWarning, match=r"AddImageTextConverter\(x_pos=\.\.\., y_pos=\.\.\.\)"): - AddImageTextConverter(img_to_add=image_text_converter_sample_image, x_pos=10) - - -def test_add_image_text_converter_no_x_pos_y_pos_no_warning(image_text_converter_sample_image): - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("error", DeprecationWarning) - AddImageTextConverter(img_to_add=image_text_converter_sample_image) - - -def test_add_image_text_converter_x_pos_with_bounding_box_raises(image_text_converter_sample_image): - with pytest.raises(ValueError, match="Cannot pass x_pos/y_pos together with bounding_box"): - AddImageTextConverter(img_to_add=image_text_converter_sample_image, x_pos=10, bounding_box=(0, 0, 100, 100)) - - def test_add_image_text_converter_invalid_font(image_text_converter_sample_image): with pytest.raises(ValueError): AddImageTextConverter(img_to_add=image_text_converter_sample_image, font_name="helvetica.otf") diff --git a/tests/unit/prompt_converter/test_add_text_image_converter.py b/tests/unit/prompt_converter/test_add_text_image_converter.py index 45c2887142..0c1f461675 100644 --- a/tests/unit/prompt_converter/test_add_text_image_converter.py +++ b/tests/unit/prompt_converter/test_add_text_image_converter.py @@ -38,24 +38,6 @@ def test_add_text_image_converter_invalid_font(): AddTextImageConverter(text_to_add="Sample text", font_name="helvetica.otf") # Invalid font extension -def test_add_text_image_converter_positional_arg_deprecation(): - with pytest.warns( - DeprecationWarning, match="Passing text_to_add as a positional argument to AddTextImageConverter" - ): - converter = AddTextImageConverter("Sample text") - assert converter._text_to_add == "Sample text" - - -def test_add_text_image_converter_positional_and_keyword_raises(): - with pytest.raises(TypeError, match="Cannot pass text_to_add as both positional and keyword"): - AddTextImageConverter("Sample text", text_to_add="Sample text") - - -def test_add_text_image_converter_too_many_positional_args_raises(): - with pytest.raises(TypeError, match="takes at most 1 positional argument"): - AddTextImageConverter("Sample text", "extra") - - def test_add_text_image_converter_invalid_text_to_add(): with pytest.raises(ValueError): AddTextImageConverter(text_to_add="", font_name="helvetica.ttf") diff --git a/tests/unit/prompt_converter/test_azure_speech_converter.py b/tests/unit/prompt_converter/test_azure_speech_converter.py index 9b9b2e0f89..fb2c654439 100644 --- a/tests/unit/prompt_converter/test_azure_speech_converter.py +++ b/tests/unit/prompt_converter/test_azure_speech_converter.py @@ -132,14 +132,6 @@ def my_provider(): with pytest.raises(ValueError, match="AZURE_SPEECH_RESOURCE_ID"): AzureSpeechTextToAudioConverter(azure_speech_region="test_region", azure_speech_key=my_provider) - def test_use_entra_auth_emits_deprecation_warning(self): - with pytest.warns(DeprecationWarning, match="use_entra_auth.*deprecated"): - AzureSpeechTextToAudioConverter( - azure_speech_region="test_region", - azure_speech_resource_id="test_resource_id", - use_entra_auth=True, - ) - @patch("azure.cognitiveservices.speech.SpeechConfig") @patch( "pyrit.common.default_values.get_required_value", diff --git a/tests/unit/prompt_converter/test_azure_speech_text_converter.py b/tests/unit/prompt_converter/test_azure_speech_text_converter.py index 98057a66ee..3fa1c74c1d 100644 --- a/tests/unit/prompt_converter/test_azure_speech_text_converter.py +++ b/tests/unit/prompt_converter/test_azure_speech_text_converter.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import warnings from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -118,15 +117,6 @@ async def test_azure_speech_audio_text_converter_non_wav_file(self, mock_path_ex with pytest.raises(ValueError): assert await converter.convert_async(prompt=prompt, input_type="audio_path") - def test_use_entra_auth_emits_deprecation_warning(self): - """Test that use_entra_auth emits DeprecationWarning.""" - with pytest.warns(DeprecationWarning, match="use_entra_auth.*deprecated"): - AzureSpeechAudioToTextConverter( - azure_speech_region="test_region", - azure_speech_resource_id="test_resource_id", - use_entra_auth=True, - ) - @patch( "pyrit.common.default_values.get_required_value", side_effect=lambda env_var_name, passed_value: passed_value or "dummy_value", @@ -259,50 +249,3 @@ async def test_convert_async_happy_path(self, mock_required, mock_factory, mock_ assert result.output_type == "text" mock_get_config.assert_called_once() mock_recognize.assert_called_once_with(audio_bytes=b"fake audio bytes", speech_config=mock_speech_config) - - @patch("pyrit.prompt_converter.azure_speech_audio_to_text_converter.get_speech_config") - @patch( - "pyrit.common.default_values.get_required_value", - side_effect=lambda env_var_name, passed_value: passed_value or "dummy_value", - ) - def test_recognize_audio_calls_get_speech_config(self, mock_required, mock_get_config): - """Test that recognize_audio() calls get_speech_config and _recognize_audio.""" - mock_speech_config = MagicMock() - mock_get_config.return_value = mock_speech_config - - converter = AzureSpeechAudioToTextConverter(azure_speech_region="test_region", azure_speech_key="test_key") - - with patch.object(converter, "_recognize_audio", return_value="transcribed") as mock_recognize: - result = converter.recognize_audio(audio_bytes=b"fake audio") - - assert result == "transcribed" - mock_get_config.assert_called_once_with(resource_id=None, key="test_key", region="test_region") - mock_recognize.assert_called_once_with(audio_bytes=b"fake audio", speech_config=mock_speech_config) - - @patch( - "pyrit.common.default_values.get_required_value", - side_effect=lambda env_var_name, passed_value: passed_value or "dummy_value", - ) - def test_recognize_audio_warns_when_token_provider_set(self, mock_required): - """Test that recognize_audio() emits DeprecationWarning when _token_provider is set.""" - - def my_provider(): - return "my_token" - - converter = AzureSpeechAudioToTextConverter( - azure_speech_region="test_region", - azure_speech_key=my_provider, - azure_speech_resource_id="test_resource_id", - ) - - with ( - patch("pyrit.prompt_converter.azure_speech_audio_to_text_converter.get_speech_config") as mock_config, - patch.object(converter, "_recognize_audio", return_value="text"), - warnings.catch_warnings(record=True) as w, - ): - warnings.simplefilter("always") - mock_config.return_value = MagicMock() - converter.recognize_audio(audio_bytes=b"fake audio") - - deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] - assert any("recognize_audio" in str(x.message) and "deprecated" in str(x.message) for x in deprecation_warnings) diff --git a/tests/unit/prompt_target/target/test_hugging_face_endpoint_target.py b/tests/unit/prompt_target/target/test_hugging_face_endpoint_target.py deleted file mode 100644 index 72a58173fe..0000000000 --- a/tests/unit/prompt_target/target/test_hugging_face_endpoint_target.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from pyrit.models import Message, MessagePiece -from pyrit.prompt_target.hugging_face.hugging_face_endpoint_target import ( - HuggingFaceEndpointTarget, -) - -# HuggingFaceEndpointTarget emits a DeprecationWarning on construction -pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") - - -@pytest.fixture -def hugging_face_endpoint_target(patch_central_database) -> HuggingFaceEndpointTarget: - return HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - ) - - -def test_hugging_face_endpoint_initializes(hugging_face_endpoint_target: HuggingFaceEndpointTarget): - assert hugging_face_endpoint_target - - -def test_hugging_face_endpoint_sets_endpoint_and_rate_limit(): - target = HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - max_requests_per_minute=30, - ) - identifier = target.get_identifier() - assert identifier.params["endpoint"] == "https://api-inference.huggingface.co/models/test-model" - assert target._max_requests_per_minute == 30 - - -def test_invalid_temperature_too_low_raises(patch_central_database): - with pytest.raises(Exception, match="temperature must be between 0 and 2"): - HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - temperature=-0.1, - ) - - -def test_invalid_temperature_too_high_raises(patch_central_database): - with pytest.raises(Exception, match="temperature must be between 0 and 2"): - HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - temperature=2.1, - ) - - -def test_invalid_top_p_too_low_raises(patch_central_database): - with pytest.raises(Exception, match="top_p must be between 0 and 1"): - HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - top_p=-0.1, - ) - - -def test_invalid_top_p_too_high_raises(patch_central_database): - with pytest.raises(Exception, match="top_p must be between 0 and 1"): - HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - top_p=1.1, - ) - - -def test_valid_temperature_and_top_p(patch_central_database): - # Should not raise any exceptions - target = HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - temperature=1.5, - top_p=0.9, - ) - assert target._temperature == 1.5 - assert target._top_p == 0.9 - - -def test_identifier_includes_generation_params(): - """New generation params (top_k, do_sample, repetition_penalty) appear in the identifier.""" - target = HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - top_k=40, - do_sample=True, - repetition_penalty=1.2, - ) - identifier = target.get_identifier() - assert identifier.params["top_k"] == 40 - assert identifier.params["do_sample"] is True - assert identifier.params["repetition_penalty"] == 1.2 - - -def test_identifier_excludes_none_generation_params(): - """None-valued generation params are excluded from the identifier.""" - target = HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - ) - identifier = target.get_identifier() - assert "top_k" not in identifier.params - assert "do_sample" not in identifier.params - assert "repetition_penalty" not in identifier.params - - -def test_sampling_params_without_do_sample_warns(): - """Setting temperature != 1.0 without do_sample=True emits a warning.""" - with pytest.warns(UserWarning, match="do_sample is not True"): - HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - temperature=0.7, - ) - - -def test_sampling_params_with_do_sample_no_warning(): - """Setting temperature != 1.0 with do_sample=True does not warn.""" - import warnings as _warnings - - with _warnings.catch_warnings(): - _warnings.simplefilter("error", UserWarning) - HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - temperature=0.7, - do_sample=True, - ) - - -@pytest.mark.filterwarnings("default::DeprecationWarning") -def test_init_emits_deprecation_warning(): - """HuggingFaceEndpointTarget emits a DeprecationWarning on construction.""" - with pytest.warns(DeprecationWarning, match="deprecated and will be removed"): - HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - ) - - -def _make_user_message(text: str) -> Message: - """Helper to create a single-piece user Message.""" - return Message( - message_pieces=[ - MessagePiece( - role="user", - original_value=text, - converted_value=text, - converted_value_data_type="text", - ) - ] - ) - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("patch_central_database") -async def test_send_prompt_async_list_response(): - """Verify send_prompt_async handles a list response from the HF API.""" - target = HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - ) - - mock_response = MagicMock() - mock_response.json.return_value = [{"generated_text": "Hello from HF"}] - - with patch( - "pyrit.prompt_target.hugging_face.hugging_face_endpoint_target.make_request_and_raise_if_error_async", - new_callable=AsyncMock, - return_value=mock_response, - ): - message = _make_user_message("test prompt") - response = await target.send_prompt_async(message=message) - - assert len(response) == 1 - assert response[0].message_pieces[0].original_value == "Hello from HF" - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("patch_central_database") -async def test_send_prompt_async_dict_response(): - """Verify send_prompt_async handles a dict response from the HF API.""" - target = HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - ) - - mock_response = MagicMock() - mock_response.json.return_value = {"generated_text": "Dict response"} - - with patch( - "pyrit.prompt_target.hugging_face.hugging_face_endpoint_target.make_request_and_raise_if_error_async", - new_callable=AsyncMock, - return_value=mock_response, - ): - message = _make_user_message("test prompt") - response = await target.send_prompt_async(message=message) - - assert len(response) == 1 - assert response[0].message_pieces[0].original_value == "Dict response" - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("patch_central_database") -async def test_send_prompt_async_passes_optional_params_in_payload(): - """Verify optional generation params are included in the HTTP payload.""" - target = HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - top_k=40, - do_sample=True, - repetition_penalty=1.2, - ) - - mock_response = MagicMock() - mock_response.json.return_value = [{"generated_text": "response"}] - - with patch( - "pyrit.prompt_target.hugging_face.hugging_face_endpoint_target.make_request_and_raise_if_error_async", - new_callable=AsyncMock, - return_value=mock_response, - ) as mock_request: - message = _make_user_message("test prompt") - await target.send_prompt_async(message=message) - - call_kwargs = mock_request.call_args[1] - params = call_kwargs["request_body"]["parameters"] - assert params["top_k"] == 40 - assert params["do_sample"] is True - assert params["repetition_penalty"] == 1.2 - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("patch_central_database") -async def test_send_prompt_async_omits_none_params_from_payload(): - """Verify None-valued optional params are not in the HTTP payload.""" - target = HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - ) - - mock_response = MagicMock() - mock_response.json.return_value = [{"generated_text": "response"}] - - with patch( - "pyrit.prompt_target.hugging_face.hugging_face_endpoint_target.make_request_and_raise_if_error_async", - new_callable=AsyncMock, - return_value=mock_response, - ) as mock_request: - message = _make_user_message("test prompt") - await target.send_prompt_async(message=message) - - call_kwargs = mock_request.call_args[1] - params = call_kwargs["request_body"]["parameters"] - assert "top_k" not in params - assert "do_sample" not in params - assert "repetition_penalty" not in params - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("patch_central_database") -async def test_send_prompt_async_metadata_contains_model_id(): - """Verify prompt_metadata includes the model_id.""" - target = HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - ) - - mock_response = MagicMock() - mock_response.json.return_value = [{"generated_text": "response"}] - - with patch( - "pyrit.prompt_target.hugging_face.hugging_face_endpoint_target.make_request_and_raise_if_error_async", - new_callable=AsyncMock, - return_value=mock_response, - ): - message = _make_user_message("test prompt") - response = await target.send_prompt_async(message=message) - - metadata = response[0].message_pieces[0].prompt_metadata - assert metadata["model_id"] == "test-model" - - -def test_validate_request_rejects_multiple_pieces(): - """Verify _validate_request raises for messages with multiple pieces.""" - target = HuggingFaceEndpointTarget( - hf_token="test_token", - endpoint="https://api-inference.huggingface.co/models/test-model", - model_id="test-model", - ) - - piece1 = MessagePiece( - role="user", - original_value="first", - converted_value="first", - converted_value_data_type="text", - conversation_id="conv1", - ) - piece2 = MessagePiece( - role="user", - original_value="second", - converted_value="second", - converted_value_data_type="text", - conversation_id="conv1", - ) - message = Message(message_pieces=[piece1, piece2]) - - with pytest.raises(ValueError, match="single message piece"): - target._validate_request(normalized_conversation=[message]) diff --git a/tests/unit/prompt_target/target/test_image_target.py b/tests/unit/prompt_target/target/test_image_target.py index c14d89cf8d..3eca3e7607 100644 --- a/tests/unit/prompt_target/target/test_image_target.py +++ b/tests/unit/prompt_target/target/test_image_target.py @@ -3,7 +3,6 @@ import os import uuid -import warnings from collections.abc import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch @@ -400,54 +399,6 @@ async def test_send_prompt_async_bad_request_content_policy_violation( assert result[0].message_pieces[0].converted_value_data_type == "error" -async def test_send_prompt_async_url_response_downloads_image( - image_target: OpenAIImageTarget, - sample_conversations: MutableSequence[MessagePiece], -): - """Test that when model returns URL instead of base64, the image is downloaded from URL.""" - request = sample_conversations[0] - request.conversation_id = str(uuid.uuid4()) - - # Response returns URL (no b64_json) - mock_response_url = MagicMock() - mock_image_url = MagicMock() - mock_image_url.b64_json = None - mock_image_url.url = "https://example.com/image.png" - mock_response_url.data = [mock_image_url] - - # Mock httpx response for URL download - mock_http_response = MagicMock() - mock_http_response.content = b"hello" - mock_http_response.raise_for_status = MagicMock() - - with patch.object(image_target._async_client.images, "generate", new_callable=AsyncMock) as mock_generate: - mock_generate.return_value = mock_response_url - - with patch("pyrit.prompt_target.openai.openai_image_target.httpx.AsyncClient") as mock_httpx: - mock_client_instance = AsyncMock() - mock_client_instance.get = AsyncMock(return_value=mock_http_response) - mock_httpx.return_value.__aenter__.return_value = mock_client_instance - - resp = await image_target.send_prompt_async(message=Message([request])) - - # Should have called generate once - assert mock_generate.call_count == 1 - - # Should have downloaded from the URL - mock_client_instance.get.assert_called_once_with("https://example.com/image.png") - - # Should have successfully returned the image - assert len(resp) == 1 - path = resp[0].message_pieces[0].original_value - assert os.path.isfile(path) - - with open(path, "rb") as file: - data = file.read() - assert data == b"hello" - - os.remove(path) - - async def test_validate_no_text_piece(image_target: OpenAIImageTarget): image_piece = get_image_message_piece() @@ -556,101 +507,6 @@ async def test_validate_previous_conversations( await image_target.send_prompt_async(message=request) -def test_style_param_emits_deprecation_warning(patch_central_database): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - target = OpenAIImageTarget( - model_name="gpt-image-1", - endpoint="test", - api_key="test", - style="vivid", - ) - deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] - style_warnings = [w for w in deprecation_warnings if "style" in str(w.message)] - assert len(style_warnings) == 1 - assert "0.15.0" in str(style_warnings[0].message) - assert "2026-05-12" in str(style_warnings[0].message) - assert target.style == "vivid" - - -def test_no_style_does_not_emit_deprecation_warning(patch_central_database): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - OpenAIImageTarget( - model_name="gpt-image-1", - endpoint="test", - api_key="test", - ) - style_warnings = [ - w for w in caught if issubclass(w.category, DeprecationWarning) and "OpenAIImageTarget(style" in str(w.message) - ] - assert len(style_warnings) == 0 - - -@pytest.mark.parametrize("deprecated_size", ["256x256", "512x512", "1792x1024", "1024x1792"]) -def test_deprecated_image_size_emits_warning(patch_central_database, deprecated_size): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - target = OpenAIImageTarget( - model_name="gpt-image-1", - endpoint="test", - api_key="test", - image_size=deprecated_size, - ) - deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] - size_warnings = [w for w in deprecation_warnings if "image_size" in str(w.message)] - assert len(size_warnings) == 1 - assert "0.15.0" in str(size_warnings[0].message) - assert "2026-05-12" in str(size_warnings[0].message) - assert target.image_size == deprecated_size - - -@pytest.mark.parametrize("valid_size", ["auto", "1024x1024", "1536x1024", "1024x1536"]) -def test_valid_image_size_does_not_emit_warning(patch_central_database, valid_size): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - OpenAIImageTarget( - model_name="gpt-image-1", - endpoint="test", - api_key="test", - image_size=valid_size, - ) - size_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning) and "image_size" in str(w.message)] - assert len(size_warnings) == 0 - - -@pytest.mark.parametrize("deprecated_quality", ["standard", "hd"]) -def test_deprecated_quality_emits_warning(patch_central_database, deprecated_quality): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - target = OpenAIImageTarget( - model_name="gpt-image-1", - endpoint="test", - api_key="test", - quality=deprecated_quality, - ) - deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] - quality_warnings = [w for w in deprecation_warnings if "quality" in str(w.message)] - assert len(quality_warnings) == 1 - assert "0.15.0" in str(quality_warnings[0].message) - assert "2026-05-12" in str(quality_warnings[0].message) - assert target.quality == deprecated_quality - - -@pytest.mark.parametrize("valid_quality", ["auto", "low", "medium", "high"]) -def test_valid_quality_does_not_emit_warning(patch_central_database, valid_quality): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - OpenAIImageTarget( - model_name="gpt-image-1", - endpoint="test", - api_key="test", - quality=valid_quality, - ) - quality_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning) and "quality" in str(w.message)] - assert len(quality_warnings) == 0 - - def test_background_param_stored(patch_central_database): target = OpenAIImageTarget( model_name="gpt-image-1", diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py index 2bd58d18a0..104c2b086e 100644 --- a/tests/unit/prompt_target/target/test_normalize_async_integration.py +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -4,7 +4,6 @@ from __future__ import annotations import json -import warnings from typing import TYPE_CHECKING from unittest.mock import AsyncMock, MagicMock, patch @@ -15,7 +14,6 @@ from openai.types.chat import ChatCompletion from pyrit.memory.memory_interface import MemoryInterface -from pyrit.message_normalizer import GenericSystemSquashNormalizer from pyrit.models import ComponentIdentifier, Message, MessagePiece from pyrit.prompt_target import AzureMLChatTarget, OpenAIChatTarget from pyrit.prompt_target.common.target_capabilities import ( @@ -304,61 +302,6 @@ async def test_azure_ml_target_memory_not_mutated(): assert memory_conversation[0].get_piece().api_role == "system" -# --------------------------------------------------------------------------- -# AzureMLChatTarget — message_normalizer deprecation -# --------------------------------------------------------------------------- - - -@pytest.mark.usefixtures("patch_central_database") -def test_azure_ml_generic_system_squash_normalizer_emits_deprecation_warning(): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - target = AzureMLChatTarget( - endpoint="http://aml-test-endpoint.com", - api_key="valid_api_key", - message_normalizer=GenericSystemSquashNormalizer(), - ) - deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] - assert len(deprecation_warnings) == 1 - assert "message_normalizer" in str(deprecation_warnings[0].message) - assert "deprecated" in str(deprecation_warnings[0].message) - - -@pytest.mark.usefixtures("patch_central_database") -def test_azure_ml_generic_system_squash_normalizer_creates_adapt_configuration(): - """Legacy message_normalizer should be translated into a TargetConfiguration with ADAPT policy.""" - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - target = AzureMLChatTarget( - endpoint="http://aml-test-endpoint.com", - api_key="valid_api_key", - message_normalizer=GenericSystemSquashNormalizer(), - ) - # The shim should create a config with supports_system_prompt=False - assert not target.capabilities.supports_system_prompt - assert target.configuration.includes(capability=CapabilityName.MULTI_TURN) - assert not target.configuration.includes(capability=CapabilityName.SYSTEM_PROMPT) - - -@pytest.mark.usefixtures("patch_central_database") -def test_azure_ml_message_normalizer_and_custom_config_raises(): - """Passing both message_normalizer and custom_configuration should raise ValueError.""" - custom_config = TargetConfiguration( - capabilities=TargetCapabilities( - supports_multi_turn=True, - supports_system_prompt=True, - supports_multi_message_pieces=True, - ) - ) - with pytest.raises(ValueError, match="Cannot specify both"): - AzureMLChatTarget( - endpoint="http://aml-test-endpoint.com", - api_key="valid_api_key", - message_normalizer=GenericSystemSquashNormalizer(), - custom_configuration=custom_config, - ) - - @pytest.mark.usefixtures("patch_central_database") async def test_azure_ml_system_squash_via_configuration_pipeline(): """End-to-end: GenericSystemSquashNormalizer-equivalent behavior via TargetConfiguration pipeline.""" diff --git a/tests/unit/prompt_target/target/test_target_capabilities.py b/tests/unit/prompt_target/target/test_target_capabilities.py index c8e6de4c3a..3899223261 100644 --- a/tests/unit/prompt_target/target/test_target_capabilities.py +++ b/tests/unit/prompt_target/target/test_target_capabilities.py @@ -124,7 +124,6 @@ def _all_concrete_target_classes(self): HTTPTarget, HTTPXAPITarget, HuggingFaceChatTarget, - HuggingFaceEndpointTarget, OpenAIChatTarget, OpenAICompletionTarget, OpenAIImageTarget, @@ -146,7 +145,6 @@ def _all_concrete_target_classes(self): HTTPTarget, HTTPXAPITarget, HuggingFaceChatTarget, - HuggingFaceEndpointTarget, OpenAIChatTarget, OpenAICompletionTarget, OpenAIImageTarget, diff --git a/tests/unit/scenario/airt/test_rapid_response.py b/tests/unit/scenario/airt/test_rapid_response.py index ba464b129f..ca2471907d 100644 --- a/tests/unit/scenario/airt/test_rapid_response.py +++ b/tests/unit/scenario/airt/test_rapid_response.py @@ -502,52 +502,6 @@ def test_factories_always_use_default_adversarial(self, mock_objective_scorer): # =========================================================================== -# Deprecated alias tests -# =========================================================================== - - -@pytest.mark.usefixtures(*FIXTURES) -class TestDeprecatedAliases: - """Tests for backward-compatible ContentHarms aliases.""" - - def test_content_harms_is_rapid_response(self): - with pytest.warns(DeprecationWarning, match="ContentHarms"): - from pyrit.scenario.scenarios.airt.content_harms import ContentHarms - - assert ContentHarms is RapidResponse - - def test_content_harms_strategy_is_rapid_response_strategy(self): - with pytest.warns(DeprecationWarning, match="ContentHarmsStrategy"): - from pyrit.scenario.scenarios.airt.content_harms import ContentHarmsStrategy - - assert ContentHarmsStrategy is _strategy_class() - - def test_content_harms_instance_name_is_rapid_response(self, mock_objective_scorer): - """ContentHarms() creates a RapidResponse with name 'RapidResponse'.""" - with pytest.warns(DeprecationWarning, match="ContentHarms"): - from pyrit.scenario.scenarios.airt.content_harms import ContentHarms - - scenario = ContentHarms( - objective_scorer=mock_objective_scorer, - ) - assert scenario.name == "RapidResponse" - assert isinstance(scenario, RapidResponse) - - def test_content_harms_via_airt_package_emits_deprecation_warning(self): - """Importing ``ContentHarms`` from the parent ``airt`` package emits the warning.""" - with pytest.warns(DeprecationWarning, match="ContentHarms"): - from pyrit.scenario.scenarios.airt import ContentHarms - - assert ContentHarms is RapidResponse - - def test_content_harms_strategy_via_airt_package_emits_deprecation_warning(self): - """Importing ``ContentHarmsStrategy`` from the parent ``airt`` package emits the warning.""" - with pytest.warns(DeprecationWarning, match="ContentHarmsStrategy"): - from pyrit.scenario.scenarios.airt import ContentHarmsStrategy - - assert ContentHarmsStrategy is _strategy_class() - - # =========================================================================== # Registry integration tests # =========================================================================== diff --git a/tests/unit/scenario/scenarios/adaptive/test_epsilon_greedy.py b/tests/unit/scenario/scenarios/adaptive/test_epsilon_greedy.py index 21144721f5..bfade07b3c 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_epsilon_greedy.py +++ b/tests/unit/scenario/scenarios/adaptive/test_epsilon_greedy.py @@ -137,7 +137,6 @@ async def test_default_scope_passes_none_scenario_result_id(self, mock_compute): # Default scope is all_runs(): the per-call scenario_result_id is dropped. assert mock_compute.call_args.kwargs["scenario_result_id"] is None - assert mock_compute.call_args.kwargs["targeted_harm_categories"] is None @patch(_COMPUTE_PATH, side_effect=_empty_rates) async def test_current_run_scope_forwards_scenario_result_id(self, mock_compute): @@ -146,17 +145,6 @@ async def test_current_run_scope_forwards_scenario_result_id(self, mock_compute) assert mock_compute.call_args.kwargs["scenario_result_id"] == "run-42" - @patch(_COMPUTE_PATH, side_effect=_empty_rates) - async def test_scope_filter_fields_forwarded(self, mock_compute): - scope = SelectorScope( - targeted_harm_categories=["misinformation"], - ) - selector = EpsilonGreedyTechniqueSelector(epsilon=0.0, random_seed=0, scope=scope) - await selector.select_async(technique_identifiers=TECHNIQUES, objective="obj") - - kwargs = mock_compute.call_args.kwargs - assert kwargs["targeted_harm_categories"] == ["misinformation"] - class TestEpsilonGreedyEstimate: def test_estimate_unseen_is_one(self): diff --git a/tests/unit/scenario/scenarios/adaptive/test_selector_scope.py b/tests/unit/scenario/scenarios/adaptive/test_selector_scope.py index f024d2c5a5..71cd8aad20 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_selector_scope.py +++ b/tests/unit/scenario/scenarios/adaptive/test_selector_scope.py @@ -12,7 +12,6 @@ class TestSelectorScopeDefaults: def test_default_constructs_all_runs(self): scope = SelectorScope() assert scope.current_run_only is False - assert scope.targeted_harm_categories is None def test_all_runs_classmethod_equivalent_to_default(self): assert SelectorScope.all_runs() == SelectorScope() @@ -20,7 +19,6 @@ def test_all_runs_classmethod_equivalent_to_default(self): def test_current_run_classmethod_sets_flag(self): scope = SelectorScope.current_run() assert scope.current_run_only is True - assert scope.targeted_harm_categories is None class TestSelectorScopeFrozen: @@ -29,24 +27,15 @@ def test_assigning_field_raises(self): with pytest.raises(dataclasses.FrozenInstanceError): scope.current_run_only = True # type: ignore[misc] - def test_assigning_new_field_raises(self): - scope = SelectorScope() - with pytest.raises(dataclasses.FrozenInstanceError): - scope.targeted_harm_categories = ("a",) # type: ignore[misc] - class TestSelectorScopeCombinations: def test_fields_combine(self): - scope = SelectorScope( - current_run_only=True, - targeted_harm_categories=["misinformation"], - ) + scope = SelectorScope(current_run_only=True) assert scope.current_run_only is True - assert scope.targeted_harm_categories == ["misinformation"] def test_equality_value_based(self): - a = SelectorScope(targeted_harm_categories=("y",)) - b = SelectorScope(targeted_harm_categories=("y",)) + a = SelectorScope(current_run_only=True) + b = SelectorScope(current_run_only=True) assert a == b def test_inequality_when_fields_differ(self): diff --git a/tests/unit/score/test_audio_scorer.py b/tests/unit/score/test_audio_scorer.py index b228e71cb3..a99d28b1b3 100644 --- a/tests/unit/score/test_audio_scorer.py +++ b/tests/unit/score/test_audio_scorer.py @@ -225,15 +225,7 @@ async def test_score_piece_empty_transcript(self, audio_message_piece): @pytest.mark.usefixtures("patch_central_database") class TestAudioTranscriptHelper: - """Tests for AudioTranscriptHelper deprecation and transcription.""" - - def test_use_entra_auth_emits_deprecation_warning(self): - """Test that passing use_entra_auth to AudioTranscriptHelper emits DeprecationWarning.""" - from pyrit.score.audio_transcript_scorer import AudioTranscriptHelper - - text_scorer = MockTextTrueFalseScorer() - with pytest.warns(DeprecationWarning, match="use_entra_auth.*deprecated"): - AudioTranscriptHelper(text_capable_scorer=text_scorer, use_entra_auth=True) + """Tests for AudioTranscriptHelper transcription.""" async def test_transcribe_audio_async_creates_converter(self, audio_message_piece): """Test that _transcribe_audio_async creates AzureSpeechAudioToTextConverter and calls convert_async.""" diff --git a/tests/unit/score/test_scorer_metrics.py b/tests/unit/score/test_scorer_metrics.py index 20318caf57..6974467f2b 100644 --- a/tests/unit/score/test_scorer_metrics.py +++ b/tests/unit/score/test_scorer_metrics.py @@ -5,8 +5,6 @@ from pathlib import Path from unittest.mock import patch -import pytest - from pyrit.models import ComponentIdentifier from pyrit.score import ( HarmScorerMetrics, @@ -66,24 +64,6 @@ def test_objective_metrics_to_json_and_from_json_file(self, tmp_path): loaded = ObjectiveScorerMetrics.from_json_file(str(file_path)) assert loaded == metrics - def test_from_json_is_deprecated_alias_for_from_json_file(self, tmp_path): - metrics = ObjectiveScorerMetrics( - num_responses=10, - num_human_raters=3, - accuracy=0.9, - accuracy_standard_error=0.05, - f1_score=0.8, - precision=0.85, - recall=0.75, - ) - file_path = tmp_path / "metrics.json" - with open(file_path, "w") as f: - f.write(metrics.to_json()) - - with pytest.warns(DeprecationWarning, match="ObjectiveScorerMetrics.from_json"): - loaded = ObjectiveScorerMetrics.from_json(str(file_path)) - assert loaded == metrics - class TestScorerMetricsWithIdentity: """Tests for ScorerMetricsWithIdentity dataclass.""" From 846bc9c8953ce4cd031c538390183ba686ef4e12 Mon Sep 17 00:00:00 2001 From: Copilot <223556219+Copilot@users.noreply.github.com> Date: Tue, 9 Jun 2026 12:48:25 -0700 Subject: [PATCH 2/2] Fix ty type errors in memory_models surfaced by merge These errors pre-existed on main but only get checked when the file is modified in a PR. The merge of main into this branch put memory_models.py back in the pre-commit diff and surfaced 7 ty diagnostics: Real type bugs fixed: - get_message_piece: filter None from converter_ids before passing to MessagePiece.converter_identifiers (whose type is list[ComponentIdentifier], not list[ComponentIdentifier | None]). - get_scenario_result: wrap self.scenario_run_state (Mapped[str]) with ScenarioRunState(...) so the ScenarioResult.scenario_run_state ScenarioRunState field gets the correct type. SQLAlchemy descriptor quirks (ty cannot model them): - Added # type: ignore[ty:invalid-assignment] to 5 column assignments where the runtime value is compatible but the typed Mapped[...] descriptor's __set__ confuses ty (converter_identifiers, response_error, attribution_data, objective_target_identifier, scenario_metadata). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/memory/memory_models.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index c4d80b4ea7..964e9ffefe 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -51,6 +51,7 @@ SeedSimulatedConversation, SeedType, ) +from pyrit.models.scenario_result import ScenarioRunState logger = logging.getLogger(__name__) @@ -306,7 +307,7 @@ def __init__(self, *, entry: MessagePiece) -> None: self.timestamp = entry.timestamp self.labels = entry.labels self.prompt_metadata = entry.prompt_metadata - self.converter_identifiers = _dump_identifiers(entry.converter_identifiers) + self.converter_identifiers = _dump_identifiers(entry.converter_identifiers) # type: ignore[ty:invalid-assignment] self.prompt_target_identifier = _dump_identifier(entry.prompt_target_identifier) or {} self.attack_identifier = _dump_identifier(entry.attack_identifier) or {} @@ -318,7 +319,7 @@ def __init__(self, *, entry: MessagePiece) -> None: self.converted_value_data_type = entry.converted_value_data_type self.converted_value_sha256 = entry.converted_value_sha256 - self.response_error = entry.response_error + self.response_error = entry.response_error # type: ignore[ty:invalid-assignment] self.original_prompt_id = entry.original_prompt_id self.pyrit_version = pyrit.__version__ @@ -346,7 +347,7 @@ def get_message_piece(self) -> MessagePiece: conversation_id=self.conversation_id, sequence=self.sequence, prompt_metadata=self.prompt_metadata, - converter_identifiers=converter_ids or [], + converter_identifiers=[c for c in (converter_ids or []) if c is not None], prompt_target_identifier=target_id, attack_identifier=attack_id, original_value_data_type=self.original_value_data_type, @@ -863,7 +864,7 @@ def __init__(self, *, entry: AttackResult) -> None: # Attribution / parent linkage (set by the attack persistence path when # an AttackResultAttribution is present on the AttackContext; otherwise None) self.attribution_parent_id = uuid.UUID(entry.attribution_parent_id) if entry.attribution_parent_id else None - self.attribution_data = entry.attribution_data + self.attribution_data = entry.attribution_data # type: ignore[ty:invalid-assignment] @staticmethod def _get_id_as_uuid(obj: Any) -> uuid.UUID | None: @@ -1054,7 +1055,7 @@ def __init__(self, *, entry: ScenarioResult) -> None: self.pyrit_version = entry.scenario_identifier.pyrit_version self.scenario_init_data = entry.scenario_identifier.init_data # Convert ComponentIdentifier to dict for JSON storage - self.objective_target_identifier = _dump_identifier(entry.objective_target_identifier) + self.objective_target_identifier = _dump_identifier(entry.objective_target_identifier) # type: ignore[ty:invalid-assignment] # Ensure eval_hash is set before truncation so it survives the DB round-trip. if entry.objective_scorer_identifier and entry.objective_scorer_identifier.eval_hash is None: entry.objective_scorer_identifier = entry.objective_scorer_identifier.with_eval_hash( @@ -1078,7 +1079,7 @@ def __init__(self, *, entry: ScenarioResult) -> None: self.error_message = entry.error_message self.error_type = entry.error_type - self.scenario_metadata = entry.metadata if entry.metadata else None + self.scenario_metadata = entry.metadata if entry.metadata else None # type: ignore[ty:invalid-assignment] self.timestamp = datetime.now(tz=timezone.utc) @@ -1123,7 +1124,7 @@ def get_scenario_result(self) -> ScenarioResult: objective_target_identifier=target_identifier, attack_results=attack_results, objective_scorer_identifier=scorer_identifier, - scenario_run_state=self.scenario_run_state, + scenario_run_state=ScenarioRunState(self.scenario_run_state), labels=self.labels or {}, creation_time=self.timestamp, number_tries=self.number_tries,