Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pyrit/analytics/technique_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def compute_technique_stats(
*,
technique_eval_hashes: Sequence[str],
scenario_result_id: str | None = None,
targeted_harm_categories: Sequence[str] | None = None,
memory: MemoryInterface | None = None,
) -> dict[str, AttackStats]:
"""
Expand All @@ -40,8 +39,6 @@ def compute_technique_stats(
Returned dict is keyed by these.
scenario_result_id (str | None): Restrict to a single scenario run.
Defaults to ``None`` (aggregate across all runs).
targeted_harm_categories (Sequence[str] | None): Restrict to results
whose prompts targeted these harm categories. Defaults to ``None``.
memory (MemoryInterface | None): Memory backend to query. Defaults to
``CentralMemory.get_memory_instance()``.

Expand All @@ -57,7 +54,6 @@ def compute_technique_stats(
results = memory.get_attack_results(
atomic_attack_eval_hashes=list(technique_eval_hashes),
scenario_result_id=scenario_result_id,
targeted_harm_categories=targeted_harm_categories,
)

requested = set(technique_eval_hashes)
Expand Down
27 changes: 16 additions & 11 deletions pyrit/executor/attack/component/conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,22 @@ async def _process_prepended_for_chat_target_async(
if hasattr(context, "executed_turns"):
context.executed_turns = state.turn_count # type: ignore[ty:invalid-assignment]

# Extract scores on final prepended assistant message if it exists and are relavent
# Multi-part messages (e.g., text + image) may have scores on multiple pieces
# only extract true_false scores with score_value=False. This allows attacks
# to use the score's rationale for feedback without re-scoring.
for piece in final_prepended_message.message_pieces:
for score in piece.scores:
if score.score_type == "true_false" and score.get_value() is False:
state.last_assistant_message_scores.append(score)
# context.last_score gets the first matching score for single-score use cases.
if hasattr(context, "last_score") and context.last_score is None:
context.last_score = score # type: ignore[ty:invalid-assignment]
# Extract scores on final prepended assistant message if it exists and are relevant.
# The prepended pieces were re-keyed with new ids when added to memory, so look
# them up by conversation_id and filter to the last assistant turn. Only extract
# true_false scores with score_value=False so attacks can use the rationale for
# feedback without re-scoring.
memory_pieces = self._memory.get_message_pieces(conversation_id=conversation_id)
assistant_piece_ids = [str(piece.id) for piece in memory_pieces if piece.api_role == "assistant"]
existing_scores = (
self._memory.get_prompt_scores(prompt_ids=assistant_piece_ids) if assistant_piece_ids else []
)
for score in existing_scores:
if score.score_type == "true_false" and score.get_value() is False:
state.last_assistant_message_scores.append(score)
# context.last_score gets the first matching score for single-score use cases.
if hasattr(context, "last_score") and context.last_score is None:
context.last_score = score # type: ignore[ty:invalid-assignment]

return state

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Drop deprecated columns scheduled for removal in v0.15.0.

* ``AttackResultEntries.attack_identifier`` (superseded by
``atomic_attack_identifier``).
* ``PromptMemoryEntries.targeted_harm_categories`` (callers should use the
attack-level ``labels`` column with ``{"harm_category": [...]}`` instead).

Revision ID: f1a2b3c4d5e6
Revises: 9c8b7a6d5e4f
Create Date: 2026-06-05 14:39:00.000000
"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "f1a2b3c4d5e6"
down_revision: str | None = "9c8b7a6d5e4f"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
"""Apply this schema upgrade."""
# SQLite does not support DROP COLUMN on a table with constraints in older
# versions; use batch_alter_table so the operation is portable across both
# SQLite and Azure SQL.
with op.batch_alter_table("AttackResultEntries") as batch_op:
batch_op.drop_column("attack_identifier")
with op.batch_alter_table("PromptMemoryEntries") as batch_op:
batch_op.drop_column("targeted_harm_categories")


def downgrade() -> None:
"""Revert this schema upgrade."""
# Re-add the columns as nullable so legacy code can still write to them
# (the not-null default on attack_identifier is intentionally relaxed on
# downgrade since we have no way to backfill the original value).
with op.batch_alter_table("PromptMemoryEntries") as batch_op:
batch_op.add_column(sa.Column("targeted_harm_categories", sa.JSON(), nullable=True))
with op.batch_alter_table("AttackResultEntries") as batch_op:
batch_op.add_column(sa.Column("attack_identifier", sa.JSON(), nullable=True))
37 changes: 0 additions & 37 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,43 +445,6 @@ def _get_condition_json_array_match(
combined = joiner.join(conditions)
return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND ({combined})""").bindparams(**bindparams_dict)

def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
"""
Get the SQL Azure implementation for filtering AttackResults by targeted harm categories.

Uses JSON_QUERY() function specific to SQL Azure to check if categories exist in the JSON array.

Args:
targeted_harm_categories (Sequence[str]): List of harm category strings to filter by.

Returns:
Any: SQLAlchemy exists subquery condition with bound parameters.
"""
# For SQL Azure, we need to use JSON_QUERY to check if a value exists in a JSON array
# OPENJSON can parse the array and we check if the category exists
# Using parameterized queries for safety
harm_conditions = []
bindparams_dict = {}
for i, category in enumerate(targeted_harm_categories):
param_name = f"harm_cat_{i}"
# Check if the JSON array contains the category value
harm_conditions.append(
f"EXISTS(SELECT 1 FROM OPENJSON(targeted_harm_categories) WHERE value = :{param_name})"
)
bindparams_dict[param_name] = category

combined_conditions = " AND ".join(harm_conditions)

return exists().where(
and_(
PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id,
PromptMemoryEntry.targeted_harm_categories.isnot(None),
PromptMemoryEntry.targeted_harm_categories != "",
PromptMemoryEntry.targeted_harm_categories != "[]",
text(f"ISJSON(targeted_harm_categories) = 1 AND {combined_conditions}").bindparams(**bindparams_dict),
)
)

def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence[str]]) -> Any:
"""
Azure SQL implementation for filtering AttackResults by labels.
Expand Down
76 changes: 14 additions & 62 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
if TYPE_CHECKING:
from pyrit.memory.memory_embedding import MemoryEmbedding

from pyrit.common.deprecation import print_deprecation_message
from pyrit.memory.memory_models import (
AttackResultEntry,
Base,
Expand Down Expand Up @@ -589,19 +588,6 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict
update_fields (dict): A dictionary of field names and their new values.
"""

@abc.abstractmethod
def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
"""
Return a database-specific condition for filtering AttackResults by targeted harm categories
in the associated PromptMemoryEntry records.

Args:
targeted_harm_categories: List of harm categories that must ALL be present.

Returns:
Database-specific SQLAlchemy condition.
"""

@abc.abstractmethod
def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence[str]]) -> Any:
"""
Expand Down Expand Up @@ -822,21 +808,19 @@ def get_prompt_scores(
converted_value_sha256=converted_value_sha256,
)

# Deduplicate message pieces by original_prompt_id to avoid duplicate scores
# since duplicated pieces share scores with their originals
seen_original_ids = set()
unique_pieces = []
for piece in message_pieces:
if piece.original_prompt_id not in seen_original_ids:
seen_original_ids.add(piece.original_prompt_id)
unique_pieces.append(piece)

scores = []
for piece in unique_pieces:
if piece.scores:
scores.extend(piece.scores)
# Deduplicate by original_prompt_id since duplicated pieces share scores
# with their originals.
original_ids = {piece.original_prompt_id for piece in message_pieces if piece.original_prompt_id is not None}
if not original_ids:
return []

return list(scores)
score_entries = self._execute_batched_query(
ScoreEntry,
batch_column=ScoreEntry.prompt_request_response_id,
batch_values=list(original_ids),
other_conditions=[],
)
return [entry.get_score() for entry in score_entries]

def get_conversation(self, *, conversation_id: str) -> MutableSequence[Message]:
"""
Expand Down Expand Up @@ -1646,13 +1630,11 @@ def get_attack_results(
objective: str | None = None,
objective_sha256: Sequence[str] | None = None,
outcome: str | None = None,
attack_class: str | None = None,
attack_classes: Sequence[str] | None = None,
atomic_attack_eval_hashes: Sequence[str] | None = None,
converter_classes: Sequence[str] | None = None,
converter_classes_match: Literal["all", "any"] = "all",
has_converters: bool | None = None,
targeted_harm_categories: Sequence[str] | None = None,
labels: dict[str, str | Sequence[str]] | None = None,
identifier_filters: Sequence[IdentifierFilter] | None = None,
scenario_result_id: str | None = None,
Expand All @@ -1668,9 +1650,6 @@ def get_attack_results(
Defaults to None.
outcome (str | None, optional): The outcome to filter by (success, failure, undetermined).
Defaults to None.
attack_class (str | None, optional): Deprecated. Filter by a single exact attack
class_name in attack_identifier. Equivalent to passing ``attack_classes=[attack_class]``.
Cannot be combined with ``attack_classes``. Defaults to None.
attack_classes (Sequence[str] | None, optional): Filter by exact attack class_name in
attack_identifier. Returns attacks matching ANY of the listed class names (OR logic,
case-sensitive). An empty sequence applies no filter. Defaults to None.
Expand All @@ -1692,13 +1671,6 @@ def get_attack_results(
has_converters (bool | None, optional): Filter by converter presence.
``True`` returns only attacks that used at least one converter. ``False`` returns
only attacks that used no converters. ``None`` applies no filter. Defaults to None.
targeted_harm_categories (Sequence[str] | None, optional):
A list of targeted harm categories to filter results by.
These targeted harm categories are associated with the prompts themselves,
meaning they are harm(s) we're trying to elicit with the prompt,
not necessarily one(s) that were found in the response.
By providing a list, this means ALL categories in the list must be present.
Defaults to None.
labels (dict[str, str | Sequence[str]] | None, optional): Filter results
by attack labels. Entries are AND-combined across label names; within a
single entry, a string value is an equality match and a sequence value is
Expand All @@ -1719,26 +1691,15 @@ def get_attack_results(
Sequence[AttackResult]: A list of AttackResult objects that match the specified filters.

Raises:
ValueError: If both ``attack_class`` (deprecated) and ``attack_classes`` are provided.
ValueError: If any label key contains characters outside the allowlist
``[A-Za-z0-9_.-]+``.
"""
# Handle empty list cases
if attack_result_ids is not None and len(attack_result_ids) == 0:
return []
if objective_sha256 is not None and len(objective_sha256) == 0:
return []

if attack_class is not None and attack_classes is not None:
raise ValueError(
"Pass either `attack_class` (deprecated, singular) or `attack_classes` (plural), not both."
)
if attack_class is not None and attack_classes is None:
print_deprecation_message(
old_item="get_attack_results(attack_class=...)",
new_item="get_attack_results(attack_classes=...)",
removed_in="0.15.0",
)
attack_classes = [attack_class]

# Build non-list conditions
conditions: list[ColumnElement[bool]] = []
if conversation_id:
Expand Down Expand Up @@ -1814,15 +1775,6 @@ def get_attack_results(
)
conditions.append(not_(empty_condition) if has_converters else empty_condition)

if targeted_harm_categories:
print_deprecation_message(
old_item="get_attack_results(targeted_harm_categories=...)",
new_item="get_attack_results(labels={'harm_category': [...]})",
removed_in="0.15.0",
)
conditions.append(
self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories)
)
if labels:
# Strip keys whose value is an empty sequence — an empty sequence means
# "no OR-candidates", and per the docstring applies no filter for that
Expand Down
Loading
Loading