Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions backend/app/core/validators/prompts/topic_relevance_llm/v1.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
{{USER_PROMPT}}

Rules:

- Use semantic meaning, not keyword matching.
Expand Down
2 changes: 0 additions & 2 deletions backend/app/core/validators/prompts/topic_relevance_llm/v2.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
{{USER_PROMPT}}

Rules:

- Use semantic meaning, not keyword matching.
Expand Down
2 changes: 0 additions & 2 deletions backend/app/core/validators/prompts/topic_relevance_llm/v3.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
{{USER_PROMPT}}

Rules:

- Use semantic meaning, not keyword matching.
Expand Down
30 changes: 9 additions & 21 deletions backend/app/core/validators/topic_relevance_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
supports_response_format,
)

# Placeholder in user-message templates marking where the user's query is injected.
_USER_PROMPT_PLACEHOLDER = "{{USER_PROMPT}}"
_PROMPTS_DIR = Path(__file__).parent / "prompts" / "topic_relevance_llm"

# Valid scope scores returned by the model; the highest means "clearly in scope".
Expand All @@ -35,7 +33,7 @@

@lru_cache(maxsize=8)
def _load_prompt_template(prompt_schema_version: int) -> str:
"""Load and cache the user-message prompt template for the given schema version."""
"""Load and cache the scoring instruction block for the given schema version."""
if prompt_schema_version < 1:
raise ValueError("prompt_schema_version must be a positive integer")

Expand All @@ -45,12 +43,7 @@ def _load_prompt_template(prompt_schema_version: int) -> str:
f"Topic relevance (LLM) prompt template for version {prompt_schema_version} not found"
)

template = prompt_file.read_text(encoding="utf-8")
if _USER_PROMPT_PLACEHOLDER not in template:
raise ValueError(
f"Prompt template v{prompt_schema_version} must contain {_USER_PROMPT_PLACEHOLDER}"
)
return template
return prompt_file.read_text(encoding="utf-8")


@register_validator(name="topic-relevance-llm", data_type="string")
Expand All @@ -59,10 +52,10 @@ class TopicRelevanceLLM(Validator):
Validates whether a user message is within the defined topic scope
using a direct LLM call via litellm.

The caller supplies the topic configuration as ``system_prompt``, which
becomes the system message. Scoring and response-format instructions are
loaded from a versioned prompt template (v1/v2/v3) and injected as the
user message alongside the query.
The caller supplies the topic configuration as ``system_prompt``. Scoring
and response-format instructions are loaded from a versioned prompt template
(v1/v2/v3) and appended to the system message. The user message contains
only the raw query.

Scores 1–3 where 3 = clearly in scope, 2 = partially related,
1 = outside scope. Passes when score >= threshold (default 2).
Expand All @@ -87,20 +80,19 @@ def __init__(
self.threshold = threshold
self._invalid_config_reason: Optional[str] = None
self._system_prompt: Optional[str] = None
self._user_message_template: Optional[str] = None
self._supports_response_format: bool = False

if not system_prompt or not system_prompt.strip():
self._invalid_config_reason = "system_prompt is blank or missing"
return

try:
self._user_message_template = _load_prompt_template(prompt_schema_version)
scoring_rules = _load_prompt_template(prompt_schema_version)
except ValueError as e:
self._invalid_config_reason = str(e)
return

self._system_prompt = system_prompt.strip()
self._system_prompt = f"{system_prompt.strip()}\n\n{scoring_rules}"
self._supports_response_format = supports_response_format(llm_callable)

def _validate(
Expand All @@ -112,16 +104,12 @@ def _validate(
if not value or not value.strip():
return FailResult(error_message=EMPTY_MESSAGE_ERROR)

user_message = self._user_message_template.replace(
_USER_PROMPT_PLACEHOLDER, value
)

try:
kwargs = {
"model": self.llm_callable,
"messages": [
{"role": "system", "content": self._system_prompt},
{"role": "user", "content": user_message},
{"role": "user", "content": value},
],
"max_tokens": _MAX_TOKENS,
}
Expand Down
32 changes: 9 additions & 23 deletions backend/app/tests/validators/test_topic_relevance_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import pytest
from guardrails.validators import FailResult, PassResult

from app.core.validators.topic_relevance_llm import (
TopicRelevanceLLM,
_USER_PROMPT_PLACEHOLDER,
)
from app.core.validators.topic_relevance_llm import TopicRelevanceLLM

TOPIC_CONFIG = "Only answer questions about cooking and recipes."

Expand Down Expand Up @@ -234,16 +231,15 @@ def test_fails_when_score_is_boolean(validator):
# ---------------------------------------------------------------------------


def test_user_message_contains_query_not_placeholder(validator):
def test_user_message_is_exactly_the_query(validator):
query = "How do I make pasta?"
with patch("app.core.validators.topic_relevance_llm.completion") as mock_llm:
mock_llm.return_value = _make_llm_response('{"scope_violation": 3}')
validator._validate(query)

_, kwargs = mock_llm.call_args
user_message = kwargs["messages"][1]["content"]
assert query in user_message
assert _USER_PROMPT_PLACEHOLDER not in user_message
assert user_message == query


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -311,25 +307,15 @@ def test_system_prompt_contains_topic_config():
assert TOPIC_CONFIG in validator._system_prompt


def test_user_message_template_contains_json_instruction():
def test_system_prompt_contains_json_instruction():
with patch(
"app.core.validators.llm_utils.get_supported_openai_params",
return_value=[],
):
validator = TopicRelevanceLLM(system_prompt=TOPIC_CONFIG)

assert "scope_violation" in validator._user_message_template
assert "JSON" in validator._user_message_template


def test_user_message_template_contains_user_prompt_placeholder():
with patch(
"app.core.validators.llm_utils.get_supported_openai_params",
return_value=[],
):
validator = TopicRelevanceLLM(system_prompt=TOPIC_CONFIG)

assert _USER_PROMPT_PLACEHOLDER in validator._user_message_template
assert "scope_violation" in validator._system_prompt
assert "JSON" in validator._system_prompt


def test_prompt_schema_version_v2_loads_forbidden_template():
Expand All @@ -341,7 +327,7 @@ def test_prompt_schema_version_v2_loads_forbidden_template():
system_prompt=TOPIC_CONFIG, prompt_schema_version=2
)

assert "forbidden" in v2_validator._user_message_template.lower()
assert "forbidden" in v2_validator._system_prompt.lower()


def test_prompt_schema_version_v3_loads_combined_template():
Expand All @@ -353,8 +339,8 @@ def test_prompt_schema_version_v3_loads_combined_template():
system_prompt=TOPIC_CONFIG, prompt_schema_version=3
)

assert "forbidden" in v3_validator._user_message_template.lower()
assert "allowed" in v3_validator._user_message_template.lower()
assert "forbidden" in v3_validator._system_prompt.lower()
assert "allowed" in v3_validator._system_prompt.lower()


def test_invalid_prompt_schema_version_returns_fail():
Expand Down