diff --git a/backend/app/core/validators/prompts/topic_relevance_llm/v1.md b/backend/app/core/validators/prompts/topic_relevance_llm/v1.md index 036d7cc..ffabc12 100644 --- a/backend/app/core/validators/prompts/topic_relevance_llm/v1.md +++ b/backend/app/core/validators/prompts/topic_relevance_llm/v1.md @@ -1,5 +1,3 @@ -{{USER_PROMPT}} - Rules: - Use semantic meaning, not keyword matching. diff --git a/backend/app/core/validators/prompts/topic_relevance_llm/v2.md b/backend/app/core/validators/prompts/topic_relevance_llm/v2.md index 60b82d9..db56047 100644 --- a/backend/app/core/validators/prompts/topic_relevance_llm/v2.md +++ b/backend/app/core/validators/prompts/topic_relevance_llm/v2.md @@ -1,5 +1,3 @@ -{{USER_PROMPT}} - Rules: - Use semantic meaning, not keyword matching. diff --git a/backend/app/core/validators/prompts/topic_relevance_llm/v3.md b/backend/app/core/validators/prompts/topic_relevance_llm/v3.md index 59cff81..47ba77c 100644 --- a/backend/app/core/validators/prompts/topic_relevance_llm/v3.md +++ b/backend/app/core/validators/prompts/topic_relevance_llm/v3.md @@ -1,5 +1,3 @@ -{{USER_PROMPT}} - Rules: - Use semantic meaning, not keyword matching. diff --git a/backend/app/core/validators/topic_relevance_llm.py b/backend/app/core/validators/topic_relevance_llm.py index f7e44d1..69700a1 100644 --- a/backend/app/core/validators/topic_relevance_llm.py +++ b/backend/app/core/validators/topic_relevance_llm.py @@ -23,8 +23,6 @@ supports_response_format, ) -# Placeholder in user-message templates marking where the user's query is injected. -_USER_PROMPT_PLACEHOLDER = "{{USER_PROMPT}}" _PROMPTS_DIR = Path(__file__).parent / "prompts" / "topic_relevance_llm" # Valid scope scores returned by the model; the highest means "clearly in scope". @@ -35,7 +33,7 @@ @lru_cache(maxsize=8) def _load_prompt_template(prompt_schema_version: int) -> str: - """Load and cache the user-message prompt template for the given schema version.""" + """Load and cache the scoring instruction block for the given schema version.""" if prompt_schema_version < 1: raise ValueError("prompt_schema_version must be a positive integer") @@ -45,12 +43,7 @@ def _load_prompt_template(prompt_schema_version: int) -> str: f"Topic relevance (LLM) prompt template for version {prompt_schema_version} not found" ) - template = prompt_file.read_text(encoding="utf-8") - if _USER_PROMPT_PLACEHOLDER not in template: - raise ValueError( - f"Prompt template v{prompt_schema_version} must contain {_USER_PROMPT_PLACEHOLDER}" - ) - return template + return prompt_file.read_text(encoding="utf-8") @register_validator(name="topic-relevance-llm", data_type="string") @@ -59,10 +52,10 @@ class TopicRelevanceLLM(Validator): Validates whether a user message is within the defined topic scope using a direct LLM call via litellm. - The caller supplies the topic configuration as ``system_prompt``, which - becomes the system message. Scoring and response-format instructions are - loaded from a versioned prompt template (v1/v2/v3) and injected as the - user message alongside the query. + The caller supplies the topic configuration as ``system_prompt``. Scoring + and response-format instructions are loaded from a versioned prompt template + (v1/v2/v3) and appended to the system message. The user message contains + only the raw query. Scores 1–3 where 3 = clearly in scope, 2 = partially related, 1 = outside scope. Passes when score >= threshold (default 2). @@ -87,7 +80,6 @@ def __init__( self.threshold = threshold self._invalid_config_reason: Optional[str] = None self._system_prompt: Optional[str] = None - self._user_message_template: Optional[str] = None self._supports_response_format: bool = False if not system_prompt or not system_prompt.strip(): @@ -95,12 +87,12 @@ def __init__( return try: - self._user_message_template = _load_prompt_template(prompt_schema_version) + scoring_rules = _load_prompt_template(prompt_schema_version) except ValueError as e: self._invalid_config_reason = str(e) return - self._system_prompt = system_prompt.strip() + self._system_prompt = f"{system_prompt.strip()}\n\n{scoring_rules}" self._supports_response_format = supports_response_format(llm_callable) def _validate( @@ -112,16 +104,12 @@ def _validate( if not value or not value.strip(): return FailResult(error_message=EMPTY_MESSAGE_ERROR) - user_message = self._user_message_template.replace( - _USER_PROMPT_PLACEHOLDER, value - ) - try: kwargs = { "model": self.llm_callable, "messages": [ {"role": "system", "content": self._system_prompt}, - {"role": "user", "content": user_message}, + {"role": "user", "content": value}, ], "max_tokens": _MAX_TOKENS, } diff --git a/backend/app/tests/validators/test_topic_relevance_llm.py b/backend/app/tests/validators/test_topic_relevance_llm.py index 697b20b..926186e 100644 --- a/backend/app/tests/validators/test_topic_relevance_llm.py +++ b/backend/app/tests/validators/test_topic_relevance_llm.py @@ -3,10 +3,7 @@ import pytest from guardrails.validators import FailResult, PassResult -from app.core.validators.topic_relevance_llm import ( - TopicRelevanceLLM, - _USER_PROMPT_PLACEHOLDER, -) +from app.core.validators.topic_relevance_llm import TopicRelevanceLLM TOPIC_CONFIG = "Only answer questions about cooking and recipes." @@ -234,7 +231,7 @@ def test_fails_when_score_is_boolean(validator): # --------------------------------------------------------------------------- -def test_user_message_contains_query_not_placeholder(validator): +def test_user_message_is_exactly_the_query(validator): query = "How do I make pasta?" with patch("app.core.validators.topic_relevance_llm.completion") as mock_llm: mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') @@ -242,8 +239,7 @@ def test_user_message_contains_query_not_placeholder(validator): _, kwargs = mock_llm.call_args user_message = kwargs["messages"][1]["content"] - assert query in user_message - assert _USER_PROMPT_PLACEHOLDER not in user_message + assert user_message == query # --------------------------------------------------------------------------- @@ -311,25 +307,15 @@ def test_system_prompt_contains_topic_config(): assert TOPIC_CONFIG in validator._system_prompt -def test_user_message_template_contains_json_instruction(): +def test_system_prompt_contains_json_instruction(): with patch( "app.core.validators.llm_utils.get_supported_openai_params", return_value=[], ): validator = TopicRelevanceLLM(system_prompt=TOPIC_CONFIG) - assert "scope_violation" in validator._user_message_template - assert "JSON" in validator._user_message_template - - -def test_user_message_template_contains_user_prompt_placeholder(): - with patch( - "app.core.validators.llm_utils.get_supported_openai_params", - return_value=[], - ): - validator = TopicRelevanceLLM(system_prompt=TOPIC_CONFIG) - - assert _USER_PROMPT_PLACEHOLDER in validator._user_message_template + assert "scope_violation" in validator._system_prompt + assert "JSON" in validator._system_prompt def test_prompt_schema_version_v2_loads_forbidden_template(): @@ -341,7 +327,7 @@ def test_prompt_schema_version_v2_loads_forbidden_template(): system_prompt=TOPIC_CONFIG, prompt_schema_version=2 ) - assert "forbidden" in v2_validator._user_message_template.lower() + assert "forbidden" in v2_validator._system_prompt.lower() def test_prompt_schema_version_v3_loads_combined_template(): @@ -353,8 +339,8 @@ def test_prompt_schema_version_v3_loads_combined_template(): system_prompt=TOPIC_CONFIG, prompt_schema_version=3 ) - assert "forbidden" in v3_validator._user_message_template.lower() - assert "allowed" in v3_validator._user_message_template.lower() + assert "forbidden" in v3_validator._system_prompt.lower() + assert "allowed" in v3_validator._system_prompt.lower() def test_invalid_prompt_schema_version_returns_fail():