diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 7281718..fd24954 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -1,5 +1,5 @@ -from uuid import UUID import uuid +from uuid import UUID from fastapi import APIRouter from guardrails.guard import Guard @@ -14,21 +14,24 @@ REPHRASE_ON_FAIL_PREFIX, ) from app.core.enum import ValidatorType -from app.core.guardrail_controller import build_guard, get_validator_config_models from app.core.exception_handlers import _safe_error_message +from app.core.guardrail_controller import build_guard, get_validator_config_models from app.core.validators.config.ban_list_safety_validator_config import ( BanListSafetyValidatorConfig, ) -from app.crud.ban_list import ban_list_crud -from app.crud.topic_relevance import topic_relevance_crud -from app.crud.request_log import RequestLogCrud -from app.crud.validator_log import ValidatorLogCrud +from app.core.validators.config.topic_relevance_openai_safety_validator_config import ( + TopicRelevanceOpenAISafetyValidatorConfig, +) from app.core.validators.config.topic_relevance_safety_validator_config import ( TopicRelevanceSafetyValidatorConfig, ) -from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse +from app.crud.ban_list import ban_list_crud +from app.crud.request_log import RequestLogCrud +from app.crud.topic_relevance import topic_relevance_crud +from app.crud.validator_log import ValidatorLogCrud from app.models.logging.request_log import RequestLogUpdate, RequestStatus from app.models.logging.validator_log import ValidatorLog, ValidatorOutcome +from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse from app.utils import APIResponse, load_description router = APIRouter(prefix="/guardrails", tags=["guardrails"]) @@ -103,6 +106,7 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N Resolves config-backed references for all validators in-place before guard execution: - BanList: fetches banned_words from the stored BanList when not provided inline. - TopicRelevance: fetches configuration and prompt_schema_version from stored config. + - TopicRelevanceOpenAI: fetches configuration from stored config. """ for validator in payload.validators: if isinstance(validator, BanListSafetyValidatorConfig): @@ -115,7 +119,13 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N ) validator.banned_words = ban_list.banned_words - elif isinstance(validator, TopicRelevanceSafetyValidatorConfig): + elif isinstance( + validator, + ( + TopicRelevanceSafetyValidatorConfig, + TopicRelevanceOpenAISafetyValidatorConfig, + ), + ): if validator.topic_relevance_config_id is not None: config = topic_relevance_crud.get( session=session, @@ -124,7 +134,9 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N project_id=payload.project_id, ) validator.configuration = config.configuration - validator.prompt_schema_version = config.prompt_schema_version + # Only the LLMCritic-backed variant carries a prompt schema version. + if isinstance(validator, TopicRelevanceSafetyValidatorConfig): + validator.prompt_schema_version = config.prompt_schema_version def _validate_with_guard( diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 6d4ae94..1f001f5 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -45,6 +45,8 @@ class Settings(BaseSettings): KAAPI_AUTH_TIMEOUT: int CORE_DIR: ClassVar[Path] = Path(__file__).resolve().parent OPENAI_API_KEY: str | None = None + DEFAULT_LLM_CALLABLE: str = "gpt-4o-mini" + TOPIC_RELEVANCE_OPENAI_THRESHOLD: int = 2 SLUR_LIST_FILENAME: ClassVar[str] = "curated_slurlist_hi_en.csv" diff --git a/backend/app/core/constants.py b/backend/app/core/constants.py index 1b83095..5fa2273 100644 --- a/backend/app/core/constants.py +++ b/backend/app/core/constants.py @@ -11,6 +11,10 @@ f"{LLM_CRITIC_ERROR_MESSAGE} Please rephrase without unsafe content." ) +# Topic relevance validators (shared by the LLMCritic- and litellm-backed variants) +EMPTY_MESSAGE_ERROR = "Empty message." +TOPIC_OUT_OF_SCOPE_ERROR = "Input is outside the allowed topic scope." + VALIDATOR_CONFIG_SYSTEM_FIELDS = { "organization_id", "project_id", diff --git a/backend/app/core/enum.py b/backend/app/core/enum.py index ff653c5..70eff2d 100644 --- a/backend/app/core/enum.py +++ b/backend/app/core/enum.py @@ -32,6 +32,7 @@ class ValidatorType(Enum): GenderAssumptionBias = "gender_assumption_bias" BanList = "ban_list" TopicRelevance = "topic_relevance" + TopicRelevanceOpenAI = "topic_relevance_openai" LLMCritic = "llm_critic" LlamaGuard7B = "llamaguard_7b" ProfanityFree = "profanity_free" diff --git a/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py b/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py new file mode 100644 index 0000000..5859bd9 --- /dev/null +++ b/backend/app/core/validators/config/topic_relevance_openai_safety_validator_config.py @@ -0,0 +1,31 @@ +from typing import Literal, Optional +from uuid import UUID + +from pydantic import Field + +from app.core.config import settings +from app.core.validators.config.base_validator_config import BaseValidatorConfig +from app.core.validators.topic_relevance_openai import TopicRelevanceOpenAI + + +class TopicRelevanceOpenAISafetyValidatorConfig(BaseValidatorConfig): + type: Literal["topic_relevance_openai"] + configuration: Optional[str] = None + llm_callable: str = settings.DEFAULT_LLM_CALLABLE + threshold: int = Field( + default=settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD, ge=1, le=3 + ) + topic_relevance_config_id: Optional[UUID] = None + + def build(self): + if not settings.OPENAI_API_KEY: + raise ValueError( + "OPENAI_API_KEY is not configured. " + "Topic relevance (OpenAI) validation requires an OpenAI API key." + ) + return TopicRelevanceOpenAI( + system_prompt=self.configuration or "", + llm_callable=self.llm_callable, + threshold=self.threshold, + on_fail=self.resolve_on_fail(), + ) diff --git a/backend/app/core/validators/config/topic_relevance_safety_validator_config.py b/backend/app/core/validators/config/topic_relevance_safety_validator_config.py index 53023a9..743e053 100644 --- a/backend/app/core/validators/config/topic_relevance_safety_validator_config.py +++ b/backend/app/core/validators/config/topic_relevance_safety_validator_config.py @@ -1,18 +1,16 @@ from typing import Literal, Optional from uuid import UUID -from pydantic import model_validator - from app.core.config import settings -from app.core.validators.topic_relevance import TopicRelevance from app.core.validators.config.base_validator_config import BaseValidatorConfig +from app.core.validators.topic_relevance import TopicRelevance class TopicRelevanceSafetyValidatorConfig(BaseValidatorConfig): type: Literal["topic_relevance"] configuration: Optional[str] = None prompt_schema_version: Optional[int] = None - llm_callable: str = "gpt-4o-mini" + llm_callable: str = settings.DEFAULT_LLM_CALLABLE topic_relevance_config_id: Optional[UUID] = None def build(self): diff --git a/backend/app/core/validators/llm_utils.py b/backend/app/core/validators/llm_utils.py new file mode 100644 index 0000000..3db3455 --- /dev/null +++ b/backend/app/core/validators/llm_utils.py @@ -0,0 +1,15 @@ +from litellm import get_supported_openai_params + +# Passed to litellm/OpenAI to force a strict JSON object response. +JSON_OBJECT_RESPONSE_FORMAT = {"type": "json_object"} + + +def supports_response_format(model: str) -> bool: + """Return True if the given model supports the OpenAI ``response_format`` param. + + Falls back to False if litellm cannot resolve the model's capabilities. + """ + try: + return "response_format" in (get_supported_openai_params(model=model) or []) + except Exception: + return False diff --git a/backend/app/core/validators/topic_relevance.py b/backend/app/core/validators/topic_relevance.py index 22d2bcc..c98d446 100644 --- a/backend/app/core/validators/topic_relevance.py +++ b/backend/app/core/validators/topic_relevance.py @@ -4,15 +4,22 @@ from pathlib import Path from typing import Callable, Optional -from guardrails.hub import LLMCritic from guardrails import OnFailAction +from guardrails.hub import LLMCritic from guardrails.validators import ( + FailResult, + PassResult, + ValidationResult, Validator, register_validator, - ValidationResult, ) -from guardrails.validators import FailResult, PassResult +from app.core.config import settings +from app.core.constants import EMPTY_MESSAGE_ERROR, TOPIC_OUT_OF_SCOPE_ERROR +from app.core.validators.llm_utils import ( + JSON_OBJECT_RESPONSE_FORMAT, + supports_response_format, +) # This should be present in all prompt templates to indicate where the topic configuration will be inserted _PROMPT_PLACEHOLDER = "{{TOPIC_CONFIGURATION}}" @@ -62,7 +69,7 @@ def __init__( self, topic_config: str, prompt_schema_version: int = 1, - llm_callable: str = "gpt-4o-mini", + llm_callable: str = settings.DEFAULT_LLM_CALLABLE, on_fail: Optional[Callable] = OnFailAction.NOOP, ): """Build the LLMCritic with a scope_violation metric from the topic configuration.""" @@ -78,15 +85,6 @@ def __init__( self._critic = None return - try: - from litellm import get_supported_openai_params - - supports_response_format = "response_format" in ( - get_supported_openai_params(model=llm_callable) or [] - ) - except Exception: - supports_response_format = False - self._critic = LLMCritic( metrics={ "scope_violation": { @@ -101,19 +99,21 @@ def __init__( llm_callable=llm_callable, on_fail=on_fail, **( - {"llm_kwargs": {"response_format": {"type": "json_object"}}} - if supports_response_format + {"llm_kwargs": {"response_format": JSON_OBJECT_RESPONSE_FORMAT}} + if supports_response_format(llm_callable) else {} ), ) - def _validate(self, value: str, metadata: dict = None) -> ValidationResult: + def _validate( + self, value: str, metadata: Optional[dict] = None + ) -> ValidationResult: """Run the LLMCritic and return a PassResult or FailResult with the scope score.""" if self._invalid_config_reason: return FailResult(error_message=self._invalid_config_reason) if not value or not value.strip(): - return FailResult(error_message="Empty message.") + return FailResult(error_message=EMPTY_MESSAGE_ERROR) try: result = self._critic.validate(value, metadata) @@ -127,7 +127,7 @@ def _validate(self, value: str, metadata: dict = None) -> ValidationResult: if isinstance(result, FailResult): return FailResult( - error_message="Input is outside the allowed topic scope.", + error_message=TOPIC_OUT_OF_SCOPE_ERROR, metadata={"scope_score": score}, ) diff --git a/backend/app/core/validators/topic_relevance_openai.py b/backend/app/core/validators/topic_relevance_openai.py new file mode 100644 index 0000000..5e528fc --- /dev/null +++ b/backend/app/core/validators/topic_relevance_openai.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import json +import re +from typing import Callable, Optional + +from guardrails import OnFailAction +from guardrails.validators import ( + FailResult, + PassResult, + ValidationResult, + Validator, + register_validator, +) +from litellm import completion + +from app.core.config import settings +from app.core.constants import EMPTY_MESSAGE_ERROR, TOPIC_OUT_OF_SCOPE_ERROR +from app.core.validators.llm_utils import ( + JSON_OBJECT_RESPONSE_FORMAT, + supports_response_format, +) + +# Valid scope scores returned by the model; the highest means "clearly in scope". +_VALID_SCORES = (1, 2, 3) +# Cap the response: a single ``{"scope_violation": }`` object is tiny. +_MAX_TOKENS = 50 + +_SCORING_INSTRUCTIONS = ( + "\n\nScore using:\n" + f"{_VALID_SCORES[2]} = clearly within scope (directly matches a topic description)\n" + f"{_VALID_SCORES[1]} = partially related (tangentially related or implicitly within scope)\n" + f"{_VALID_SCORES[0]} = clearly outside scope (no relation to any listed topic)\n" + "\nRespond ONLY with a JSON object in this exact format: " + '{"scope_violation": } where is the integer ' + f"{_VALID_SCORES[0]}, {_VALID_SCORES[1]}, or {_VALID_SCORES[2]}." +) + + +@register_validator(name="topic-relevance-openai", data_type="string") +class TopicRelevanceOpenAI(Validator): + """ + Validates whether a user message is within the defined topic scope + using a direct OpenAI/litellm call. + + The caller supplies the full system prompt. The validator appends + hardcoded scoring and response-format instructions. + + Scores 1–3 where 3 = clearly in scope, 2 = partially related, + 1 = outside scope. Passes when score >= threshold (default 2). + """ + + def __init__( + self, + system_prompt: str, + llm_callable: str = settings.DEFAULT_LLM_CALLABLE, + threshold: int = settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD, + on_fail: Optional[Callable] = OnFailAction.NOOP, + ): + super().__init__(on_fail=on_fail) + + self.llm_callable = llm_callable + self.threshold = threshold + self._invalid_config_reason: Optional[str] = None + self._system_prompt: Optional[str] = None + self._supports_response_format: bool = False + + if not system_prompt or not system_prompt.strip(): + self._invalid_config_reason = "system_prompt is blank or missing" + return + + self._system_prompt = system_prompt.strip() + _SCORING_INSTRUCTIONS + self._supports_response_format = supports_response_format(llm_callable) + + def _validate( + self, value: str, metadata: Optional[dict] = None + ) -> ValidationResult: + if self._invalid_config_reason: + return FailResult(error_message=self._invalid_config_reason) + + if not value or not value.strip(): + return FailResult(error_message=EMPTY_MESSAGE_ERROR) + + try: + kwargs = { + "model": self.llm_callable, + "messages": [ + {"role": "system", "content": self._system_prompt}, + {"role": "user", "content": value}, + ], + "max_tokens": _MAX_TOKENS, + } + if self._supports_response_format: + kwargs["response_format"] = JSON_OBJECT_RESPONSE_FORMAT + + response = completion(**kwargs) + content = response.choices[0].message.content.strip() + except Exception as e: + return FailResult(error_message=f"LLM call failed: {e}") + + try: + text = re.sub(r"```(?:json)?\s*|\s*```", "", content).strip() + match = re.search(r"\{[^{}]*\}", text) + if not match: + raise ValueError("no JSON object found in response") + data = json.loads(match.group()) + score = data.get("scope_violation") + # `type(score) is not int` (not isinstance) deliberately rejects bool, + # which is an int subclass, so `true`/`false` are treated as invalid. + if type(score) is not int or score not in _VALID_SCORES: + raise ValueError(f"unexpected score value: {score!r}") + except Exception as e: + return FailResult( + error_message=f"LLM returned an unparseable response: {e}. Raw: {content!r}" + ) + + if score >= self.threshold: + return PassResult(value=value, metadata={"scope_score": score}) + + return FailResult( + error_message=TOPIC_OUT_OF_SCOPE_ERROR, + metadata={"scope_score": score}, + ) diff --git a/backend/app/core/validators/validators.json b/backend/app/core/validators/validators.json index c6c0fd6..9823bf4 100644 --- a/backend/app/core/validators/validators.json +++ b/backend/app/core/validators/validators.json @@ -30,6 +30,11 @@ "version": "0.1.0", "source": "local" }, + { + "type": "topic_relevance_openai", + "version": "0.1.0", + "source": "local" + }, { "type": "llamaguard_7b", "version": "0.1.0", diff --git a/backend/app/evaluation/topic_relevance/run.py b/backend/app/evaluation/topic_relevance/run.py index f1be3cc..2979d4c 100644 --- a/backend/app/evaluation/topic_relevance/run.py +++ b/backend/app/evaluation/topic_relevance/run.py @@ -5,7 +5,9 @@ import pandas as pd from guardrails.validators import FailResult +from app.core.config import settings from app.core.validators.topic_relevance import TopicRelevance +from app.core.validators.topic_relevance_openai import TopicRelevanceOpenAI from app.evaluation.common.helper import ( Profiler, build_evaluation_report, @@ -16,15 +18,9 @@ BASE_DIR = Path(__file__).resolve().parent.parent DATASETS_DIR = BASE_DIR / "datasets" / "topic_relevance" -OUT_DIR = BASE_DIR / "outputs" / "topic_relevance" +OUTPUTS_DIR = BASE_DIR / "outputs" -DEFAULT_CONFIG = { - "llm_callable": "gpt-4o-mini", - "prompt_schema_version": 1, -} - -# All evaluations defined here -EVALUATIONS = [ +DATASETS = [ { "domain": "education", "dataset": "education-topic-relevance-dataset.csv", @@ -37,30 +33,46 @@ }, ] +BACKENDS = [ + { + "name": "topic_relevance", + "out_dir": OUTPUTS_DIR / "topic_relevance", + "build": lambda tc: TopicRelevance( + topic_config=tc, + prompt_schema_version=1, + llm_callable=settings.DEFAULT_LLM_CALLABLE, + ), + "report_extra": { + "llm_callable": settings.DEFAULT_LLM_CALLABLE, + "prompt_schema_version": 1, + }, + }, + { + "name": "topic_relevance_openai", + "out_dir": OUTPUTS_DIR / "topic_relevance_openai", + "build": lambda tc: TopicRelevanceOpenAI( + system_prompt=tc, + llm_callable=settings.DEFAULT_LLM_CALLABLE, + threshold=settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD, + ), + "report_extra": { + "llm_callable": settings.DEFAULT_LLM_CALLABLE, + "threshold": settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD, + }, + }, +] -def run_evaluation(config: dict) -> None: - """ - Run the topic relevance evaluation for a single domain config. - Loads the dataset and topic config (the plain-text scope definition describing allowed topics, - distinct from DEFAULT_CONFIG which holds model and prompt version settings), runs each input - through the TopicRelevance validator, computes binary and per-category metrics, and writes - prediction CSV and metrics JSON to the output directory. - """ - domain = config["domain"] - dataset_path = DATASETS_DIR / config["dataset"] - topic_config_path = DATASETS_DIR / config["topic_config"] - topic_config = topic_config_path.read_text() +def run_evaluation(dataset: dict, backend: dict) -> None: + domain = dataset["domain"] + topic_config = (DATASETS_DIR / dataset["topic_config"]).read_text() + dataset_path = DATASETS_DIR / dataset["dataset"] + out_dir: Path = backend["out_dir"] - print(f"\nRunning topic relevance evaluation: {domain}") + print(f"\nRunning {backend['name']} evaluation: {domain}") df = pd.read_csv(dataset_path) - - validator = TopicRelevance( - topic_config=topic_config, - prompt_schema_version=DEFAULT_CONFIG["prompt_schema_version"], - llm_callable=DEFAULT_CONFIG["llm_callable"], - ) + validator = backend["build"](topic_config) normalized_df = pd.DataFrame( { @@ -69,7 +81,6 @@ def run_evaluation(config: dict) -> None: "in_scope": df["scope"].apply(lambda x: 1 if x == "IN_SCOPE" else 0), } ) - normalized_df["y_true"] = (1 - normalized_df["in_scope"]).astype(int) with Profiler() as p: @@ -88,7 +99,6 @@ def run_evaluation(config: dict) -> None: ) metrics = compute_binary_metrics(normalized_df["y_true"], normalized_df["y_pred"]) - metrics["category_metrics"] = { str(cat): { "num_samples": int(len(g)), @@ -97,30 +107,27 @@ def run_evaluation(config: dict) -> None: for cat, g in normalized_df.groupby("category", dropna=False) } - OUT_DIR.mkdir(parents=True, exist_ok=True) - - write_csv(normalized_df, OUT_DIR / f"{domain}-predictions.csv") - + out_dir.mkdir(parents=True, exist_ok=True) + write_csv(normalized_df, out_dir / f"{domain}-predictions.csv") write_json( build_evaluation_report( - guardrail="topic_relevance", + guardrail=backend["name"], num_samples=len(normalized_df), profiler=p, dataset=str(dataset_path), - llm_callable=DEFAULT_CONFIG["llm_callable"], - prompt_schema_version=DEFAULT_CONFIG["prompt_schema_version"], + **backend["report_extra"], metrics=metrics, ), - OUT_DIR / f"{domain}-metrics.json", + out_dir / f"{domain}-metrics.json", ) - print(f"Completed {domain} evaluation") + print(f"Completed {backend['name']} {domain} evaluation") def main() -> None: - """Iterate over all entries in EVALUATIONS and run each domain evaluation in sequence.""" - for config in EVALUATIONS: - run_evaluation(config) + for backend in BACKENDS: + for dataset in DATASETS: + run_evaluation(dataset, backend) if __name__ == "__main__": diff --git a/backend/app/schemas/guardrail_config.py b/backend/app/schemas/guardrail_config.py index 968c260..ab179ad 100644 --- a/backend/app/schemas/guardrail_config.py +++ b/backend/app/schemas/guardrail_config.py @@ -24,6 +24,9 @@ from app.core.validators.config.topic_relevance_safety_validator_config import ( TopicRelevanceSafetyValidatorConfig, ) +from app.core.validators.config.topic_relevance_openai_safety_validator_config import ( + TopicRelevanceOpenAISafetyValidatorConfig, +) from app.core.validators.config.llamaguard_7b_safety_validator_config import ( LlamaGuard7BSafetyValidatorConfig, ) @@ -45,6 +48,7 @@ NSFWTextSafetyValidatorConfig, ProfanityFreeSafetyValidatorConfig, TopicRelevanceSafetyValidatorConfig, + TopicRelevanceOpenAISafetyValidatorConfig, ], Field(discriminator="type"), ] diff --git a/backend/app/tests/test_llm_validators.py b/backend/app/tests/test_llm_validators.py index 5834843..58fb9f0 100644 --- a/backend/app/tests/test_llm_validators.py +++ b/backend/app/tests/test_llm_validators.py @@ -6,6 +6,9 @@ from app.core.validators.config.topic_relevance_safety_validator_config import ( TopicRelevanceSafetyValidatorConfig, ) +from app.core.validators.config.topic_relevance_openai_safety_validator_config import ( + TopicRelevanceOpenAISafetyValidatorConfig, +) from app.core.validators.config.llm_critic_safety_validator_config import ( LLMCriticSafetyValidatorConfig, ) @@ -61,6 +64,79 @@ def test_topic_relevance_blank_config_returns_fail_result(): assert "blank" in result.error_message +_SAMPLE_OPENAI_TOPIC_CONFIG = dict( + type="topic_relevance_openai", + configuration="Only answer about cooking.", + llm_callable="gpt-4o-mini", +) + +_TOPIC_RELEVANCE_OPENAI_SETTINGS_PATH = ( + "app.core.validators.config.topic_relevance_openai_safety_validator_config.settings" +) + + +def test_topic_relevance_openai_build_raises_when_openai_key_missing(): + config = TopicRelevanceOpenAISafetyValidatorConfig(**_SAMPLE_OPENAI_TOPIC_CONFIG) + + with patch(_TOPIC_RELEVANCE_OPENAI_SETTINGS_PATH) as mock_settings: + mock_settings.OPENAI_API_KEY = None + + with pytest.raises(ValueError) as exc: + config.build() + + assert "OPENAI_API_KEY" in str(exc.value) + assert "not configured" in str(exc.value) + + +def test_topic_relevance_openai_build_proceeds_when_openai_key_present(): + config = TopicRelevanceOpenAISafetyValidatorConfig(**_SAMPLE_OPENAI_TOPIC_CONFIG) + + with patch(_TOPIC_RELEVANCE_OPENAI_SETTINGS_PATH) as mock_settings, patch( + "app.core.validators.config.topic_relevance_openai_safety_validator_config.TopicRelevanceOpenAI" + ) as mock_validator: + mock_settings.OPENAI_API_KEY = "sk-test-key" + config.build() + + mock_validator.assert_called_once() + + +def test_topic_relevance_openai_blank_config_returns_fail_result(): + config = TopicRelevanceOpenAISafetyValidatorConfig( + **{**_SAMPLE_OPENAI_TOPIC_CONFIG, "configuration": None} + ) + + with patch(_TOPIC_RELEVANCE_OPENAI_SETTINGS_PATH) as mock_settings, patch( + "app.core.validators.llm_utils.get_supported_openai_params", + return_value=[], + ): + mock_settings.OPENAI_API_KEY = "sk-test-key" + validator = config.build() + + result = validator._validate("some input") + assert isinstance(result, FailResult) + assert "blank" in result.error_message + + +def test_topic_relevance_openai_default_threshold_is_2(): + config = TopicRelevanceOpenAISafetyValidatorConfig(**_SAMPLE_OPENAI_TOPIC_CONFIG) + assert config.threshold == 2 + + +def test_topic_relevance_openai_custom_threshold_forwarded_to_validator(): + config = TopicRelevanceOpenAISafetyValidatorConfig( + **{**_SAMPLE_OPENAI_TOPIC_CONFIG, "threshold": 3} + ) + + with patch(_TOPIC_RELEVANCE_OPENAI_SETTINGS_PATH) as mock_settings, patch( + "app.core.validators.config.topic_relevance_openai_safety_validator_config.TopicRelevanceOpenAI" + ) as mock_validator: + mock_settings.OPENAI_API_KEY = "sk-test-key" + config.build() + + call_kwargs = mock_validator.call_args[1] + assert call_kwargs["threshold"] == 3 + + _SAMPLE_CONFIG = dict( type="llm_critic", metrics={ diff --git a/backend/app/tests/test_validate_with_guard.py b/backend/app/tests/test_validate_with_guard.py index 2956512..9bdf4f1 100644 --- a/backend/app/tests/test_validate_with_guard.py +++ b/backend/app/tests/test_validate_with_guard.py @@ -270,6 +270,77 @@ def test_resolve_validator_configs_uses_inline_topic_relevance_without_lookup(): mock_get.assert_not_called() +def test_resolve_validator_configs_topic_relevance_openai_from_config_id(): + topic_relevance_id = str(uuid4()) + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="test", + validators=[ + { + "type": "topic_relevance_openai", + "topic_relevance_config_id": topic_relevance_id, + } + ], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.topic_relevance_crud.get") as mock_get: + mock_get.return_value = MagicMock( + configuration="Healthcare topic scope text", + ) + _resolve_validator_configs(payload, mock_session) + + validator = payload.validators[0] + assert validator.configuration == "Healthcare topic scope text" + mock_get.assert_called_once_with( + session=mock_session, + id=validator.topic_relevance_config_id, + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + ) + + +def test_resolve_validator_configs_skips_topic_relevance_openai_lookup_when_no_config_id(): + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="test", + validators=[{"type": "topic_relevance_openai"}], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.topic_relevance_crud.get") as mock_get: + _resolve_validator_configs(payload, mock_session) + + mock_get.assert_not_called() + + +def test_resolve_validator_configs_uses_inline_topic_relevance_openai_without_lookup(): + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="test", + validators=[ + { + "type": "topic_relevance_openai", + "configuration": "inline openai config", + } + ], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.topic_relevance_crud.get") as mock_get: + _resolve_validator_configs(payload, mock_session) + + validator = payload.validators[0] + assert validator.configuration == "inline openai config" + mock_get.assert_not_called() + + def _build_mock_guard_with_fail_result(validator_name: str, error_message: str): mock_log = MagicMock() mock_log.validator_name = validator_name diff --git a/backend/app/tests/validators/test_topic_relevance_openai.py b/backend/app/tests/validators/test_topic_relevance_openai.py new file mode 100644 index 0000000..54ba6d7 --- /dev/null +++ b/backend/app/tests/validators/test_topic_relevance_openai.py @@ -0,0 +1,302 @@ +from unittest.mock import MagicMock, patch + +import pytest +from guardrails.validators import FailResult, PassResult + +from app.core.validators.topic_relevance_openai import TopicRelevanceOpenAI + +TOPIC_CONFIG = "Only answer questions about cooking and recipes." + + +def _make_llm_response(json_text: str) -> MagicMock: + choice = MagicMock() + choice.message.content = json_text + result = MagicMock() + result.choices = [choice] + return result + + +@pytest.fixture +def validator(): + with patch( + "app.core.validators.llm_utils.get_supported_openai_params", + return_value=["response_format"], + ): + return TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) + + +# --------------------------------------------------------------------------- +# PassResult — score >= threshold (2) +# --------------------------------------------------------------------------- + + +def test_passes_when_score_is_3(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') + result = validator._validate("How do I make pasta?") + + assert isinstance(result, PassResult) + assert result.metadata["scope_score"] == 3 + + +def test_passes_when_score_equals_threshold(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 2}') + result = validator._validate("What is cooking roughly about?") + + assert isinstance(result, PassResult) + assert result.metadata["scope_score"] == 2 + + +# --------------------------------------------------------------------------- +# FailResult — score < threshold (1) +# --------------------------------------------------------------------------- + + +def test_fails_when_score_is_1(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 1}') + result = validator._validate("What is the latest cricket score?") + + assert isinstance(result, FailResult) + assert "outside the allowed topic scope" in result.error_message + assert result.metadata["scope_score"] == 1 + + +# --------------------------------------------------------------------------- +# Custom threshold +# --------------------------------------------------------------------------- + + +def test_custom_threshold_of_3_fails_on_score_2(): + with patch( + "app.core.validators.llm_utils.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG, threshold=3) + + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 2}') + result = v._validate("Something vaguely food related") + + assert isinstance(result, FailResult) + assert result.metadata["scope_score"] == 2 + + +def test_custom_threshold_of_1_passes_on_score_1(): + with patch( + "app.core.validators.llm_utils.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG, threshold=1) + + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 1}') + result = v._validate("Cricket scores") + + assert isinstance(result, PassResult) + + +# --------------------------------------------------------------------------- +# Guard inputs +# --------------------------------------------------------------------------- + + +def test_fails_when_value_is_empty(validator): + result = validator._validate("") + + assert isinstance(result, FailResult) + assert "Empty message" in result.error_message + + +def test_fails_when_value_is_whitespace(validator): + result = validator._validate(" ") + + assert isinstance(result, FailResult) + assert "Empty message" in result.error_message + + +def test_fails_when_system_prompt_is_blank(): + with patch( + "app.core.validators.llm_utils.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(system_prompt="") + + result = v._validate("Some input") + + assert isinstance(result, FailResult) + assert "blank" in result.error_message + + +def test_fails_when_system_prompt_is_whitespace_only(): + with patch( + "app.core.validators.llm_utils.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(system_prompt=" ") + + result = v._validate("Some input") + + assert isinstance(result, FailResult) + assert "blank" in result.error_message + + +# --------------------------------------------------------------------------- +# LLM error handling +# --------------------------------------------------------------------------- + + +def test_fails_gracefully_when_llm_raises(validator): + with patch( + "app.core.validators.topic_relevance_openai.completion", + side_effect=Exception("network timeout"), + ): + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "LLM call failed" in result.error_message + assert "network timeout" in result.error_message + + +def test_fails_gracefully_when_response_is_not_json(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response("Sure, this looks great!") + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "unparseable" in result.error_message + + +def test_fails_gracefully_when_score_key_is_missing(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"result": "yes"}') + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "unparseable" in result.error_message + + +def test_fails_gracefully_when_score_is_out_of_range(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 5}') + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "unparseable" in result.error_message + + +def test_fails_gracefully_when_score_is_a_string(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": "high"}') + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "unparseable" in result.error_message + + +def test_passes_when_response_wrapped_in_markdown_fence(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response( + '```json\n{"scope_violation": 3}\n```' + ) + result = validator._validate("How do I make pasta?") + + assert isinstance(result, PassResult) + assert result.metadata["scope_score"] == 3 + + +def test_passes_when_response_has_surrounding_prose(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response( + 'Sure! Here is my evaluation: {"scope_violation": 2}' + ) + result = validator._validate("Something vaguely food related") + + assert isinstance(result, PassResult) + assert result.metadata["scope_score"] == 2 + + +def test_fails_when_score_is_boolean(validator): + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": true}') + result = validator._validate("How do I bake bread?") + + assert isinstance(result, FailResult) + assert "unparseable" in result.error_message + + +# --------------------------------------------------------------------------- +# response_format forwarding +# --------------------------------------------------------------------------- + + +def test_response_format_passed_when_supported(): + with patch( + "app.core.validators.llm_utils.get_supported_openai_params", + return_value=["response_format"], + ): + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) + + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') + v._validate("How do I make pasta?") + + _, kwargs = mock_llm.call_args + assert kwargs.get("response_format") == {"type": "json_object"} + + +def test_response_format_omitted_when_not_supported(): + with patch( + "app.core.validators.llm_utils.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) + + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') + v._validate("How do I make pasta?") + + _, kwargs = mock_llm.call_args + assert "response_format" not in kwargs + + +def test_response_format_omitted_when_litellm_check_fails(): + with patch( + "app.core.validators.llm_utils.get_supported_openai_params", + side_effect=Exception("litellm unavailable"), + ): + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) + + with patch("app.core.validators.topic_relevance_openai.completion") as mock_llm: + mock_llm.return_value = _make_llm_response('{"scope_violation": 3}') + v._validate("How do I make pasta?") + + _, kwargs = mock_llm.call_args + assert "response_format" not in kwargs + + +# --------------------------------------------------------------------------- +# Prompt template content +# --------------------------------------------------------------------------- + + +def test_system_prompt_contains_topic_config(): + with patch( + "app.core.validators.llm_utils.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) + + assert TOPIC_CONFIG in v._system_prompt + + +def test_system_prompt_contains_json_instruction(): + with patch( + "app.core.validators.llm_utils.get_supported_openai_params", + return_value=[], + ): + v = TopicRelevanceOpenAI(system_prompt=TOPIC_CONFIG) + + assert "scope_violation" in v._system_prompt + assert "JSON" in v._system_prompt