From 01674716a96ebd9d527059974dc560720bb28fe5 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Mon, 1 Jun 2026 12:52:19 +0530 Subject: [PATCH 1/5] added open ai topic relevance guardrail --- backend/app/api/routes/guardrails.py | 11 +- backend/app/core/enum.py | 1 + ...elevance_openai_safety_validator_config.py | 29 ++ .../core/validators/topic_relevance_openai.py | 140 +++++++++ backend/app/core/validators/validators.json | 5 + .../evaluation/topic_relevance_openai/run.py | 127 ++++++++ backend/app/schemas/guardrail_config.py | 4 + backend/app/tests/test_llm_validators.py | 76 +++++ backend/app/tests/test_validate_with_guard.py | 73 +++++ .../validators/test_topic_relevance_openai.py | 289 ++++++++++++++++++ backend/scripts/run_all_evaluations.sh | 1 + 11 files changed, 755 insertions(+), 1 deletion(-) create mode 100644 backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py create mode 100644 backend/app/core/validators/topic_relevance_openai.py create mode 100644 backend/app/evaluation/topic_relevance_openai/run.py create mode 100644 backend/app/tests/validators/test_topic_relevance_openai.py diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 7281718..4d0483b 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -26,6 +26,9 @@ from app.core.validators.config.topic_relevance_safety_validator_config import ( TopicRelevanceSafetyValidatorConfig, ) +from app.core.validators.config.topic_relevance_openai_safety_validator_config import ( + TopicRelevanceOpenAISafetyValidatorConfig, +) from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse from app.models.logging.request_log import RequestLogUpdate, RequestStatus from app.models.logging.validator_log import ValidatorLog, ValidatorOutcome @@ -115,7 +118,13 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N ) validator.banned_words = ban_list.banned_words - elif isinstance(validator, TopicRelevanceSafetyValidatorConfig): + elif isinstance( + validator, + ( + TopicRelevanceSafetyValidatorConfig, + TopicRelevanceOpenAISafetyValidatorConfig, + ), + ): if validator.topic_relevance_config_id is not None: config = topic_relevance_crud.get( session=session, diff --git a/backend/app/core/enum.py b/backend/app/core/enum.py index ff653c5..70eff2d 100644 --- a/backend/app/core/enum.py +++ b/backend/app/core/enum.py @@ -32,6 +32,7 @@ class ValidatorType(Enum): GenderAssumptionBias = "gender_assumption_bias" BanList = "ban_list" TopicRelevance = "topic_relevance" + TopicRelevanceOpenAI = "topic_relevance_openai" LLMCritic = "llm_critic" LlamaGuard7B = "llamaguard_7b" ProfanityFree = "profanity_free" diff --git a/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py b/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py new file mode 100644 index 0000000..73c6af7 --- /dev/null +++ b/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py @@ -0,0 +1,29 @@ +from typing import Literal, Optional +from uuid import UUID + +from app.core.validators.topic_relevance_openai import TopicRelevanceOpenAI +from app.core.validators.config.base_validator_config import BaseValidatorConfig +from app.core.config import settings + + +class TopicRelevanceOpenAISafetyValidatorConfig(BaseValidatorConfig): + type: Literal["topic_relevance_openai"] + configuration: Optional[str] = None + prompt_schema_version: Optional[int] = None + llm_callable: str = "gpt-4o-mini" + threshold: int = 2 + topic_relevance_config_id: Optional[UUID] = None + + def build(self): + if not settings.OPENAI_API_KEY: + raise ValueError( + "OPENAI_API_KEY is not configured. " + "Topic relevance (OpenAI) validation requires an OpenAI API key." + ) + return TopicRelevanceOpenAI( + topic_config=self.configuration or " ", + prompt_schema_version=self.prompt_schema_version or 1, + llm_callable=self.llm_callable, + threshold=self.threshold, + on_fail=self.resolve_on_fail(), + ) diff --git a/backend/app/core/validators/topic_relevance_openai.py b/backend/app/core/validators/topic_relevance_openai.py new file mode 100644 index 0000000..c53a2cd --- /dev/null +++ b/backend/app/core/validators/topic_relevance_openai.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import json +from functools import lru_cache +from pathlib import Path +from typing import Callable, Optional + +from litellm import completion, get_supported_openai_params +from guardrails import OnFailAction +from guardrails.validators import ( + FailResult, + PassResult, + ValidationResult, + Validator, + register_validator, +) + + +_PROMPT_PLACEHOLDER = "{{TOPIC_CONFIGURATION}}" +_PROMPTS_DIR = Path(__file__).parent / "prompts" / "topic_relevance" + +_JSON_INSTRUCTION = ( + "\n\nRespond ONLY with a JSON object in this exact format: " + '{"scope_violation": } where is the integer 1, 2, or 3.' +) + + +@lru_cache(maxsize=8) +def _load_prompt_template(prompt_schema_version: int) -> str: + if prompt_schema_version < 1: + raise ValueError("prompt_schema_version must be a positive integer") + prompt_file = _PROMPTS_DIR / f"v{prompt_schema_version}.md" + if not prompt_file.exists(): + raise ValueError( + f"Topic relevance prompt template for version {prompt_schema_version} not found" + ) + template = prompt_file.read_text(encoding="utf-8") + if _PROMPT_PLACEHOLDER not in template: + raise ValueError( + f"Prompt template v{prompt_schema_version} must contain {_PROMPT_PLACEHOLDER}" + ) + return template + + +def _build_system_prompt(prompt_schema_version: int, topic_config: str) -> str: + scope_text = topic_config.strip() + if not scope_text: + raise ValueError("topic_config cannot be empty") + template = _load_prompt_template(prompt_schema_version) + return template.replace(_PROMPT_PLACEHOLDER, scope_text) + _JSON_INSTRUCTION + + +@register_validator(name="topic-relevance-openai", data_type="string") +class TopicRelevanceOpenAI(Validator): + """ + Validates whether a user message is within the defined topic scope + using a direct OpenAI/litellm call. + + Scores 1–3 where 3 = clearly in scope, 2 = ambiguous, 1 = outside scope. + Passes when score >= threshold (default 2). + """ + + def __init__( + self, + topic_config: str, + prompt_schema_version: int = 1, + llm_callable: str = "gpt-4o-mini", + threshold: int = 2, + on_fail: Optional[Callable] = OnFailAction.NOOP, + ): + super().__init__(on_fail=on_fail) + + self.topic_config = topic_config + self.prompt_schema_version = prompt_schema_version + self.llm_callable = llm_callable + self.threshold = threshold + self._invalid_config_reason: Optional[str] = None + self._system_prompt: Optional[str] = None + self._supports_response_format: bool = False + + if not topic_config or not topic_config.strip(): + self._invalid_config_reason = "topic_config is blank or missing" + return + + try: + self._system_prompt = _build_system_prompt( + prompt_schema_version, topic_config + ) + except ValueError as e: + self._invalid_config_reason = str(e) + return + + try: + self._supports_response_format = "response_format" in ( + get_supported_openai_params(model=llm_callable) or [] + ) + except Exception: + self._supports_response_format = False + + def _validate(self, value: str, metadata: dict = None) -> ValidationResult: + if self._invalid_config_reason: + return FailResult(error_message=self._invalid_config_reason) + + if not value or not value.strip(): + return FailResult(error_message="Empty message.") + + try: + kwargs = dict( + model=self.llm_callable, + messages=[ + {"role": "system", "content": self._system_prompt}, + {"role": "user", "content": value}, + ], + max_tokens=50, + ) + if self._supports_response_format: + kwargs["response_format"] = {"type": "json_object"} + + response = completion(**kwargs) + content = response.choices[0].message.content.strip() + except Exception as e: + return FailResult(error_message=f"LLM call failed: {e}") + + try: + data = json.loads(content) + score = data.get("scope_violation") + if not isinstance(score, int) or score not in (1, 2, 3): + raise ValueError(f"unexpected score value: {score!r}") + except Exception as e: + return FailResult( + error_message=f"LLM returned an unparseable response: {e}. Raw: {content!r}" + ) + + if score >= self.threshold: + return PassResult(value=value, metadata={"scope_score": score}) + + return FailResult( + error_message="Input is outside the allowed topic scope.", + metadata={"scope_score": score}, + ) diff --git a/backend/app/core/validators/validators.json b/backend/app/core/validators/validators.json index c6c0fd6..9823bf4 100644 --- a/backend/app/core/validators/validators.json +++ b/backend/app/core/validators/validators.json @@ -30,6 +30,11 @@ "version": "0.1.0", "source": "local" }, + { + "type": "topic_relevance_openai", + "version": "0.1.0", + "source": "local" + }, { "type": "llamaguard_7b", "version": "0.1.0", diff --git a/backend/app/evaluation/topic_relevance_openai/run.py b/backend/app/evaluation/topic_relevance_openai/run.py new file mode 100644 index 0000000..8b2c996 --- /dev/null +++ b/backend/app/evaluation/topic_relevance_openai/run.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from pathlib import Path + +import pandas as pd +from guardrails.validators import FailResult + +from app.core.validators.topic_relevance_openai import TopicRelevanceOpenAI +from app.evaluation.common.helper import ( + Profiler, + build_evaluation_report, + compute_binary_metrics, + write_csv, + write_json, +) + +BASE_DIR = Path(__file__).resolve().parent.parent +DATASETS_DIR = BASE_DIR / "datasets" / "topic_relevance" +OUT_DIR = BASE_DIR / "outputs" / "topic_relevance_openai" + +DEFAULT_CONFIG = { + "llm_callable": "gpt-4o-mini", + "prompt_schema_version": 3, + "threshold": 2, +} + +EVALUATIONS = [ + { + "domain": "education", + "dataset": "education-topic-relevance-dataset.csv", + "topic_config": "education_topic_config.txt", + }, + { + "domain": "healthcare", + "dataset": "healthcare-topic-relevance-dataset.csv", + "topic_config": "healthcare_topic_config.txt", + }, +] + + +def run_evaluation(config: dict) -> None: + """ + Run the topic relevance (OpenAI) evaluation for a single domain config. + Loads dataset and topic config, runs each input through TopicRelevanceOpenAI, + computes binary and per-category metrics, and writes CSV and JSON to outputs. + """ + domain = config["domain"] + + dataset_path = DATASETS_DIR / config["dataset"] + topic_config_path = DATASETS_DIR / config["topic_config"] + topic_config = topic_config_path.read_text() + + print(f"\nRunning topic relevance (OpenAI) evaluation: {domain}") + + df = pd.read_csv(dataset_path) + + validator = TopicRelevanceOpenAI( + topic_config=topic_config, + prompt_schema_version=DEFAULT_CONFIG["prompt_schema_version"], + llm_callable=DEFAULT_CONFIG["llm_callable"], + threshold=DEFAULT_CONFIG["threshold"], + ) + + normalized_df = pd.DataFrame( + { + "input": df["input"].astype(str), + "category": df["category"].astype(str), + "in_scope": df["scope"].apply(lambda x: 1 if x == "IN_SCOPE" else 0), + } + ) + + normalized_df["y_true"] = (1 - normalized_df["in_scope"]).astype(int) + + with Profiler() as p: + results = normalized_df["input"].apply( + lambda x: p.record(lambda t: validator.validate(t, metadata=None), x) + ) + + normalized_df["y_pred"] = results.apply(lambda r: int(isinstance(r, FailResult))) + normalized_df["scope_score"] = results.apply( + lambda r: r.metadata.get("scope_score") + if getattr(r, "metadata", None) + else None + ) + normalized_df["error_message"] = results.apply( + lambda r: r.error_message if isinstance(r, FailResult) else "" + ) + + metrics = compute_binary_metrics(normalized_df["y_true"], normalized_df["y_pred"]) + + metrics["category_metrics"] = { + str(cat): { + "num_samples": int(len(g)), + **compute_binary_metrics(g["y_true"], g["y_pred"]), + } + for cat, g in normalized_df.groupby("category", dropna=False) + } + + OUT_DIR.mkdir(parents=True, exist_ok=True) + + write_csv(normalized_df, OUT_DIR / f"{domain}-predictions.csv") + + write_json( + build_evaluation_report( + guardrail="topic_relevance_openai", + num_samples=len(normalized_df), + profiler=p, + dataset=str(dataset_path), + llm_callable=DEFAULT_CONFIG["llm_callable"], + prompt_schema_version=DEFAULT_CONFIG["prompt_schema_version"], + threshold=DEFAULT_CONFIG["threshold"], + metrics=metrics, + ), + OUT_DIR / f"{domain}-metrics.json", + ) + + print(f"Completed {domain} evaluation") + + +def main() -> None: + """Iterate over all entries in EVALUATIONS and run each domain evaluation in sequence.""" + for config in EVALUATIONS: + run_evaluation(config) + + +if __name__ == "__main__": + main() diff --git a/backend/app/schemas/guardrail_config.py b/backend/app/schemas/guardrail_config.py index 968c260..ab179ad 100644 --- a/backend/app/schemas/guardrail_config.py +++ b/backend/app/schemas/guardrail_config.py @@ -24,6 +24,9 @@ from app.core.validators.config.topic_relevance_safety_validator_config import ( TopicRelevanceSafetyValidatorConfig, ) +from app.core.validators.config.topic_relevance_openai_safety_validator_config import ( + TopicRelevanceOpenAISafetyValidatorConfig, +) from app.core.validators.config.llamaguard_7b_safety_validator_config import ( LlamaGuard7BSafetyValidatorConfig, ) @@ -45,6 +48,7 @@ NSFWTextSafetyValidatorConfig, ProfanityFreeSafetyValidatorConfig, TopicRelevanceSafetyValidatorConfig, + TopicRelevanceOpenAISafetyValidatorConfig, ], Field(discriminator="type"), ] diff --git a/backend/app/tests/test_llm_validators.py b/backend/app/tests/test_llm_validators.py index 5834843..9cec0b0 100644 --- a/backend/app/tests/test_llm_validators.py +++ b/backend/app/tests/test_llm_validators.py @@ -6,6 +6,9 @@ from app.core.validators.config.topic_relevance_safety_validator_config import ( TopicRelevanceSafetyValidatorConfig, ) +from app.core.validators.config.topic_relevance_openai_safety_validator_config import ( + TopicRelevanceOpenAISafetyValidatorConfig, +) from app.core.validators.config.llm_critic_safety_validator_config import ( LLMCriticSafetyValidatorConfig, ) @@ -61,6 +64,79 @@ def test_topic_relevance_blank_config_returns_fail_result(): assert "blank" in result.error_message +_SAMPLE_OPENAI_TOPIC_CONFIG = dict( + type="topic_relevance_openai", + configuration="Only answer about cooking.", + llm_callable="gpt-4o-mini", +) + +_TOPIC_RELEVANCE_OPENAI_SETTINGS_PATH = ( + "app.core.validators.config.topic_relevance_openai_safety_validator_config.settings" +) + + +def test_topic_relevance_openai_build_raises_when_openai_key_missing(): + config = TopicRelevanceOpenAISafetyValidatorConfig(**_SAMPLE_OPENAI_TOPIC_CONFIG) + + with patch(_TOPIC_RELEVANCE_OPENAI_SETTINGS_PATH) as mock_settings: + mock_settings.OPENAI_API_KEY = None + + with pytest.raises(ValueError) as exc: + config.build() + + assert "OPENAI_API_KEY" in str(exc.value) + assert "not configured" in str(exc.value) + + +def test_topic_relevance_openai_build_proceeds_when_openai_key_present(): + config = TopicRelevanceOpenAISafetyValidatorConfig(**_SAMPLE_OPENAI_TOPIC_CONFIG) + + with patch(_TOPIC_RELEVANCE_OPENAI_SETTINGS_PATH) as mock_settings, patch( + "app.core.validators.config.topic_relevance_openai_safety_validator_config.TopicRelevanceOpenAI" + ) as mock_validator: + mock_settings.OPENAI_API_KEY = "sk-test-key" + config.build() + + mock_validator.assert_called_once() + + +def test_topic_relevance_openai_blank_config_returns_fail_result(): + config = TopicRelevanceOpenAISafetyValidatorConfig( + **{**_SAMPLE_OPENAI_TOPIC_CONFIG, "configuration": None} + ) + + with patch(_TOPIC_RELEVANCE_OPENAI_SETTINGS_PATH) as mock_settings, patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + return_value=[], + ): + mock_settings.OPENAI_API_KEY = "sk-test-key" + validator = config.build() + + result = validator._validate("some input") + assert isinstance(result, FailResult) + assert "blank" in result.error_message + + +def test_topic_relevance_openai_default_threshold_is_2(): + config = TopicRelevanceOpenAISafetyValidatorConfig(**_SAMPLE_OPENAI_TOPIC_CONFIG) + assert config.threshold == 2 + + +def test_topic_relevance_openai_custom_threshold_forwarded_to_validator(): + config = TopicRelevanceOpenAISafetyValidatorConfig( + **{**_SAMPLE_OPENAI_TOPIC_CONFIG, "threshold": 3} + ) + + with patch(_TOPIC_RELEVANCE_OPENAI_SETTINGS_PATH) as mock_settings, patch( + "app.core.validators.config.topic_relevance_openai_safety_validator_config.TopicRelevanceOpenAI" + ) as mock_validator: + mock_settings.OPENAI_API_KEY = "sk-test-key" + config.build() + + call_kwargs = mock_validator.call_args[1] + assert call_kwargs["threshold"] == 3 + + _SAMPLE_CONFIG = dict( type="llm_critic", metrics={ diff --git a/backend/app/tests/test_validate_with_guard.py b/backend/app/tests/test_validate_with_guard.py index 2956512..ec91791 100644 --- a/backend/app/tests/test_validate_with_guard.py +++ b/backend/app/tests/test_validate_with_guard.py @@ -270,6 +270,79 @@ def test_resolve_validator_configs_uses_inline_topic_relevance_without_lookup(): mock_get.assert_not_called() +def test_resolve_validator_configs_topic_relevance_openai_from_config_id(): + topic_relevance_id = str(uuid4()) + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="test", + validators=[ + { + "type": "topic_relevance_openai", + "topic_relevance_config_id": topic_relevance_id, + } + ], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.topic_relevance_crud.get") as mock_get: + mock_get.return_value = MagicMock( + configuration="Healthcare topic scope text", + prompt_schema_version=3, + ) + _resolve_validator_configs(payload, mock_session) + + validator = payload.validators[0] + assert validator.configuration == "Healthcare topic scope text" + assert validator.prompt_schema_version == 3 + mock_get.assert_called_once_with( + session=mock_session, + id=validator.topic_relevance_config_id, + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + ) + + +def test_resolve_validator_configs_skips_topic_relevance_openai_lookup_when_no_config_id(): + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="test", + validators=[{"type": "topic_relevance_openai"}], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.topic_relevance_crud.get") as mock_get: + _resolve_validator_configs(payload, mock_session) + + mock_get.assert_not_called() + + +def test_resolve_validator_configs_uses_inline_topic_relevance_openai_without_lookup(): + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="test", + validators=[ + { + "type": "topic_relevance_openai", + "configuration": "inline openai config", + } + ], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.topic_relevance_crud.get") as mock_get: + _resolve_validator_configs(payload, mock_session) + + validator = payload.validators[0] + assert validator.configuration == "inline openai config" + mock_get.assert_not_called() + + def _build_mock_guard_with_fail_result(validator_name: str, error_message: str): mock_log = MagicMock() mock_log.validator_name = validator_name diff --git a/backend/app/tests/validators/test_topic_relevance_openai.py b/backend/app/tests/validators/test_topic_relevance_openai.py new file mode 100644 index 0000000..93646ce --- /dev/null +++ b/backend/app/tests/validators/test_topic_relevance_openai.py @@ -0,0 +1,289 @@ +from unittest.mock import MagicMock, patch + +import pytest +from guardrails.validators import FailResult, PassResult + +from app.core.validators.topic_relevance_openai import TopicRelevanceOpenAI + +TOPIC_CONFIG = "Only answer questions about cooking and recipes." + + +def _make_llm_response(json_text: str) -> MagicMock: + choice = MagicMock() + choice.message.content = json_text + result = MagicMock() + result.choices = [choice] + return result + + +@pytest.fixture +def validator(): + with patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + return_value=["response_format"], + ): + return TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + + +# --------------------------------------------------------------------------- +# PassResult — score >= threshold (2) +# --------------------------------------------------------------------------- + + +def test_passes_when_score_is_3(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') + result = validator._validate("How do I make pasta?") + + assert isinstance(result, PassResult) + assert result.metadata["scope_score"] == 3 + + +def test_passes_when_score_equals_threshold(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 2}') + result = validator._validate("What is cooking roughly about?") + + assert isinstance(result, PassResult) + assert result.metadata["scope_score"] == 2 + + +# --------------------------------------------------------------------------- +# FailResult — score < threshold (1) +# --------------------------------------------------------------------------- + + +def test_fails_when_score_is_1(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 1}') + result = validator._validate("What is the latest cricket score?") + + assert isinstance(result, FailResult) + assert "outside the allowed topic scope" in result.error_message + assert result.metadata["scope_score"] == 1 + + +# --------------------------------------------------------------------------- +# Custom threshold +# --------------------------------------------------------------------------- + + +def test_custom_threshold_of_3_fails_on_score_2(): + with patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG, threshold=3) + + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 2}') + result = v._validate("Something vaguely food related") + + assert isinstance(result, FailResult) + assert result.metadata["scope_score"] == 2 + + +def test_custom_threshold_of_1_passes_on_score_1(): + with patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG, threshold=1) + + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 1}') + result = v._validate("Cricket scores") + + assert isinstance(result, PassResult) + + +# --------------------------------------------------------------------------- +# Guard inputs +# --------------------------------------------------------------------------- + + +def test_fails_when_value_is_empty(validator): + result = validator._validate("") + + assert isinstance(result, FailResult) + assert "Empty message" in result.error_message + + +def test_fails_when_value_is_whitespace(validator): + result = validator._validate(" ") + + assert isinstance(result, FailResult) + assert "Empty message" in result.error_message + + +def test_fails_when_topic_config_is_blank(): + with patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(topic_config="") + + result = v._validate("Some input") + + assert isinstance(result, FailResult) + assert "blank" in result.error_message + + +def test_fails_when_topic_config_is_whitespace_only(): + with patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(topic_config=" ") + + result = v._validate("Some input") + + assert isinstance(result, FailResult) + assert "blank" in result.error_message + + +# --------------------------------------------------------------------------- +# Prompt version +# --------------------------------------------------------------------------- + + +def test_invalid_prompt_version_causes_fail_on_validate(): + with patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG, prompt_schema_version=999) + + result = v._validate("Some input") + + assert isinstance(result, FailResult) + assert "999" in result.error_message + + +# --------------------------------------------------------------------------- +# LLM error handling +# --------------------------------------------------------------------------- + + +def test_fails_gracefully_when_llm_raises(validator): + with patch( + "app.core.validators.topic_relevance_openai.completion", + side_effect=Exception("network timeout"), + ): + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "LLM call failed" in result.error_message + assert "network timeout" in result.error_message + + +def test_fails_gracefully_when_response_is_not_json(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response("Sure, this looks great!") + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "unparseable" in result.error_message + + +def test_fails_gracefully_when_score_key_is_missing(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"result": "yes"}') + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "unparseable" in result.error_message + + +def test_fails_gracefully_when_score_is_out_of_range(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 5}') + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "unparseable" in result.error_message + + +def test_fails_gracefully_when_score_is_a_string(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": "high"}') + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "unparseable" in result.error_message + + +# --------------------------------------------------------------------------- +# response_format forwarding +# --------------------------------------------------------------------------- + + +def test_response_format_passed_when_supported(): + with patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + return_value=["response_format"], + ): + v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') + v._validate("How do I make pasta?") + + _, kwargs = mock_llm.call_args + assert kwargs.get("response_format") == {"type": "json_object"} + + +def test_response_format_omitted_when_not_supported(): + with patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') + v._validate("How do I make pasta?") + + _, kwargs = mock_llm.call_args + assert "response_format" not in kwargs + + +def test_response_format_omitted_when_litellm_check_fails(): + with patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + side_effect=Exception("litellm unavailable"), + ): + v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') + v._validate("How do I make pasta?") + + _, kwargs = mock_llm.call_args + assert "response_format" not in kwargs + + +# --------------------------------------------------------------------------- +# Prompt template content +# --------------------------------------------------------------------------- + + +def test_system_prompt_contains_topic_config(): + with patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + + assert TOPIC_CONFIG in v._system_prompt + + +def test_system_prompt_contains_json_instruction(): + with patch( + "app.core.validators.topic_relevance_openai.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + + assert "scope_violation" in v._system_prompt + assert "JSON" in v._system_prompt diff --git a/backend/scripts/run_all_evaluations.sh b/backend/scripts/run_all_evaluations.sh index 0da2402..d344c91 100755 --- a/backend/scripts/run_all_evaluations.sh +++ b/backend/scripts/run_all_evaluations.sh @@ -11,6 +11,7 @@ RUNNERS=( "$EVAL_DIR/gender_assumption_bias/run.py" "$EVAL_DIR/ban_list/run.py" "$EVAL_DIR/topic_relevance/run.py" + "$EVAL_DIR/topic_relevance_openai/run.py" "$EVAL_DIR/toxicity/run.py" ) From 0394bfcac1464e0a6e5b316a4e5efd2d412cd6da Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Tue, 2 Jun 2026 16:39:52 +0530 Subject: [PATCH 2/5] updates --- backend/app/api/routes/guardrails.py | 19 ++- backend/app/core/config.py | 1 + ...elevance_openai_safety_validator_config.py | 6 +- ...topic_relevance_safety_validator_config.py | 2 +- .../app/core/validators/topic_relevance.py | 4 +- .../core/validators/topic_relevance_openai.py | 60 +++------ backend/app/evaluation/topic_relevance/run.py | 89 ++++++------ .../evaluation/topic_relevance_openai/run.py | 127 ------------------ backend/app/tests/test_validate_with_guard.py | 2 - .../validators/test_topic_relevance_openai.py | 42 ++---- 10 files changed, 94 insertions(+), 258 deletions(-) delete mode 100644 backend/app/evaluation/topic_relevance_openai/run.py diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 4d0483b..4650141 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -106,6 +106,7 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N Resolves config-backed references for all validators in-place before guard execution: - BanList: fetches banned_words from the stored BanList when not provided inline. - TopicRelevance: fetches configuration and prompt_schema_version from stored config. + - TopicRelevanceOpenAI: fetches configuration from stored config. """ for validator in payload.validators: if isinstance(validator, BanListSafetyValidatorConfig): @@ -118,13 +119,7 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N ) validator.banned_words = ban_list.banned_words - elif isinstance( - validator, - ( - TopicRelevanceSafetyValidatorConfig, - TopicRelevanceOpenAISafetyValidatorConfig, - ), - ): + elif isinstance(validator, TopicRelevanceSafetyValidatorConfig): if validator.topic_relevance_config_id is not None: config = topic_relevance_crud.get( session=session, @@ -135,6 +130,16 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N validator.configuration = config.configuration validator.prompt_schema_version = config.prompt_schema_version + elif isinstance(validator, TopicRelevanceOpenAISafetyValidatorConfig): + if validator.topic_relevance_config_id is not None: + config = topic_relevance_crud.get( + session=session, + id=validator.topic_relevance_config_id, + organization_id=payload.organization_id, + project_id=payload.project_id, + ) + validator.configuration = config.configuration + def _validate_with_guard( payload: GuardrailRequest, diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 6d4ae94..85800d1 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -45,6 +45,7 @@ class Settings(BaseSettings): KAAPI_AUTH_TIMEOUT: int CORE_DIR: ClassVar[Path] = Path(__file__).resolve().parent OPENAI_API_KEY: str | None = None + DEFAULT_LLM_CALLABLE: str = "gpt-4o-mini" SLUR_LIST_FILENAME: ClassVar[str] = "curated_slurlist_hi_en.csv" diff --git a/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py b/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py index 73c6af7..875ecc1 100644 --- a/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py +++ b/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py @@ -9,8 +9,7 @@ class TopicRelevanceOpenAISafetyValidatorConfig(BaseValidatorConfig): type: Literal["topic_relevance_openai"] configuration: Optional[str] = None - prompt_schema_version: Optional[int] = None - llm_callable: str = "gpt-4o-mini" + llm_callable: str = settings.DEFAULT_LLM_CALLABLE threshold: int = 2 topic_relevance_config_id: Optional[UUID] = None @@ -21,8 +20,7 @@ def build(self): "Topic relevance (OpenAI) validation requires an OpenAI API key." ) return TopicRelevanceOpenAI( - topic_config=self.configuration or " ", - prompt_schema_version=self.prompt_schema_version or 1, + system_prompt=self.configuration or "", llm_callable=self.llm_callable, threshold=self.threshold, on_fail=self.resolve_on_fail(), diff --git a/backend/app/core/validators/config/topic_relevance_safety_validator_config.py b/backend/app/core/validators/config/topic_relevance_safety_validator_config.py index 53023a9..4dff8d3 100644 --- a/backend/app/core/validators/config/topic_relevance_safety_validator_config.py +++ b/backend/app/core/validators/config/topic_relevance_safety_validator_config.py @@ -12,7 +12,7 @@ class TopicRelevanceSafetyValidatorConfig(BaseValidatorConfig): type: Literal["topic_relevance"] configuration: Optional[str] = None prompt_schema_version: Optional[int] = None - llm_callable: str = "gpt-4o-mini" + llm_callable: str = settings.DEFAULT_LLM_CALLABLE topic_relevance_config_id: Optional[UUID] = None def build(self): diff --git a/backend/app/core/validators/topic_relevance.py b/backend/app/core/validators/topic_relevance.py index 22d2bcc..7241972 100644 --- a/backend/app/core/validators/topic_relevance.py +++ b/backend/app/core/validators/topic_relevance.py @@ -13,6 +13,8 @@ ) from guardrails.validators import FailResult, PassResult +from app.core.config import settings + # This should be present in all prompt templates to indicate where the topic configuration will be inserted _PROMPT_PLACEHOLDER = "{{TOPIC_CONFIGURATION}}" @@ -62,7 +64,7 @@ def __init__( self, topic_config: str, prompt_schema_version: int = 1, - llm_callable: str = "gpt-4o-mini", + llm_callable: str = settings.DEFAULT_LLM_CALLABLE, on_fail: Optional[Callable] = OnFailAction.NOOP, ): """Build the LLMCritic with a scope_violation metric from the topic configuration.""" diff --git a/backend/app/core/validators/topic_relevance_openai.py b/backend/app/core/validators/topic_relevance_openai.py index c53a2cd..5cf2802 100644 --- a/backend/app/core/validators/topic_relevance_openai.py +++ b/backend/app/core/validators/topic_relevance_openai.py @@ -1,8 +1,6 @@ from __future__ import annotations import json -from functools import lru_cache -from pathlib import Path from typing import Callable, Optional from litellm import completion, get_supported_openai_params @@ -15,80 +13,52 @@ register_validator, ) +from app.core.config import settings -_PROMPT_PLACEHOLDER = "{{TOPIC_CONFIGURATION}}" -_PROMPTS_DIR = Path(__file__).parent / "prompts" / "topic_relevance" -_JSON_INSTRUCTION = ( - "\n\nRespond ONLY with a JSON object in this exact format: " +_SCORING_INSTRUCTIONS = ( + "\n\nScore using:\n" + "3 = clearly within scope (directly matches a topic description)\n" + "2 = partially related (tangentially related or implicitly within scope)\n" + "1 = clearly outside scope (no relation to any listed topic)\n" + "\nRespond ONLY with a JSON object in this exact format: " '{"scope_violation": } where is the integer 1, 2, or 3.' ) -@lru_cache(maxsize=8) -def _load_prompt_template(prompt_schema_version: int) -> str: - if prompt_schema_version < 1: - raise ValueError("prompt_schema_version must be a positive integer") - prompt_file = _PROMPTS_DIR / f"v{prompt_schema_version}.md" - if not prompt_file.exists(): - raise ValueError( - f"Topic relevance prompt template for version {prompt_schema_version} not found" - ) - template = prompt_file.read_text(encoding="utf-8") - if _PROMPT_PLACEHOLDER not in template: - raise ValueError( - f"Prompt template v{prompt_schema_version} must contain {_PROMPT_PLACEHOLDER}" - ) - return template - - -def _build_system_prompt(prompt_schema_version: int, topic_config: str) -> str: - scope_text = topic_config.strip() - if not scope_text: - raise ValueError("topic_config cannot be empty") - template = _load_prompt_template(prompt_schema_version) - return template.replace(_PROMPT_PLACEHOLDER, scope_text) + _JSON_INSTRUCTION - - @register_validator(name="topic-relevance-openai", data_type="string") class TopicRelevanceOpenAI(Validator): """ Validates whether a user message is within the defined topic scope using a direct OpenAI/litellm call. + The caller supplies the full system prompt. The validator appends + hardcoded scoring and response-format instructions. + Scores 1–3 where 3 = clearly in scope, 2 = ambiguous, 1 = outside scope. Passes when score >= threshold (default 2). """ def __init__( self, - topic_config: str, - prompt_schema_version: int = 1, - llm_callable: str = "gpt-4o-mini", + system_prompt: str, + llm_callable: str = settings.DEFAULT_LLM_CALLABLE, threshold: int = 2, on_fail: Optional[Callable] = OnFailAction.NOOP, ): super().__init__(on_fail=on_fail) - self.topic_config = topic_config - self.prompt_schema_version = prompt_schema_version self.llm_callable = llm_callable self.threshold = threshold self._invalid_config_reason: Optional[str] = None self._system_prompt: Optional[str] = None self._supports_response_format: bool = False - if not topic_config or not topic_config.strip(): - self._invalid_config_reason = "topic_config is blank or missing" + if not system_prompt or not system_prompt.strip(): + self._invalid_config_reason = "system_prompt is blank or missing" return - try: - self._system_prompt = _build_system_prompt( - prompt_schema_version, topic_config - ) - except ValueError as e: - self._invalid_config_reason = str(e) - return + self._system_prompt = system_prompt.strip() + _SCORING_INSTRUCTIONS try: self._supports_response_format = "response_format" in ( diff --git a/backend/app/evaluation/topic_relevance/run.py b/backend/app/evaluation/topic_relevance/run.py index f1be3cc..15e3b94 100644 --- a/backend/app/evaluation/topic_relevance/run.py +++ b/backend/app/evaluation/topic_relevance/run.py @@ -5,7 +5,9 @@ import pandas as pd from guardrails.validators import FailResult +from app.core.config import settings from app.core.validators.topic_relevance import TopicRelevance +from app.core.validators.topic_relevance_openai import TopicRelevanceOpenAI from app.evaluation.common.helper import ( Profiler, build_evaluation_report, @@ -16,15 +18,9 @@ BASE_DIR = Path(__file__).resolve().parent.parent DATASETS_DIR = BASE_DIR / "datasets" / "topic_relevance" -OUT_DIR = BASE_DIR / "outputs" / "topic_relevance" +OUTPUTS_DIR = BASE_DIR / "outputs" -DEFAULT_CONFIG = { - "llm_callable": "gpt-4o-mini", - "prompt_schema_version": 1, -} - -# All evaluations defined here -EVALUATIONS = [ +DATASETS = [ { "domain": "education", "dataset": "education-topic-relevance-dataset.csv", @@ -37,30 +33,46 @@ }, ] +BACKENDS = [ + { + "name": "topic_relevance", + "out_dir": OUTPUTS_DIR / "topic_relevance", + "build": lambda tc: TopicRelevance( + topic_config=tc, + prompt_schema_version=1, + llm_callable=settings.DEFAULT_LLM_CALLABLE, + ), + "report_extra": { + "llm_callable": settings.DEFAULT_LLM_CALLABLE, + "prompt_schema_version": 1, + }, + }, + { + "name": "topic_relevance_openai", + "out_dir": OUTPUTS_DIR / "topic_relevance_openai", + "build": lambda tc: TopicRelevanceOpenAI( + system_prompt=tc, + llm_callable=settings.DEFAULT_LLM_CALLABLE, + threshold=2, + ), + "report_extra": { + "llm_callable": settings.DEFAULT_LLM_CALLABLE, + "threshold": 2, + }, + }, +] -def run_evaluation(config: dict) -> None: - """ - Run the topic relevance evaluation for a single domain config. - Loads the dataset and topic config (the plain-text scope definition describing allowed topics, - distinct from DEFAULT_CONFIG which holds model and prompt version settings), runs each input - through the TopicRelevance validator, computes binary and per-category metrics, and writes - prediction CSV and metrics JSON to the output directory. - """ - domain = config["domain"] - dataset_path = DATASETS_DIR / config["dataset"] - topic_config_path = DATASETS_DIR / config["topic_config"] - topic_config = topic_config_path.read_text() +def run_evaluation(dataset: dict, backend: dict) -> None: + domain = dataset["domain"] + topic_config = (DATASETS_DIR / dataset["topic_config"]).read_text() + dataset_path = DATASETS_DIR / dataset["dataset"] + out_dir: Path = backend["out_dir"] - print(f"\nRunning topic relevance evaluation: {domain}") + print(f"\nRunning {backend['name']} evaluation: {domain}") df = pd.read_csv(dataset_path) - - validator = TopicRelevance( - topic_config=topic_config, - prompt_schema_version=DEFAULT_CONFIG["prompt_schema_version"], - llm_callable=DEFAULT_CONFIG["llm_callable"], - ) + validator = backend["build"](topic_config) normalized_df = pd.DataFrame( { @@ -69,7 +81,6 @@ def run_evaluation(config: dict) -> None: "in_scope": df["scope"].apply(lambda x: 1 if x == "IN_SCOPE" else 0), } ) - normalized_df["y_true"] = (1 - normalized_df["in_scope"]).astype(int) with Profiler() as p: @@ -88,7 +99,6 @@ def run_evaluation(config: dict) -> None: ) metrics = compute_binary_metrics(normalized_df["y_true"], normalized_df["y_pred"]) - metrics["category_metrics"] = { str(cat): { "num_samples": int(len(g)), @@ -97,30 +107,27 @@ def run_evaluation(config: dict) -> None: for cat, g in normalized_df.groupby("category", dropna=False) } - OUT_DIR.mkdir(parents=True, exist_ok=True) - - write_csv(normalized_df, OUT_DIR / f"{domain}-predictions.csv") - + out_dir.mkdir(parents=True, exist_ok=True) + write_csv(normalized_df, out_dir / f"{domain}-predictions.csv") write_json( build_evaluation_report( - guardrail="topic_relevance", + guardrail=backend["name"], num_samples=len(normalized_df), profiler=p, dataset=str(dataset_path), - llm_callable=DEFAULT_CONFIG["llm_callable"], - prompt_schema_version=DEFAULT_CONFIG["prompt_schema_version"], + **backend["report_extra"], metrics=metrics, ), - OUT_DIR / f"{domain}-metrics.json", + out_dir / f"{domain}-metrics.json", ) - print(f"Completed {domain} evaluation") + print(f"Completed {backend['name']} {domain} evaluation") def main() -> None: - """Iterate over all entries in EVALUATIONS and run each domain evaluation in sequence.""" - for config in EVALUATIONS: - run_evaluation(config) + for backend in BACKENDS: + for dataset in DATASETS: + run_evaluation(dataset, backend) if __name__ == "__main__": diff --git a/backend/app/evaluation/topic_relevance_openai/run.py b/backend/app/evaluation/topic_relevance_openai/run.py deleted file mode 100644 index 8b2c996..0000000 --- a/backend/app/evaluation/topic_relevance_openai/run.py +++ /dev/null @@ -1,127 +0,0 @@ -from __future__ import annotations - -from pathlib import Path - -import pandas as pd -from guardrails.validators import FailResult - -from app.core.validators.topic_relevance_openai import TopicRelevanceOpenAI -from app.evaluation.common.helper import ( - Profiler, - build_evaluation_report, - compute_binary_metrics, - write_csv, - write_json, -) - -BASE_DIR = Path(__file__).resolve().parent.parent -DATASETS_DIR = BASE_DIR / "datasets" / "topic_relevance" -OUT_DIR = BASE_DIR / "outputs" / "topic_relevance_openai" - -DEFAULT_CONFIG = { - "llm_callable": "gpt-4o-mini", - "prompt_schema_version": 3, - "threshold": 2, -} - -EVALUATIONS = [ - { - "domain": "education", - "dataset": "education-topic-relevance-dataset.csv", - "topic_config": "education_topic_config.txt", - }, - { - "domain": "healthcare", - "dataset": "healthcare-topic-relevance-dataset.csv", - "topic_config": "healthcare_topic_config.txt", - }, -] - - -def run_evaluation(config: dict) -> None: - """ - Run the topic relevance (OpenAI) evaluation for a single domain config. - Loads dataset and topic config, runs each input through TopicRelevanceOpenAI, - computes binary and per-category metrics, and writes CSV and JSON to outputs. - """ - domain = config["domain"] - - dataset_path = DATASETS_DIR / config["dataset"] - topic_config_path = DATASETS_DIR / config["topic_config"] - topic_config = topic_config_path.read_text() - - print(f"\nRunning topic relevance (OpenAI) evaluation: {domain}") - - df = pd.read_csv(dataset_path) - - validator = TopicRelevanceOpenAI( - topic_config=topic_config, - prompt_schema_version=DEFAULT_CONFIG["prompt_schema_version"], - llm_callable=DEFAULT_CONFIG["llm_callable"], - threshold=DEFAULT_CONFIG["threshold"], - ) - - normalized_df = pd.DataFrame( - { - "input": df["input"].astype(str), - "category": df["category"].astype(str), - "in_scope": df["scope"].apply(lambda x: 1 if x == "IN_SCOPE" else 0), - } - ) - - normalized_df["y_true"] = (1 - normalized_df["in_scope"]).astype(int) - - with Profiler() as p: - results = normalized_df["input"].apply( - lambda x: p.record(lambda t: validator.validate(t, metadata=None), x) - ) - - normalized_df["y_pred"] = results.apply(lambda r: int(isinstance(r, FailResult))) - normalized_df["scope_score"] = results.apply( - lambda r: r.metadata.get("scope_score") - if getattr(r, "metadata", None) - else None - ) - normalized_df["error_message"] = results.apply( - lambda r: r.error_message if isinstance(r, FailResult) else "" - ) - - metrics = compute_binary_metrics(normalized_df["y_true"], normalized_df["y_pred"]) - - metrics["category_metrics"] = { - str(cat): { - "num_samples": int(len(g)), - **compute_binary_metrics(g["y_true"], g["y_pred"]), - } - for cat, g in normalized_df.groupby("category", dropna=False) - } - - OUT_DIR.mkdir(parents=True, exist_ok=True) - - write_csv(normalized_df, OUT_DIR / f"{domain}-predictions.csv") - - write_json( - build_evaluation_report( - guardrail="topic_relevance_openai", - num_samples=len(normalized_df), - profiler=p, - dataset=str(dataset_path), - llm_callable=DEFAULT_CONFIG["llm_callable"], - prompt_schema_version=DEFAULT_CONFIG["prompt_schema_version"], - threshold=DEFAULT_CONFIG["threshold"], - metrics=metrics, - ), - OUT_DIR / f"{domain}-metrics.json", - ) - - print(f"Completed {domain} evaluation") - - -def main() -> None: - """Iterate over all entries in EVALUATIONS and run each domain evaluation in sequence.""" - for config in EVALUATIONS: - run_evaluation(config) - - -if __name__ == "__main__": - main() diff --git a/backend/app/tests/test_validate_with_guard.py b/backend/app/tests/test_validate_with_guard.py index ec91791..9bdf4f1 100644 --- a/backend/app/tests/test_validate_with_guard.py +++ b/backend/app/tests/test_validate_with_guard.py @@ -289,13 +289,11 @@ def test_resolve_validator_configs_topic_relevance_openai_from_config_id(): with patch("app.api.routes.guardrails.topic_relevance_crud.get") as mock_get: mock_get.return_value = MagicMock( configuration="Healthcare topic scope text", - prompt_schema_version=3, ) _resolve_validator_configs(payload, mock_session) validator = payload.validators[0] assert validator.configuration == "Healthcare topic scope text" - assert validator.prompt_schema_version == 3 mock_get.assert_called_once_with( session=mock_session, id=validator.topic_relevance_config_id, diff --git a/backend/app/tests/validators/test_topic_relevance_openai.py b/backend/app/tests/validators/test_topic_relevance_openai.py index 93646ce..a8de8b7 100644 --- a/backend/app/tests/validators/test_topic_relevance_openai.py +++ b/backend/app/tests/validators/test_topic_relevance_openai.py @@ -22,7 +22,7 @@ def validator(): "app.core.validators.topic_relevance_openai.get_supported_openai_params", return_value=["response_format"], ): - return TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + return TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) # --------------------------------------------------------------------------- @@ -73,7 +73,7 @@ def test_custom_threshold_of_3_fails_on_score_2(): "app.core.validators.topic_relevance_openai.get_supported_openai_params", return_value=[], ): - v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG, threshold=3) + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG, threshold=3) with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: mock_llm.return_value = _make_llm_response('{"scope_violation": 2}') @@ -88,7 +88,7 @@ def test_custom_threshold_of_1_passes_on_score_1(): "app.core.validators.topic_relevance_openai.get_supported_openai_params", return_value=[], ): - v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG, threshold=1) + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG, threshold=1) with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: mock_llm.return_value = _make_llm_response('{"scope_violation": 1}') @@ -116,12 +116,12 @@ def test_fails_when_value_is_whitespace(validator): assert "Empty message" in result.error_message -def test_fails_when_topic_config_is_blank(): +def test_fails_when_system_prompt_is_blank(): with patch( "app.core.validators.topic_relevance_openai.get_supported_openai_params", return_value=[], ): - v = TopicRelevanceOpenAI(topic_config="") + v = TopicRelevanceOpenAI(system_prompt="") result = v._validate("Some input") @@ -129,12 +129,12 @@ def test_fails_when_topic_config_is_blank(): assert "blank" in result.error_message -def test_fails_when_topic_config_is_whitespace_only(): +def test_fails_when_system_prompt_is_whitespace_only(): with patch( "app.core.validators.topic_relevance_openai.get_supported_openai_params", return_value=[], ): - v = TopicRelevanceOpenAI(topic_config=" ") + v = TopicRelevanceOpenAI(system_prompt=" ") result = v._validate("Some input") @@ -142,24 +142,6 @@ def test_fails_when_topic_config_is_whitespace_only(): assert "blank" in result.error_message -# --------------------------------------------------------------------------- -# Prompt version -# --------------------------------------------------------------------------- - - -def test_invalid_prompt_version_causes_fail_on_validate(): - with patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", - return_value=[], - ): - v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG, prompt_schema_version=999) - - result = v._validate("Some input") - - assert isinstance(result, FailResult) - assert "999" in result.error_message - - # --------------------------------------------------------------------------- # LLM error handling # --------------------------------------------------------------------------- @@ -223,7 +205,7 @@ def test_response_format_passed_when_supported(): "app.core.validators.topic_relevance_openai.get_supported_openai_params", return_value=["response_format"], ): - v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') @@ -238,7 +220,7 @@ def test_response_format_omitted_when_not_supported(): "app.core.validators.topic_relevance_openai.get_supported_openai_params", return_value=[], ): - v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') @@ -253,7 +235,7 @@ def test_response_format_omitted_when_litellm_check_fails(): "app.core.validators.topic_relevance_openai.get_supported_openai_params", side_effect=Exception("litellm unavailable"), ): - v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') @@ -273,7 +255,7 @@ def test_system_prompt_contains_topic_config(): "app.core.validators.topic_relevance_openai.get_supported_openai_params", return_value=[], ): - v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) assert TOPIC_CONFIG in v._system_prompt @@ -283,7 +265,7 @@ def test_system_prompt_contains_json_instruction(): "app.core.validators.topic_relevance_openai.get_supported_openai_params", return_value=[], ): - v = TopicRelevanceOpenAI(topic_config=TOPIC_CONFIG) + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) assert "scope_violation" in v._system_prompt assert "JSON" in v._system_prompt From a3ca650a24f2a53176b14530973fedfe77d2a5d3 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Tue, 2 Jun 2026 16:41:48 +0530 Subject: [PATCH 3/5] added threshold to settings --- backend/app/core/config.py | 1 + .../config/topic_relevance_openai_safety_validator_config.py | 2 +- backend/app/core/validators/topic_relevance_openai.py | 2 +- backend/app/evaluation/topic_relevance/run.py | 4 ++-- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 85800d1..1f001f5 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -46,6 +46,7 @@ class Settings(BaseSettings): CORE_DIR: ClassVar[Path] = Path(__file__).resolve().parent OPENAI_API_KEY: str | None = None DEFAULT_LLM_CALLABLE: str = "gpt-4o-mini" + TOPIC_RELEVANCE_OPENAI_THRESHOLD: int = 2 SLUR_LIST_FILENAME: ClassVar[str] = "curated_slurlist_hi_en.csv" diff --git a/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py b/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py index 875ecc1..d165dff 100644 --- a/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py +++ b/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py @@ -10,7 +10,7 @@ class TopicRelevanceOpenAISafetyValidatorConfig(BaseValidatorConfig): type: Literal["topic_relevance_openai"] configuration: Optional[str] = None llm_callable: str = settings.DEFAULT_LLM_CALLABLE - threshold: int = 2 + threshold: int = settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD topic_relevance_config_id: Optional[UUID] = None def build(self): diff --git a/backend/app/core/validators/topic_relevance_openai.py b/backend/app/core/validators/topic_relevance_openai.py index 5cf2802..224820e 100644 --- a/backend/app/core/validators/topic_relevance_openai.py +++ b/backend/app/core/validators/topic_relevance_openai.py @@ -43,7 +43,7 @@ def __init__( self, system_prompt: str, llm_callable: str = settings.DEFAULT_LLM_CALLABLE, - threshold: int = 2, + threshold: int = settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD, on_fail: Optional[Callable] = OnFailAction.NOOP, ): super().__init__(on_fail=on_fail) diff --git a/backend/app/evaluation/topic_relevance/run.py b/backend/app/evaluation/topic_relevance/run.py index 15e3b94..2979d4c 100644 --- a/backend/app/evaluation/topic_relevance/run.py +++ b/backend/app/evaluation/topic_relevance/run.py @@ -53,11 +53,11 @@ "build": lambda tc: TopicRelevanceOpenAI( system_prompt=tc, llm_callable=settings.DEFAULT_LLM_CALLABLE, - threshold=2, + threshold=settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD, ), "report_extra": { "llm_callable": settings.DEFAULT_LLM_CALLABLE, - "threshold": 2, + "threshold": settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD, }, }, ] From 7532f41969daf19fcf1a53ab967fad27c95b397d Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Tue, 2 Jun 2026 17:12:08 +0530 Subject: [PATCH 4/5] resolved comments --- .../core/validators/topic_relevance_openai.py | 9 ++++-- .../validators/test_topic_relevance_openai.py | 31 +++++++++++++++++++ backend/scripts/run_all_evaluations.sh | 1 - 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/backend/app/core/validators/topic_relevance_openai.py b/backend/app/core/validators/topic_relevance_openai.py index 224820e..bc7ff55 100644 --- a/backend/app/core/validators/topic_relevance_openai.py +++ b/backend/app/core/validators/topic_relevance_openai.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import re from typing import Callable, Optional from litellm import completion, get_supported_openai_params @@ -92,9 +93,13 @@ def _validate(self, value: str, metadata: dict = None) -> ValidationResult: return FailResult(error_message=f"LLM call failed: {e}") try: - data = json.loads(content) + text = re.sub(r"```(?:json)?\s*|\s*```", "", content).strip() + match = re.search(r"\{[^{}]*\}", text) + if not match: + raise ValueError("no JSON object found in response") + data = json.loads(match.group()) score = data.get("scope_violation") - if not isinstance(score, int) or score not in (1, 2, 3): + if type(score) is not int or score not in (1, 2, 3): raise ValueError(f"unexpected score value: {score!r}") except Exception as e: return FailResult( diff --git a/backend/app/tests/validators/test_topic_relevance_openai.py b/backend/app/tests/validators/test_topic_relevance_openai.py index a8de8b7..e4c86e0 100644 --- a/backend/app/tests/validators/test_topic_relevance_openai.py +++ b/backend/app/tests/validators/test_topic_relevance_openai.py @@ -195,6 +195,37 @@ def test_fails_gracefully_when_score_is_a_string(validator): assert "unparseable" in result.error_message +def test_passes_when_response_wrapped_in_markdown_fence(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response( + '```json\n{"scope_violation": 3}\n```' + ) + result = validator._validate("How do I make pasta?") + + assert isinstance(result, PassResult) + assert result.metadata["scope_score"] == 3 + + +def test_passes_when_response_has_surrounding_prose(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response( + 'Sure! Here is my evaluation: {"scope_violation": 2}' + ) + result = validator._validate("Something vaguely food related") + + assert isinstance(result, PassResult) + assert result.metadata["scope_score"] == 2 + + +def test_fails_when_score_is_boolean(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": true}') + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "unparseable" in result.error_message + + # --------------------------------------------------------------------------- # response_format forwarding # --------------------------------------------------------------------------- diff --git a/backend/scripts/run_all_evaluations.sh b/backend/scripts/run_all_evaluations.sh index d344c91..0da2402 100755 --- a/backend/scripts/run_all_evaluations.sh +++ b/backend/scripts/run_all_evaluations.sh @@ -11,7 +11,6 @@ RUNNERS=( "$EVAL_DIR/gender_assumption_bias/run.py" "$EVAL_DIR/ban_list/run.py" "$EVAL_DIR/topic_relevance/run.py" - "$EVAL_DIR/topic_relevance_openai/run.py" "$EVAL_DIR/toxicity/run.py" ) From 80727ffab16d9df12671d1625dc5a86d48ff2825 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 3 Jun 2026 09:06:47 +0530 Subject: [PATCH 5/5] cleanup PR --- backend/app/api/routes/guardrails.py | 42 +++++++------- backend/app/core/constants.py | 4 ++ ...elevance_openai_safety_validator_config.py | 10 +++- ...topic_relevance_safety_validator_config.py | 4 +- backend/app/core/validators/llm_utils.py | 15 +++++ .../app/core/validators/topic_relevance.py | 34 ++++++----- .../core/validators/topic_relevance_openai.py | 56 +++++++++++-------- backend/app/tests/test_llm_validators.py | 2 +- .../validators/test_topic_relevance_openai.py | 20 +++---- 9 files changed, 106 insertions(+), 81 deletions(-) create mode 100644 backend/app/core/validators/llm_utils.py diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 4650141..fd24954 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -1,5 +1,5 @@ -from uuid import UUID import uuid +from uuid import UUID from fastapi import APIRouter from guardrails.guard import Guard @@ -14,24 +14,24 @@ REPHRASE_ON_FAIL_PREFIX, ) from app.core.enum import ValidatorType -from app.core.guardrail_controller import build_guard, get_validator_config_models from app.core.exception_handlers import _safe_error_message +from app.core.guardrail_controller import build_guard, get_validator_config_models from app.core.validators.config.ban_list_safety_validator_config import ( BanListSafetyValidatorConfig, ) -from app.crud.ban_list import ban_list_crud -from app.crud.topic_relevance import topic_relevance_crud -from app.crud.request_log import RequestLogCrud -from app.crud.validator_log import ValidatorLogCrud -from app.core.validators.config.topic_relevance_safety_validator_config import ( - TopicRelevanceSafetyValidatorConfig, -) from app.core.validators.config.topic_relevance_openai_safety_validator_config import ( TopicRelevanceOpenAISafetyValidatorConfig, ) -from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse +from app.core.validators.config.topic_relevance_safety_validator_config import ( + TopicRelevanceSafetyValidatorConfig, +) +from app.crud.ban_list import ban_list_crud +from app.crud.request_log import RequestLogCrud +from app.crud.topic_relevance import topic_relevance_crud +from app.crud.validator_log import ValidatorLogCrud from app.models.logging.request_log import RequestLogUpdate, RequestStatus from app.models.logging.validator_log import ValidatorLog, ValidatorOutcome +from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse from app.utils import APIResponse, load_description router = APIRouter(prefix="/guardrails", tags=["guardrails"]) @@ -119,18 +119,13 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N ) validator.banned_words = ban_list.banned_words - elif isinstance(validator, TopicRelevanceSafetyValidatorConfig): - if validator.topic_relevance_config_id is not None: - config = topic_relevance_crud.get( - session=session, - id=validator.topic_relevance_config_id, - organization_id=payload.organization_id, - project_id=payload.project_id, - ) - validator.configuration = config.configuration - validator.prompt_schema_version = config.prompt_schema_version - - elif isinstance(validator, TopicRelevanceOpenAISafetyValidatorConfig): + elif isinstance( + validator, + ( + TopicRelevanceSafetyValidatorConfig, + TopicRelevanceOpenAISafetyValidatorConfig, + ), + ): if validator.topic_relevance_config_id is not None: config = topic_relevance_crud.get( session=session, @@ -139,6 +134,9 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N project_id=payload.project_id, ) validator.configuration = config.configuration + # Only the LLMCritic-backed variant carries a prompt schema version. + if isinstance(validator, TopicRelevanceSafetyValidatorConfig): + validator.prompt_schema_version = config.prompt_schema_version def _validate_with_guard( diff --git a/backend/app/core/constants.py b/backend/app/core/constants.py index 1b83095..5fa2273 100644 --- a/backend/app/core/constants.py +++ b/backend/app/core/constants.py @@ -11,6 +11,10 @@ f"{LLM_CRITIC_ERROR_MESSAGE} Please rephrase without unsafe content." ) +# Topic relevance validators (shared by the LLMCritic- and litellm-backed variants) +EMPTY_MESSAGE_ERROR = "Empty message." +TOPIC_OUT_OF_SCOPE_ERROR = "Input is outside the allowed topic scope." + VALIDATOR_CONFIG_SYSTEM_FIELDS = { "organization_id", "project_id", diff --git a/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py b/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py index d165dff..5859bd9 100644 --- a/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py +++ b/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py @@ -1,16 +1,20 @@ from typing import Literal, Optional from uuid import UUID -from app.core.validators.topic_relevance_openai import TopicRelevanceOpenAI -from app.core.validators.config.base_validator_config import BaseValidatorConfig +from pydantic import Field + from app.core.config import settings +from app.core.validators.config.base_validator_config import BaseValidatorConfig +from app.core.validators.topic_relevance_openai import TopicRelevanceOpenAI class TopicRelevanceOpenAISafetyValidatorConfig(BaseValidatorConfig): type: Literal["topic_relevance_openai"] configuration: Optional[str] = None llm_callable: str = settings.DEFAULT_LLM_CALLABLE - threshold: int = settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD + threshold: int = Field( + default=settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD, ge=1, le=3 + ) topic_relevance_config_id: Optional[UUID] = None def build(self): diff --git a/backend/app/core/validators/config/topic_relevance_safety_validator_config.py b/backend/app/core/validators/config/topic_relevance_safety_validator_config.py index 4dff8d3..743e053 100644 --- a/backend/app/core/validators/config/topic_relevance_safety_validator_config.py +++ b/backend/app/core/validators/config/topic_relevance_safety_validator_config.py @@ -1,11 +1,9 @@ from typing import Literal, Optional from uuid import UUID -from pydantic import model_validator - from app.core.config import settings -from app.core.validators.topic_relevance import TopicRelevance from app.core.validators.config.base_validator_config import BaseValidatorConfig +from app.core.validators.topic_relevance import TopicRelevance class TopicRelevanceSafetyValidatorConfig(BaseValidatorConfig): diff --git a/backend/app/core/validators/llm_utils.py b/backend/app/core/validators/llm_utils.py new file mode 100644 index 0000000..3db3455 --- /dev/null +++ b/backend/app/core/validators/llm_utils.py @@ -0,0 +1,15 @@ +from litellm import get_supported_openai_params + +# Passed to litellm/OpenAI to force a strict JSON object response. +JSON_OBJECT_RESPONSE_FORMAT = {"type": "json_object"} + + +def supports_response_format(model: str) -> bool: + """Return True if the given model supports the OpenAI ``response_format`` param. + + Falls back to False if litellm cannot resolve the model's capabilities. + """ + try: + return "response_format" in (get_supported_openai_params(model=model) or []) + except Exception: + return False diff --git a/backend/app/core/validators/topic_relevance.py b/backend/app/core/validators/topic_relevance.py index 7241972..c98d446 100644 --- a/backend/app/core/validators/topic_relevance.py +++ b/backend/app/core/validators/topic_relevance.py @@ -4,17 +4,22 @@ from pathlib import Path from typing import Callable, Optional -from guardrails.hub import LLMCritic from guardrails import OnFailAction +from guardrails.hub import LLMCritic from guardrails.validators import ( + FailResult, + PassResult, + ValidationResult, Validator, register_validator, - ValidationResult, ) -from guardrails.validators import FailResult, PassResult from app.core.config import settings - +from app.core.constants import EMPTY_MESSAGE_ERROR, TOPIC_OUT_OF_SCOPE_ERROR +from app.core.validators.llm_utils import ( + JSON_OBJECT_RESPONSE_FORMAT, + supports_response_format, +) # This should be present in all prompt templates to indicate where the topic configuration will be inserted _PROMPT_PLACEHOLDER = "{{TOPIC_CONFIGURATION}}" @@ -80,15 +85,6 @@ def __init__( self._critic = None return - try: - from litellm import get_supported_openai_params - - supports_response_format = "response_format" in ( - get_supported_openai_params(model=llm_callable) or [] - ) - except Exception: - supports_response_format = False - self._critic = LLMCritic( metrics={ "scope_violation": { @@ -103,19 +99,21 @@ def __init__( llm_callable=llm_callable, on_fail=on_fail, **( - {"llm_kwargs": {"response_format": {"type": "json_object"}}} - if supports_response_format + {"llm_kwargs": {"response_format": JSON_OBJECT_RESPONSE_FORMAT}} + if supports_response_format(llm_callable) else {} ), ) - def _validate(self, value: str, metadata: dict = None) -> ValidationResult: + def _validate( + self, value: str, metadata: Optional[dict] = None + ) -> ValidationResult: """Run the LLMCritic and return a PassResult or FailResult with the scope score.""" if self._invalid_config_reason: return FailResult(error_message=self._invalid_config_reason) if not value or not value.strip(): - return FailResult(error_message="Empty message.") + return FailResult(error_message=EMPTY_MESSAGE_ERROR) try: result = self._critic.validate(value, metadata) @@ -129,7 +127,7 @@ def _validate(self, value: str, metadata: dict = None) -> ValidationResult: if isinstance(result, FailResult): return FailResult( - error_message="Input is outside the allowed topic scope.", + error_message=TOPIC_OUT_OF_SCOPE_ERROR, metadata={"scope_score": score}, ) diff --git a/backend/app/core/validators/topic_relevance_openai.py b/backend/app/core/validators/topic_relevance_openai.py index bc7ff55..5e528fc 100644 --- a/backend/app/core/validators/topic_relevance_openai.py +++ b/backend/app/core/validators/topic_relevance_openai.py @@ -4,7 +4,6 @@ import re from typing import Callable, Optional -from litellm import completion, get_supported_openai_params from guardrails import OnFailAction from guardrails.validators import ( FailResult, @@ -13,17 +12,28 @@ Validator, register_validator, ) +from litellm import completion from app.core.config import settings +from app.core.constants import EMPTY_MESSAGE_ERROR, TOPIC_OUT_OF_SCOPE_ERROR +from app.core.validators.llm_utils import ( + JSON_OBJECT_RESPONSE_FORMAT, + supports_response_format, +) +# Valid scope scores returned by the model; the highest means "clearly in scope". +_VALID_SCORES = (1, 2, 3) +# Cap the response: a single ``{"scope_violation": }`` object is tiny. +_MAX_TOKENS = 50 _SCORING_INSTRUCTIONS = ( "\n\nScore using:\n" - "3 = clearly within scope (directly matches a topic description)\n" - "2 = partially related (tangentially related or implicitly within scope)\n" - "1 = clearly outside scope (no relation to any listed topic)\n" + f"{_VALID_SCORES[2]} = clearly within scope (directly matches a topic description)\n" + f"{_VALID_SCORES[1]} = partially related (tangentially related or implicitly within scope)\n" + f"{_VALID_SCORES[0]} = clearly outside scope (no relation to any listed topic)\n" "\nRespond ONLY with a JSON object in this exact format: " - '{"scope_violation": } where is the integer 1, 2, or 3.' + '{"scope_violation": } where is the integer ' + f"{_VALID_SCORES[0]}, {_VALID_SCORES[1]}, or {_VALID_SCORES[2]}." ) @@ -36,8 +46,8 @@ class TopicRelevanceOpenAI(Validator): The caller supplies the full system prompt. The validator appends hardcoded scoring and response-format instructions. - Scores 1–3 where 3 = clearly in scope, 2 = ambiguous, 1 = outside scope. - Passes when score >= threshold (default 2). + Scores 1–3 where 3 = clearly in scope, 2 = partially related, + 1 = outside scope. Passes when score >= threshold (default 2). """ def __init__( @@ -60,32 +70,28 @@ def __init__( return self._system_prompt = system_prompt.strip() + _SCORING_INSTRUCTIONS + self._supports_response_format = supports_response_format(llm_callable) - try: - self._supports_response_format = "response_format" in ( - get_supported_openai_params(model=llm_callable) or [] - ) - except Exception: - self._supports_response_format = False - - def _validate(self, value: str, metadata: dict = None) -> ValidationResult: + def _validate( + self, value: str, metadata: Optional[dict] = None + ) -> ValidationResult: if self._invalid_config_reason: return FailResult(error_message=self._invalid_config_reason) if not value or not value.strip(): - return FailResult(error_message="Empty message.") + return FailResult(error_message=EMPTY_MESSAGE_ERROR) try: - kwargs = dict( - model=self.llm_callable, - messages=[ + kwargs = { + "model": self.llm_callable, + "messages": [ {"role": "system", "content": self._system_prompt}, {"role": "user", "content": value}, ], - max_tokens=50, - ) + "max_tokens": _MAX_TOKENS, + } if self._supports_response_format: - kwargs["response_format"] = {"type": "json_object"} + kwargs["response_format"] = JSON_OBJECT_RESPONSE_FORMAT response = completion(**kwargs) content = response.choices[0].message.content.strip() @@ -99,7 +105,9 @@ def _validate(self, value: str, metadata: dict = None) -> ValidationResult: raise ValueError("no JSON object found in response") data = json.loads(match.group()) score = data.get("scope_violation") - if type(score) is not int or score not in (1, 2, 3): + # `type(score) is not int` (not isinstance) deliberately rejects bool, + # which is an int subclass, so `true`/`false` are treated as invalid. + if type(score) is not int or score not in _VALID_SCORES: raise ValueError(f"unexpected score value: {score!r}") except Exception as e: return FailResult( @@ -110,6 +118,6 @@ def _validate(self, value: str, metadata: dict = None) -> ValidationResult: return PassResult(value=value, metadata={"scope_score": score}) return FailResult( - error_message="Input is outside the allowed topic scope.", + error_message=TOPIC_OUT_OF_SCOPE_ERROR, metadata={"scope_score": score}, ) diff --git a/backend/app/tests/test_llm_validators.py b/backend/app/tests/test_llm_validators.py index 9cec0b0..58fb9f0 100644 --- a/backend/app/tests/test_llm_validators.py +++ b/backend/app/tests/test_llm_validators.py @@ -106,7 +106,7 @@ def test_topic_relevance_openai_blank_config_returns_fail_result(): ) with patch(_TOPIC_RELEVANCE_OPENAI_SETTINGS_PATH) as mock_settings, patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", + "app.core.validators.llm_utils.get_supported_openai_params", return_value=[], ): mock_settings.OPENAI_API_KEY = "sk-test-key" diff --git a/backend/app/tests/validators/test_topic_relevance_openai.py b/backend/app/tests/validators/test_topic_relevance_openai.py index e4c86e0..54ba6d7 100644 --- a/backend/app/tests/validators/test_topic_relevance_openai.py +++ b/backend/app/tests/validators/test_topic_relevance_openai.py @@ -19,7 +19,7 @@ def _make_llm_response(json_text: str) -> MagicMock: @pytest.fixture def validator(): with patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", + "app.core.validators.llm_utils.get_supported_openai_params", return_value=["response_format"], ): return TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) @@ -70,7 +70,7 @@ def test_fails_when_score_is_1(validator): def test_custom_threshold_of_3_fails_on_score_2(): with patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", + "app.core.validators.llm_utils.get_supported_openai_params", return_value=[], ): v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG, threshold=3) @@ -85,7 +85,7 @@ def test_custom_threshold_of_3_fails_on_score_2(): def test_custom_threshold_of_1_passes_on_score_1(): with patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", + "app.core.validators.llm_utils.get_supported_openai_params", return_value=[], ): v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG, threshold=1) @@ -118,7 +118,7 @@ def test_fails_when_value_is_whitespace(validator): def test_fails_when_system_prompt_is_blank(): with patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", + "app.core.validators.llm_utils.get_supported_openai_params", return_value=[], ): v = TopicRelevanceOpenAI(system_prompt="") @@ -131,7 +131,7 @@ def test_fails_when_system_prompt_is_blank(): def test_fails_when_system_prompt_is_whitespace_only(): with patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", + "app.core.validators.llm_utils.get_supported_openai_params", return_value=[], ): v = TopicRelevanceOpenAI(system_prompt=" ") @@ -233,7 +233,7 @@ def test_fails_when_score_is_boolean(validator): def test_response_format_passed_when_supported(): with patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", + "app.core.validators.llm_utils.get_supported_openai_params", return_value=["response_format"], ): v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) @@ -248,7 +248,7 @@ def test_response_format_passed_when_supported(): def test_response_format_omitted_when_not_supported(): with patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", + "app.core.validators.llm_utils.get_supported_openai_params", return_value=[], ): v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) @@ -263,7 +263,7 @@ def test_response_format_omitted_when_not_supported(): def test_response_format_omitted_when_litellm_check_fails(): with patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", + "app.core.validators.llm_utils.get_supported_openai_params", side_effect=Exception("litellm unavailable"), ): v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) @@ -283,7 +283,7 @@ def test_response_format_omitted_when_litellm_check_fails(): def test_system_prompt_contains_topic_config(): with patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", + "app.core.validators.llm_utils.get_supported_openai_params", return_value=[], ): v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) @@ -293,7 +293,7 @@ def test_system_prompt_contains_topic_config(): def test_system_prompt_contains_json_instruction(): with patch( - "app.core.validators.topic_relevance_openai.get_supported_openai_params", + "app.core.validators.llm_utils.get_supported_openai_params", return_value=[], ): v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG)