Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions backend/app/api/routes/guardrails.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from uuid import UUID
import uuid
from uuid import UUID

from fastapi import APIRouter
from guardrails.guard import Guard
Expand All @@ -14,21 +14,24 @@
REPHRASE_ON_FAIL_PREFIX,
)
from app.core.enum import ValidatorType
from app.core.guardrail_controller import build_guard, get_validator_config_models
from app.core.exception_handlers import _safe_error_message
from app.core.guardrail_controller import build_guard, get_validator_config_models
from app.core.validators.config.ban_list_safety_validator_config import (
BanListSafetyValidatorConfig,
)
from app.crud.ban_list import ban_list_crud
from app.crud.topic_relevance import topic_relevance_crud
from app.crud.request_log import RequestLogCrud
from app.crud.validator_log import ValidatorLogCrud
from app.core.validators.config.topic_relevance_openai_safety_validator_config import (
TopicRelevanceOpenAISafetyValidatorConfig,
)
from app.core.validators.config.topic_relevance_safety_validator_config import (
TopicRelevanceSafetyValidatorConfig,
)
from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse
from app.crud.ban_list import ban_list_crud
from app.crud.request_log import RequestLogCrud
from app.crud.topic_relevance import topic_relevance_crud
from app.crud.validator_log import ValidatorLogCrud
from app.models.logging.request_log import RequestLogUpdate, RequestStatus
from app.models.logging.validator_log import ValidatorLog, ValidatorOutcome
from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse
from app.utils import APIResponse, load_description

router = APIRouter(prefix="/guardrails", tags=["guardrails"])
Expand Down Expand Up @@ -103,6 +106,7 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N
Resolves config-backed references for all validators in-place before guard execution:
- BanList: fetches banned_words from the stored BanList when not provided inline.
- TopicRelevance: fetches configuration and prompt_schema_version from stored config.
- TopicRelevanceOpenAI: fetches configuration from stored config.
"""
for validator in payload.validators:
if isinstance(validator, BanListSafetyValidatorConfig):
Expand All @@ -115,7 +119,13 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N
)
validator.banned_words = ban_list.banned_words

elif isinstance(validator, TopicRelevanceSafetyValidatorConfig):
elif isinstance(
validator,
(
TopicRelevanceSafetyValidatorConfig,
TopicRelevanceOpenAISafetyValidatorConfig,
),
):
if validator.topic_relevance_config_id is not None:
config = topic_relevance_crud.get(
session=session,
Expand All @@ -124,7 +134,9 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N
project_id=payload.project_id,
)
validator.configuration = config.configuration
validator.prompt_schema_version = config.prompt_schema_version
# Only the LLMCritic-backed variant carries a prompt schema version.
if isinstance(validator, TopicRelevanceSafetyValidatorConfig):
validator.prompt_schema_version = config.prompt_schema_version


def _validate_with_guard(
Expand Down
2 changes: 2 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class Settings(BaseSettings):
KAAPI_AUTH_TIMEOUT: int
CORE_DIR: ClassVar[Path] = Path(__file__).resolve().parent
OPENAI_API_KEY: str | None = None
DEFAULT_LLM_CALLABLE: str = "gpt-4o-mini"
TOPIC_RELEVANCE_OPENAI_THRESHOLD: int = 2

SLUR_LIST_FILENAME: ClassVar[str] = "curated_slurlist_hi_en.csv"

Expand Down
4 changes: 4 additions & 0 deletions backend/app/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
f"{LLM_CRITIC_ERROR_MESSAGE} Please rephrase without unsafe content."
)

# Topic relevance validators (shared by the LLMCritic- and litellm-backed variants)
EMPTY_MESSAGE_ERROR = "Empty message."
TOPIC_OUT_OF_SCOPE_ERROR = "Input is outside the allowed topic scope."

VALIDATOR_CONFIG_SYSTEM_FIELDS = {
"organization_id",
"project_id",
Expand Down
1 change: 1 addition & 0 deletions backend/app/core/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ValidatorType(Enum):
GenderAssumptionBias = "gender_assumption_bias"
BanList = "ban_list"
TopicRelevance = "topic_relevance"
TopicRelevanceOpenAI = "topic_relevance_openai"
LLMCritic = "llm_critic"
LlamaGuard7B = "llamaguard_7b"
ProfanityFree = "profanity_free"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Literal, Optional
from uuid import UUID

from pydantic import Field

from app.core.config import settings
from app.core.validators.config.base_validator_config import BaseValidatorConfig
from app.core.validators.topic_relevance_openai import TopicRelevanceOpenAI


class TopicRelevanceOpenAISafetyValidatorConfig(BaseValidatorConfig):
type: Literal["topic_relevance_openai"]
configuration: Optional[str] = None
llm_callable: str = settings.DEFAULT_LLM_CALLABLE
threshold: int = Field(
default=settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD, ge=1, le=3
)
topic_relevance_config_id: Optional[UUID] = None

def build(self):
if not settings.OPENAI_API_KEY:
raise ValueError(
"OPENAI_API_KEY is not configured. "
"Topic relevance (OpenAI) validation requires an OpenAI API key."
)
return TopicRelevanceOpenAI(
system_prompt=self.configuration or "",
llm_callable=self.llm_callable,
threshold=self.threshold,
on_fail=self.resolve_on_fail(),
)
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from typing import Literal, Optional
from uuid import UUID

from pydantic import model_validator

from app.core.config import settings
from app.core.validators.topic_relevance import TopicRelevance
from app.core.validators.config.base_validator_config import BaseValidatorConfig
from app.core.validators.topic_relevance import TopicRelevance


class TopicRelevanceSafetyValidatorConfig(BaseValidatorConfig):
type: Literal["topic_relevance"]
configuration: Optional[str] = None
prompt_schema_version: Optional[int] = None
llm_callable: str = "gpt-4o-mini"
llm_callable: str = settings.DEFAULT_LLM_CALLABLE
topic_relevance_config_id: Optional[UUID] = None

def build(self):
Expand Down
15 changes: 15 additions & 0 deletions backend/app/core/validators/llm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from litellm import get_supported_openai_params

# Passed to litellm/OpenAI to force a strict JSON object response.
JSON_OBJECT_RESPONSE_FORMAT = {"type": "json_object"}


def supports_response_format(model: str) -> bool:
"""Return True if the given model supports the OpenAI ``response_format`` param.

Falls back to False if litellm cannot resolve the model's capabilities.
"""
try:
return "response_format" in (get_supported_openai_params(model=model) or [])
except Exception:
return False
36 changes: 18 additions & 18 deletions backend/app/core/validators/topic_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@
from pathlib import Path
from typing import Callable, Optional

from guardrails.hub import LLMCritic
from guardrails import OnFailAction
from guardrails.hub import LLMCritic
from guardrails.validators import (
FailResult,
PassResult,
ValidationResult,
Validator,
register_validator,
ValidationResult,
)
from guardrails.validators import FailResult, PassResult

from app.core.config import settings
from app.core.constants import EMPTY_MESSAGE_ERROR, TOPIC_OUT_OF_SCOPE_ERROR
from app.core.validators.llm_utils import (
JSON_OBJECT_RESPONSE_FORMAT,
supports_response_format,
)

# This should be present in all prompt templates to indicate where the topic configuration will be inserted
_PROMPT_PLACEHOLDER = "{{TOPIC_CONFIGURATION}}"
Expand Down Expand Up @@ -62,7 +69,7 @@ def __init__(
self,
topic_config: str,
prompt_schema_version: int = 1,
llm_callable: str = "gpt-4o-mini",
llm_callable: str = settings.DEFAULT_LLM_CALLABLE,
on_fail: Optional[Callable] = OnFailAction.NOOP,
):
"""Build the LLMCritic with a scope_violation metric from the topic configuration."""
Expand All @@ -78,15 +85,6 @@ def __init__(
self._critic = None
return

try:
from litellm import get_supported_openai_params

supports_response_format = "response_format" in (
get_supported_openai_params(model=llm_callable) or []
)
except Exception:
supports_response_format = False

self._critic = LLMCritic(
metrics={
"scope_violation": {
Expand All @@ -101,19 +99,21 @@ def __init__(
llm_callable=llm_callable,
on_fail=on_fail,
**(
{"llm_kwargs": {"response_format": {"type": "json_object"}}}
if supports_response_format
{"llm_kwargs": {"response_format": JSON_OBJECT_RESPONSE_FORMAT}}
if supports_response_format(llm_callable)
else {}
),
)

def _validate(self, value: str, metadata: dict = None) -> ValidationResult:
def _validate(
self, value: str, metadata: Optional[dict] = None
) -> ValidationResult:
"""Run the LLMCritic and return a PassResult or FailResult with the scope score."""
if self._invalid_config_reason:
return FailResult(error_message=self._invalid_config_reason)

if not value or not value.strip():
return FailResult(error_message="Empty message.")
return FailResult(error_message=EMPTY_MESSAGE_ERROR)

try:
result = self._critic.validate(value, metadata)
Expand All @@ -127,7 +127,7 @@ def _validate(self, value: str, metadata: dict = None) -> ValidationResult:

if isinstance(result, FailResult):
return FailResult(
error_message="Input is outside the allowed topic scope.",
error_message=TOPIC_OUT_OF_SCOPE_ERROR,
metadata={"scope_score": score},
)

Expand Down
123 changes: 123 additions & 0 deletions backend/app/core/validators/topic_relevance_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations

import json
import re
from typing import Callable, Optional

from guardrails import OnFailAction
from guardrails.validators import (
FailResult,
PassResult,
ValidationResult,
Validator,
register_validator,
)
from litellm import completion

from app.core.config import settings
from app.core.constants import EMPTY_MESSAGE_ERROR, TOPIC_OUT_OF_SCOPE_ERROR
from app.core.validators.llm_utils import (
JSON_OBJECT_RESPONSE_FORMAT,
supports_response_format,
)

# Valid scope scores returned by the model; the highest means "clearly in scope".
_VALID_SCORES = (1, 2, 3)
# Cap the response: a single ``{"scope_violation": <score>}`` object is tiny.
_MAX_TOKENS = 50

_SCORING_INSTRUCTIONS = (
"\n\nScore using:\n"
f"{_VALID_SCORES[2]} = clearly within scope (directly matches a topic description)\n"
f"{_VALID_SCORES[1]} = partially related (tangentially related or implicitly within scope)\n"
f"{_VALID_SCORES[0]} = clearly outside scope (no relation to any listed topic)\n"
"\nRespond ONLY with a JSON object in this exact format: "
'{"scope_violation": <score>} where <score> is the integer '
f"{_VALID_SCORES[0]}, {_VALID_SCORES[1]}, or {_VALID_SCORES[2]}."
)


@register_validator(name="topic-relevance-openai", data_type="string")
class TopicRelevanceOpenAI(Validator):
"""
Validates whether a user message is within the defined topic scope
using a direct OpenAI/litellm call.

The caller supplies the full system prompt. The validator appends
hardcoded scoring and response-format instructions.

Scores 1–3 where 3 = clearly in scope, 2 = partially related,
1 = outside scope. Passes when score >= threshold (default 2).
"""

def __init__(
self,
system_prompt: str,
llm_callable: str = settings.DEFAULT_LLM_CALLABLE,
threshold: int = settings.TOPIC_RELEVANCE_OPENAI_THRESHOLD,
on_fail: Optional[Callable] = OnFailAction.NOOP,
):
super().__init__(on_fail=on_fail)

self.llm_callable = llm_callable
self.threshold = threshold
self._invalid_config_reason: Optional[str] = None
self._system_prompt: Optional[str] = None
self._supports_response_format: bool = False

if not system_prompt or not system_prompt.strip():
self._invalid_config_reason = "system_prompt is blank or missing"
return

self._system_prompt = system_prompt.strip() + _SCORING_INSTRUCTIONS
self._supports_response_format = supports_response_format(llm_callable)

def _validate(
self, value: str, metadata: Optional[dict] = None
) -> ValidationResult:
if self._invalid_config_reason:
return FailResult(error_message=self._invalid_config_reason)

if not value or not value.strip():
return FailResult(error_message=EMPTY_MESSAGE_ERROR)

try:
kwargs = {
"model": self.llm_callable,
"messages": [
{"role": "system", "content": self._system_prompt},
{"role": "user", "content": value},
],
"max_tokens": _MAX_TOKENS,
}
if self._supports_response_format:
kwargs["response_format"] = JSON_OBJECT_RESPONSE_FORMAT

response = completion(**kwargs)
content = response.choices[0].message.content.strip()
except Exception as e:
return FailResult(error_message=f"LLM call failed: {e}")

try:
text = re.sub(r"```(?:json)?\s*|\s*```", "", content).strip()
match = re.search(r"\{[^{}]*\}", text)
if not match:
raise ValueError("no JSON object found in response")
data = json.loads(match.group())
score = data.get("scope_violation")
# `type(score) is not int` (not isinstance) deliberately rejects bool,
# which is an int subclass, so `true`/`false` are treated as invalid.
if type(score) is not int or score not in _VALID_SCORES:
raise ValueError(f"unexpected score value: {score!r}")
except Exception as e:
return FailResult(
error_message=f"LLM returned an unparseable response: {e}. Raw: {content!r}"
)
Comment thread
rkritika1508 marked this conversation as resolved.

if score >= self.threshold:
return PassResult(value=value, metadata={"scope_score": score})

return FailResult(
error_message=TOPIC_OUT_OF_SCOPE_ERROR,
metadata={"scope_score": score},
)
5 changes: 5 additions & 0 deletions backend/app/core/validators/validators.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
"version": "0.1.0",
"source": "local"
},
{
"type": "topic_relevance_openai",
"version": "0.1.0",
"source": "local"
},
{
"type": "llamaguard_7b",
"version": "0.1.0",
Expand Down
Loading