diff --git a/.env.example b/.env.example index bffbdd9..2e6d3eb 100644 --- a/.env.example +++ b/.env.example @@ -21,12 +21,10 @@ SENTRY_DSN= DOCKER_IMAGE_BACKEND=kaapi-guardrails-backend -# Callback Timeouts (in seconds) -CALLBACK_CONNECT_TIMEOUT=3 -CALLBACK_READ_TIMEOUT=10 - # require as a env if you want to use doc transformation OPENAI_API_KEY="" GUARDRAILS_HUB_API_KEY="" # SHA-256 hex digest of your bearer token (64 lowercase hex chars) AUTH_TOKEN="" +KAAPI_AUTH_URL="" +KAAPI_AUTH_TIMEOUT=5 diff --git a/.env.test.example b/.env.test.example index b275aec..bf95f1a 100644 --- a/.env.test.example +++ b/.env.test.example @@ -21,12 +21,10 @@ SENTRY_DSN= DOCKER_IMAGE_BACKEND=kaapi-guardrails-backend -# Callback Timeouts (in seconds) -CALLBACK_CONNECT_TIMEOUT=3 -CALLBACK_READ_TIMEOUT=10 - # require as a env if you want to use doc transformation OPENAI_API_KEY="" GUARDRAILS_HUB_API_KEY="" # SHA-256 hex digest of your bearer token (64 lowercase hex chars) AUTH_TOKEN="" +KAAPI_AUTH_URL="" +KAAPI_AUTH_TIMEOUT=5 diff --git a/.github/workflows/continuous_integration.yml b/.github/workflows/continuous_integration.yml index 36c126e..8738d11 100644 --- a/.github/workflows/continuous_integration.yml +++ b/.github/workflows/continuous_integration.yml @@ -46,7 +46,7 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v7 with: - version: "0.4.15" + version: "0.7.2" enable-cache: true - name: Install dependencies diff --git a/backend/app/alembic/versions/001_added_request_log.py b/backend/app/alembic/versions/001_added_request_log.py index aec078f..e3a26d7 100644 --- a/backend/app/alembic/versions/001_added_request_log.py +++ b/backend/app/alembic/versions/001_added_request_log.py @@ -1,17 +1,17 @@ """Added request log Revision ID: 001 -Revises: +Revises: Create Date: 2026-01-07 09:42:54.128852 """ + from typing import Sequence, Union from alembic import op import sqlalchemy as sa import sqlmodel - # revision identifiers, used by Alembic. revision: str = "001" down_revision: str | None = None @@ -23,6 +23,8 @@ def upgrade() -> None: op.create_table( "request_log", sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("organization_id", sa.Integer(), nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), sa.Column("request_id", sa.Uuid(), nullable=False), sa.Column("response_id", sa.Uuid(), nullable=True), sa.Column( diff --git a/backend/app/alembic/versions/002_added_validator_log.py b/backend/app/alembic/versions/002_added_validator_log.py index a0227b4..da0a72b 100644 --- a/backend/app/alembic/versions/002_added_validator_log.py +++ b/backend/app/alembic/versions/002_added_validator_log.py @@ -5,13 +5,13 @@ Create Date: 2026-01-07 09:43:48.002351 """ + from typing import Sequence, Union from alembic import op import sqlalchemy as sa import sqlmodel - # revision identifiers, used by Alembic. revision: str = "002" down_revision: str = "001" @@ -23,6 +23,8 @@ def upgrade() -> None: op.create_table( "validator_log", sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("organization_id", sa.Integer(), nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), sa.Column("request_id", sa.Uuid(), nullable=False), sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("input", sqlmodel.sql.sqltypes.AutoString(), nullable=False), diff --git a/backend/app/alembic/versions/003_added_validator_config.py b/backend/app/alembic/versions/003_added_validator_config.py index 2d52246..6d35e03 100644 --- a/backend/app/alembic/versions/003_added_validator_config.py +++ b/backend/app/alembic/versions/003_added_validator_config.py @@ -5,6 +5,7 @@ Create Date: 2026-02-05 09:42:54.128852 """ + from typing import Sequence, Union from alembic import op diff --git a/backend/app/alembic/versions/004_added_log_indexes.py b/backend/app/alembic/versions/004_added_log_indexes.py index 2cafa6f..6683fe9 100644 --- a/backend/app/alembic/versions/004_added_log_indexes.py +++ b/backend/app/alembic/versions/004_added_log_indexes.py @@ -5,11 +5,11 @@ Create Date: 2026-02-11 10:45:00.000000 """ + from typing import Sequence, Union from alembic import op - # revision identifiers, used by Alembic. revision: str = "004" down_revision: str = "003" @@ -21,11 +21,19 @@ def upgrade() -> None: op.create_index("idx_request_log_request_id", "request_log", ["request_id"]) op.create_index("idx_request_log_status", "request_log", ["status"]) op.create_index("idx_request_log_inserted_at", "request_log", ["inserted_at"]) + op.create_index( + "idx_request_log_organization_id", "request_log", ["organization_id"] + ) + op.create_index("idx_request_log_project_id", "request_log", ["project_id"]) op.create_index("idx_validator_log_request_id", "validator_log", ["request_id"]) op.create_index("idx_validator_log_inserted_at", "validator_log", ["inserted_at"]) op.create_index("idx_validator_log_outcome", "validator_log", ["outcome"]) op.create_index("idx_validator_log_name", "validator_log", ["name"]) + op.create_index( + "idx_validator_log_organization_id", "validator_log", ["organization_id"] + ) + op.create_index("idx_validator_log_project_id", "validator_log", ["project_id"]) def downgrade() -> None: @@ -33,7 +41,11 @@ def downgrade() -> None: op.drop_index("idx_validator_log_request_id", table_name="validator_log") op.drop_index("idx_validator_log_outcome", table_name="validator_log") op.drop_index("idx_validator_log_name", table_name="validator_log") + op.drop_index("idx_validator_log_project_id", table_name="validator_log") + op.drop_index("idx_validator_log_organization_id", table_name="validator_log") op.drop_index("idx_request_log_inserted_at", table_name="request_log") op.drop_index("idx_request_log_status", table_name="request_log") op.drop_index("idx_request_log_request_id", table_name="request_log") + op.drop_index("idx_request_log_organization_id", table_name="request_log") + op.drop_index("idx_request_log_project_id", table_name="request_log") diff --git a/backend/app/alembic/versions/005_added_banlist_config.py b/backend/app/alembic/versions/005_added_banlist_config.py new file mode 100644 index 0000000..ee6632c --- /dev/null +++ b/backend/app/alembic/versions/005_added_banlist_config.py @@ -0,0 +1,62 @@ +"""Added ban_list table + +Revision ID: 005 +Revises: 004 +Create Date: 2026-02-05 09:42:54.128852 + +""" + +from typing import Sequence, Union + +from alembic import op +from sqlalchemy.dialects import postgresql +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "005" +down_revision = "004" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "ban_list", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=False), + sa.Column("organization_id", sa.Integer(), nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("domain", sa.String(), nullable=False), + sa.Column("is_public", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column( + "banned_words", + postgresql.ARRAY(sa.String(length=100)), + nullable=False, + server_default="{}", + ), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "name", "organization_id", "project_id", name="uq_ban_list_name_org_project" + ), + sa.CheckConstraint( + "coalesce(array_length(banned_words, 1), 0) <= 1000", + name="ck_ban_list_banned_words_max_items", + ), + ) + + op.create_index("idx_ban_list_organization", "ban_list", ["organization_id"]) + op.create_index("idx_ban_list_project", "ban_list", ["project_id"]) + op.create_index("idx_ban_list_domain", "ban_list", ["domain"]) + op.create_index( + "idx_ban_list_is_public_true", + "ban_list", + ["is_public"], + postgresql_where=sa.text("is_public = true"), + ) + + +def downgrade() -> None: + op.drop_table("ban_list") diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index ec78089..8426bee 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,9 +1,12 @@ from collections.abc import Generator +from dataclasses import dataclass from typing import Annotated + import hashlib import secrets +import httpx -from fastapi import Depends, HTTPException, status, Security +from fastapi import Depends, Header, HTTPException, Security, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlmodel import Session @@ -17,6 +20,9 @@ def get_db() -> Generator[Session, None, None]: SessionDep = Annotated[Session, Depends(get_db)] + + +# Static bearer token auth for internal routes. security = HTTPBearer(auto_error=False) @@ -24,27 +30,92 @@ def _hash_token(token: str) -> str: return hashlib.sha256(token.encode("utf-8")).hexdigest() +def _unauthorized(detail: str) -> HTTPException: + return HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + ) + + def verify_bearer_token( credentials: Annotated[ HTTPAuthorizationCredentials | None, Security(security), - ] -): + ], +) -> bool: if credentials is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Missing Authorization header", - ) + raise _unauthorized("Missing Authorization header") if not secrets.compare_digest( - _hash_token(credentials.credentials), settings.AUTH_TOKEN + _hash_token(credentials.credentials), + settings.AUTH_TOKEN, ): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid authorization token", - ) + raise _unauthorized("Invalid authorization token") return True AuthDep = Annotated[bool, Depends(verify_bearer_token)] + + +# Multitenant auth context resolved from X-API-KEY. +@dataclass +class TenantContext: + organization_id: int + project_id: int + + +def _fetch_tenant_from_backend(token: str) -> TenantContext: + if not settings.KAAPI_AUTH_URL: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="KAAPI_AUTH_URL is not configured", + ) + + try: + response = httpx.get( + f"{settings.KAAPI_AUTH_URL}/apikeys/verify", + headers={"X-API-KEY": f"ApiKey {token}"}, + timeout=settings.KAAPI_AUTH_TIMEOUT, + ) + except httpx.RequestError: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Auth service unavailable", + ) + + if response.status_code != 200: + raise _unauthorized("Invalid API key") + + data = response.json() + if not isinstance(data, dict) or data.get("success") is not True: + raise _unauthorized("Invalid API key") + + record = data.get("data") + if not isinstance(record, dict): + raise _unauthorized("Invalid API key") + + organization_id = record.get("organization_id") + project_id = record.get("project_id") + if not isinstance(organization_id, int) or not isinstance(project_id, int): + raise _unauthorized("Invalid API key") + + return TenantContext( + organization_id=organization_id, + project_id=project_id, + ) + + +def validate_multitenant_key( + x_api_key: Annotated[str | None, Header(alias="X-API-KEY")] = None, +) -> TenantContext: + if not x_api_key or not x_api_key.strip(): + raise _unauthorized("Missing X-API-KEY header") + + return _fetch_tenant_from_backend(x_api_key.strip()) + + +MultitenantAuthDep = Annotated[ + TenantContext, + Depends(validate_multitenant_key), +] diff --git a/backend/app/api/main.py b/backend/app/api/main.py index bf78ade..858fbb2 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,11 +1,12 @@ from fastapi import APIRouter -from app.api.routes import utils, guardrails, validator_configs +from app.api.routes import ban_lists, guardrails, validator_configs, utils api_router = APIRouter() -api_router.include_router(utils.router) +api_router.include_router(ban_lists.router) api_router.include_router(guardrails.router) api_router.include_router(validator_configs.router) +api_router.include_router(utils.router) # if settings.ENVIRONMENT == "local": # api_router.include_router(private.router) diff --git a/backend/app/api/routes/ban_lists.py b/backend/app/api/routes/ban_lists.py new file mode 100644 index 0000000..279963f --- /dev/null +++ b/backend/app/api/routes/ban_lists.py @@ -0,0 +1,88 @@ +from typing import Annotated, Optional +from uuid import UUID + +from fastapi import APIRouter, Query + +from app.api.deps import MultitenantAuthDep, SessionDep +from app.crud.ban_list import ban_list_crud +from app.schemas.ban_list import BanListCreate, BanListUpdate, BanListResponse +from app.utils import APIResponse + +router = APIRouter(prefix="/guardrails/ban_lists", tags=["Ban Lists"]) + + +@router.post("/", response_model=APIResponse[BanListResponse]) +def create_ban_list( + payload: BanListCreate, + session: SessionDep, + auth: MultitenantAuthDep, +): + ban_list = ban_list_crud.create( + session, payload, auth.organization_id, auth.project_id + ) + return APIResponse.success_response(data=ban_list) + + +@router.get("/", response_model=APIResponse[list[BanListResponse]]) +def list_ban_lists( + session: SessionDep, + auth: MultitenantAuthDep, + domain: Optional[str] = None, + offset: Annotated[int, Query(ge=0)] = 0, + limit: Annotated[int | None, Query(ge=1, le=100)] = None, +): + ban_lists = ban_list_crud.list( + session, + auth.organization_id, + auth.project_id, + domain, + offset=offset, + limit=limit, + ) + return APIResponse.success_response(data=ban_lists) + + +@router.get("/{id}", response_model=APIResponse[BanListResponse]) +def get_ban_list( + id: UUID, + session: SessionDep, + auth: MultitenantAuthDep, +): + obj = ban_list_crud.get(session, id, auth.organization_id, auth.project_id) + return APIResponse.success_response(data=obj) + + +@router.patch("/{id}", response_model=APIResponse[BanListResponse]) +def update_ban_list( + id: UUID, + payload: BanListUpdate, + session: SessionDep, + auth: MultitenantAuthDep, +): + ban_list = ban_list_crud.update( + session, + id=id, + organization_id=auth.organization_id, + project_id=auth.project_id, + data=payload, + ) + return APIResponse.success_response(data=ban_list) + + +@router.delete("/{id}", response_model=APIResponse[dict]) +def delete_ban_list( + id: UUID, + session: SessionDep, + auth: MultitenantAuthDep, +): + obj = ban_list_crud.get( + session, + id, + auth.organization_id, + auth.project_id, + require_owner=True, + ) + ban_list_crud.delete(session, obj) + return APIResponse.success_response( + data={"message": "Ban list deleted successfully"} + ) diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 9e716db..4700a64 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -4,12 +4,17 @@ from fastapi import APIRouter from guardrails.guard import Guard from guardrails.validators import FailResult, PassResult +from sqlmodel import Session from app.api.deps import AuthDep, SessionDep -from app.core.constants import REPHRASE_ON_FAIL_PREFIX +from app.core.constants import BAN_LIST, REPHRASE_ON_FAIL_PREFIX from app.core.config import settings from app.core.guardrail_controller import build_guard, get_validator_config_models from app.core.exception_handlers import _safe_error_message +from app.core.validators.config.ban_list_safety_validator_config import ( + BanListSafetyValidatorConfig, +) +from app.crud.ban_list import ban_list_crud from app.crud.request_log import RequestLogCrud from app.crud.validator_log import ValidatorLogCrud from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse @@ -33,14 +38,13 @@ def run_guardrails( validator_log_crud = ValidatorLogCrud(session=session) try: - request_id = UUID(payload.request_id) + request_log = request_log_crud.create(payload) except ValueError: return APIResponse.failure_response(error="Invalid request_id") - request_log = request_log_crud.create(request_id, input_text=payload.input) + _resolve_ban_list_banned_words(payload, session) return _validate_with_guard( - payload.input, - payload.validators, + payload, request_log_crud, request_log.id, validator_log_crud, @@ -78,9 +82,25 @@ def list_validators(_: AuthDep): return {"validators": validators} +def _resolve_ban_list_banned_words(payload: GuardrailRequest, session: Session) -> None: + for validator in payload.validators: + if not isinstance(validator, BanListSafetyValidatorConfig): + continue + + if validator.type != BAN_LIST or validator.banned_words is not None: + continue + + ban_list = ban_list_crud.get( + session, + id=validator.ban_list_id, + organization_id=payload.organization_id, + project_id=payload.project_id, + ) + validator.banned_words = ban_list.banned_words + + def _validate_with_guard( - data: str, - validators: list, + payload: GuardrailRequest, request_log_crud: RequestLogCrud, request_log_id: UUID, validator_log_crud: ValidatorLogCrud, @@ -94,6 +114,8 @@ def _validate_with_guard( while still safely handling unexpected runtime errors. """ response_id = uuid.uuid4() + data = payload.input + validators = payload.validators guard: Guard | None = None def _finalize( @@ -125,7 +147,7 @@ def _finalize( if guard is not None: add_validator_logs( - guard, request_log_id, validator_log_crud, suppress_pass_logs + guard, request_log_id, validator_log_crud, payload, suppress_pass_logs ) rephrase_needed = validated_output is not None and validated_output.startswith( @@ -175,6 +197,7 @@ def add_validator_logs( guard: Guard, request_log_id: UUID, validator_log_crud: ValidatorLogCrud, + payload: GuardrailRequest, suppress_pass_logs: bool = False, ): history = getattr(guard, "history", None) @@ -202,6 +225,8 @@ def add_validator_logs( validator_log = ValidatorLog( request_id=request_log_id, + organization_id=payload.organization_id, + project_id=payload.project_id, name=log.validator_name, input=str(log.value_before_validation), output=log.value_after_validation, diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py index 470dc89..ed34895 100644 --- a/backend/app/api/routes/validator_configs.py +++ b/backend/app/api/routes/validator_configs.py @@ -13,7 +13,6 @@ from app.crud.validator_config import validator_config_crud from app.utils import APIResponse - router = APIRouter( prefix="/guardrails/validators/configs", tags=["validator configs"], diff --git a/backend/app/core/config.py b/backend/app/core/config.py index aa36c8e..c73ff6e 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -41,6 +41,8 @@ class Settings(BaseSettings): POSTGRES_PASSWORD: str = "" POSTGRES_DB: str = "" GUARDRAILS_HUB_API_KEY: str | None = None + KAAPI_AUTH_URL: str = "" + KAAPI_AUTH_TIMEOUT: int CORE_DIR: ClassVar[Path] = Path(__file__).resolve().parent SLUR_LIST_FILENAME: ClassVar[str] = "curated_slurlist_hi_en.csv" diff --git a/backend/app/core/exception_handlers.py b/backend/app/core/exception_handlers.py index 4549549..dd5739c 100644 --- a/backend/app/core/exception_handlers.py +++ b/backend/app/core/exception_handlers.py @@ -1,6 +1,7 @@ from fastapi import FastAPI, Request, HTTPException -from fastapi.exceptions import RequestValidationError +from fastapi.exceptions import RequestValidationError, ResponseValidationError from fastapi.responses import JSONResponse +from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.status import ( HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SERVER_ERROR, @@ -55,10 +56,33 @@ def _safe_error_message(exc: Exception) -> str: return str(exc) or "An unexpected error occurred." +def _normalize_error_detail(detail: object) -> str | list: + if isinstance(detail, (str, list)): + return detail + if isinstance(detail, dict): + message = detail.get("message") + if isinstance(message, str): + return message + return str(detail) + return str(detail) + + +def _http_error_response(exc: StarletteHTTPException) -> JSONResponse: + return JSONResponse( + status_code=exc.status_code, + content=APIResponse.failure_response( + _normalize_error_detail(exc.detail) + ).model_dump(), + headers=exc.headers, + ) + + def register_exception_handlers(app: FastAPI): @app.exception_handler(RequestValidationError) async def validation_error_handler(request: Request, exc: RequestValidationError): formatted_message = _format_validation_errors(exc.errors()) + if not formatted_message: + formatted_message = "Invalid request payload" return JSONResponse( status_code=HTTP_422_UNPROCESSABLE_ENTITY, content=APIResponse.failure_response(error=formatted_message).model_dump(), @@ -66,9 +90,21 @@ async def validation_error_handler(request: Request, exc: RequestValidationError @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): + return _http_error_response(exc) + + @app.exception_handler(StarletteHTTPException) + async def starlette_http_exception_handler( + request: Request, exc: StarletteHTTPException + ): + return _http_error_response(exc) + + @app.exception_handler(ResponseValidationError) + async def response_validation_error_handler( + request: Request, exc: ResponseValidationError + ): return JSONResponse( - status_code=exc.status_code, - content=APIResponse.failure_response(exc.detail).model_dump(), + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + content=APIResponse.failure_response(_safe_error_message(exc)).model_dump(), ) @app.exception_handler(Exception) diff --git a/backend/app/core/validators/config/ban_list_safety_validator_config.py b/backend/app/core/validators/config/ban_list_safety_validator_config.py index 9db2981..6279c3c 100644 --- a/backend/app/core/validators/config/ban_list_safety_validator_config.py +++ b/backend/app/core/validators/config/ban_list_safety_validator_config.py @@ -1,16 +1,25 @@ -from typing import List, Literal +from typing import List, Literal, Optional +from uuid import UUID from guardrails.hub import BanList +from pydantic import model_validator from app.core.validators.config.base_validator_config import BaseValidatorConfig class BanListSafetyValidatorConfig(BaseValidatorConfig): type: Literal["ban_list"] - banned_words: List[str] # list of banned words to be redacted + banned_words: Optional[List[str]] = None # list of banned words to be redacted + ban_list_id: Optional[UUID] = None + + @model_validator(mode="after") + def validate_ban_list_source(self): + if self.banned_words is None and self.ban_list_id is None: + raise ValueError("Either banned_words or ban_list_id must be provided.") + return self def build(self): return BanList( - banned_words=self.banned_words, + banned_words=self.banned_words or [], on_fail=self.resolve_on_fail(), ) diff --git a/backend/app/core/validators/config/base_validator_config.py b/backend/app/core/validators/config/base_validator_config.py index 61acd4f..c615092 100644 --- a/backend/app/core/validators/config/base_validator_config.py +++ b/backend/app/core/validators/config/base_validator_config.py @@ -6,7 +6,6 @@ from app.core.enum import GuardrailOnFail from app.core.on_fail_actions import rephrase_query_on_fail - _ON_FAIL_MAP = { GuardrailOnFail.Fix: OnFailAction.FIX, GuardrailOnFail.Exception: OnFailAction.EXCEPTION, diff --git a/backend/app/crud/ban_list.py b/backend/app/crud/ban_list.py new file mode 100644 index 0000000..819cfe7 --- /dev/null +++ b/backend/app/crud/ban_list.py @@ -0,0 +1,149 @@ +from typing import List, Optional +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select + +from app.models.config.ban_list import BanList +from app.schemas.ban_list import BanListCreate, BanListUpdate +from app.utils import now + + +class BanListCrud: + def create( + self, + session: Session, + data: BanListCreate, + organization_id: int, + project_id: int, + ) -> BanList: + ban_list = BanList( + **data.model_dump(), + organization_id=organization_id, + project_id=project_id, + ) + session.add(ban_list) + + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, "Ban list already exists for the given configuration" + ) + except Exception: + session.rollback() + raise + + session.refresh(ban_list) + return ban_list + + def get( + self, + session: Session, + id: UUID, + organization_id: int, + project_id: int, + require_owner: bool = False, + ) -> BanList: + ban_list = session.get(BanList, id) + + if ban_list is None: + raise HTTPException(status_code=404, detail="Ban list not found") + + if require_owner or not ban_list.is_public: + self.check_owner(ban_list, organization_id, project_id) + + return ban_list + + def list( + self, + session: Session, + organization_id: int, + project_id: int, + domain: Optional[str] = None, + offset: int = 0, + limit: int | None = None, + ) -> List[BanList]: + query = select(BanList).where( + ( + (BanList.organization_id == organization_id) + & (BanList.project_id == project_id) + ) + | (BanList.is_public == True) + ) + + if domain: + query = query.where(BanList.domain == domain) + + query = query.order_by(BanList.created_at.desc(), BanList.id.desc()) + + if offset: + query = query.offset(offset) + if limit is not None: + query = query.limit(limit) + + return list(session.exec(query)) + + def update( + self, + session: Session, + id: UUID, + organization_id: int, + project_id: int, + data: BanListUpdate, + ) -> BanList: + ban_list = self.get( + session, + id, + organization_id, + project_id, + require_owner=True, + ) + update_data = data.model_dump(exclude_unset=True) + + for field_name, field_value in update_data.items(): + setattr(ban_list, field_name, field_value) + + ban_list.updated_at = now() + + session.add(ban_list) + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, "Ban list already exists for the given configuration" + ) + except Exception: + session.rollback() + raise + + session.refresh(ban_list) + return ban_list + + def delete(self, session: Session, ban_list: BanList): + session.delete(ban_list) + try: + session.commit() + except Exception: + session.rollback() + raise + + def check_owner( + self, ban_list: BanList, organization_id: int, project_id: int + ) -> None: + is_owner = ( + ban_list.organization_id == organization_id + and ban_list.project_id == project_id + ) + + if not is_owner: + raise HTTPException( + status_code=403, + detail="You do not have permission to access this resource.", + ) + + +ban_list_crud = BanListCrud() diff --git a/backend/app/crud/request_log.py b/backend/app/crud/request_log.py index 95d3a6b..2a283cb 100644 --- a/backend/app/crud/request_log.py +++ b/backend/app/crud/request_log.py @@ -1,8 +1,9 @@ -from uuid import UUID, uuid4 +from uuid import UUID from sqlmodel import Session from app.models.logging.request_log import RequestLog, RequestLogUpdate, RequestStatus +from app.schemas.guardrail_config import GuardrailRequest from app.utils import now @@ -10,10 +11,13 @@ class RequestLogCrud: def __init__(self, session: Session): self.session = session - def create(self, request_id: UUID, input_text: str) -> RequestLog: + def create(self, payload: GuardrailRequest) -> RequestLog: + request_id = UUID(payload.request_id) create_request_log = RequestLog( request_id=request_id, - request_text=input_text, + request_text=payload.input, + organization_id=payload.organization_id, + project_id=payload.project_id, ) self.session.add(create_request_log) self.session.commit() diff --git a/backend/app/crud/validator_config.py b/backend/app/crud/validator_config.py index dbdeffa..79dffb8 100644 --- a/backend/app/crud/validator_config.py +++ b/backend/app/crud/validator_config.py @@ -69,6 +69,10 @@ def list( if type: query = query.where(ValidatorConfig.type == type) + query = query.order_by( + ValidatorConfig.created_at.asc(), ValidatorConfig.id.asc() + ) + rows = session.exec(query).all() return [self.flatten(r) for r in rows] diff --git a/backend/app/models/config/ban_list.py b/backend/app/models/config/ban_list.py new file mode 100644 index 0000000..8e40ac7 --- /dev/null +++ b/backend/app/models/config/ban_list.py @@ -0,0 +1,72 @@ +from datetime import datetime +from uuid import UUID, uuid4 + +from sqlalchemy import Column, String +from sqlalchemy.dialects.postgresql import ARRAY +from sqlmodel import Field, SQLModel + +from app.utils import now + + +class BanList(SQLModel, table=True): + __tablename__ = "ban_list" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the ban list entry"}, + ) + + name: str = Field( + nullable=False, sa_column_kwargs={"comment": "Name of the ban list entry"} + ) + + description: str = Field( + nullable=False, + sa_column_kwargs={"comment": "Description of the ban list entry"}, + ) + + banned_words: list[str] = Field( + default_factory=list, + sa_column=Column( + ARRAY(String), + nullable=False, + comment="List of banned words", + ), + description=("List of banned words"), + ) + + organization_id: int = Field( + nullable=False, + sa_column_kwargs={"comment": "Identifier for the organization"}, + ) + + project_id: int = Field( + nullable=False, + sa_column_kwargs={"comment": "Identifier for the project"}, + ) + + domain: str = Field( + nullable=False, + sa_column_kwargs={"comment": "Domain or context for the ban list entry"}, + ) + + is_public: bool = Field( + default=False, + sa_column_kwargs={"comment": "Whether the ban list entry is public or private"}, + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the ban list entry was created"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={ + "comment": "Timestamp when the ban list entry was last updated", + "onupdate": now, + }, + ) diff --git a/backend/app/models/logging/request_log.py b/backend/app/models/logging/request_log.py index 7b7a131..bda3ad7 100644 --- a/backend/app/models/logging/request_log.py +++ b/backend/app/models/logging/request_log.py @@ -24,6 +24,16 @@ class RequestLog(SQLModel, table=True): sa_column_kwargs={"comment": "Unique identifier for the request log entry"}, ) + organization_id: int = Field( + nullable=False, + sa_column_kwargs={"comment": "Identifier for the organization"}, + ) + + project_id: int = Field( + nullable=False, + sa_column_kwargs={"comment": "Identifier for the project"}, + ) + request_id: UUID = Field( nullable=False, sa_column_kwargs={"comment": "Identifier for the request"}, diff --git a/backend/app/models/logging/validator_log.py b/backend/app/models/logging/validator_log.py index a48090d..c04a4a1 100644 --- a/backend/app/models/logging/validator_log.py +++ b/backend/app/models/logging/validator_log.py @@ -21,6 +21,16 @@ class ValidatorLog(SQLModel, table=True): sa_column_kwargs={"comment": "Unique identifier for the validator log entry"}, ) + organization_id: int = Field( + nullable=False, + sa_column_kwargs={"comment": "Identifier for the organization"}, + ) + + project_id: int = Field( + nullable=False, + sa_column_kwargs={"comment": "Identifier for the project"}, + ) + request_id: UUID = Field( foreign_key="request_log.id", nullable=False, diff --git a/backend/app/schemas/ban_list.py b/backend/app/schemas/ban_list.py new file mode 100644 index 0000000..0cb2a84 --- /dev/null +++ b/backend/app/schemas/ban_list.py @@ -0,0 +1,63 @@ +from datetime import datetime +from uuid import UUID +from typing import Annotated, Optional + +from pydantic import StringConstraints +from sqlmodel import Field +from sqlmodel import SQLModel + +MAX_BANNED_WORD_LENGTH = 100 +MAX_BANNED_WORDS_ITEMS = 1000 +MAX_BAN_LIST_NAME_LENGTH = 100 +MAX_BAN_LIST_DESCRIPTION_LENGTH = 500 + +BanListName = Annotated[ + str, + StringConstraints( + strip_whitespace=True, + min_length=1, + max_length=MAX_BAN_LIST_NAME_LENGTH, + ), +] +BanListDescription = Annotated[ + str, + StringConstraints( + strip_whitespace=True, + min_length=1, + max_length=MAX_BAN_LIST_DESCRIPTION_LENGTH, + ), +] + +BannedWord = Annotated[ + str, + StringConstraints( + strip_whitespace=True, min_length=1, max_length=MAX_BANNED_WORD_LENGTH + ), +] +BannedWords = Annotated[list[BannedWord], Field(max_length=MAX_BANNED_WORDS_ITEMS)] + + +class BanListBase(SQLModel): + name: BanListName + description: BanListDescription + banned_words: BannedWords + domain: str + is_public: bool = False + + +class BanListCreate(BanListBase): + pass + + +class BanListUpdate(SQLModel): + name: Optional[BanListName] = None + description: Optional[BanListDescription] = None + banned_words: Optional[BannedWords] = None + domain: Optional[str] = None + is_public: Optional[bool] = None + + +class BanListResponse(BanListBase): + id: UUID + created_at: datetime + updated_at: datetime diff --git a/backend/app/schemas/guardrail_config.py b/backend/app/schemas/guardrail_config.py index b9d04d8..53c8557 100644 --- a/backend/app/schemas/guardrail_config.py +++ b/backend/app/schemas/guardrail_config.py @@ -34,6 +34,8 @@ class GuardrailRequest(SQLModel): model_config = ConfigDict(extra="forbid") request_id: str + organization_id: int + project_id: int input: str validators: List[ValidatorConfigItem] diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index c1ea5ea..807d28e 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -4,12 +4,30 @@ os.environ["ENVIRONMENT"] = "testing" import pytest +from fastapi import Header from fastapi.testclient import TestClient from sqlmodel import Session, create_engine, SQLModel from app.main import app -from app.api.deps import SessionDep, verify_bearer_token +from app.api.deps import ( + SessionDep, + TenantContext, + validate_multitenant_key, + verify_bearer_token, +) from app.core.config import settings +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.models.config.ban_list import BanList +from app.models.config.validator_config import ValidatorConfig +from app.tests.seed_data import ( + BAN_LIST_INTEGRATION_ORGANIZATION_ID, + BAN_LIST_INTEGRATION_PROJECT_ID, + BAN_LIST_PAYLOADS, + VALIDATOR_INTEGRATION_ORGANIZATION_ID, + VALIDATOR_INTEGRATION_PROJECT_ID, + VALIDATOR_PAYLOADS, +) +from app.utils import split_validator_payload test_engine = create_engine( str(settings.SQLALCHEMY_DATABASE_URI), @@ -23,6 +41,31 @@ def override_session(): yield session +def seed_test_data(session: Session) -> None: + for payload in BAN_LIST_PAYLOADS.values(): + session.add( + BanList( + **payload, + organization_id=BAN_LIST_INTEGRATION_ORGANIZATION_ID, + project_id=BAN_LIST_INTEGRATION_PROJECT_ID, + ) + ) + + for payload in VALIDATOR_PAYLOADS.values(): + model_fields, config_fields = split_validator_payload(payload) + session.add( + ValidatorConfig( + organization_id=VALIDATOR_INTEGRATION_ORGANIZATION_ID, + project_id=VALIDATOR_INTEGRATION_PROJECT_ID, + type=ValidatorType(model_fields["type"]), + stage=Stage(model_fields["stage"]), + on_fail_action=GuardrailOnFail(model_fields["on_fail_action"]), + is_enabled=model_fields.get("is_enabled", True), + config=config_fields, + ) + ) + + @pytest.fixture(scope="session", autouse=True) def setup_test_db(): SQLModel.metadata.create_all(test_engine) @@ -41,6 +84,30 @@ def clean_db(): @pytest.fixture(scope="function", autouse=True) def override_dependencies(): app.dependency_overrides[verify_bearer_token] = lambda: True + default_scope = TenantContext( + organization_id=BAN_LIST_INTEGRATION_ORGANIZATION_ID, + project_id=BAN_LIST_INTEGRATION_PROJECT_ID, + ) + + def override_multitenant_key( + x_api_key: str | None = Header(default=None, alias="X-API-KEY"), + ): + if not x_api_key: + return default_scope + + token = x_api_key.strip() + if token.lower().startswith("apikey "): + token = token.split(" ", 1)[1].strip() + + if token == "org999_project999": + return TenantContext(organization_id=999, project_id=999) + + if token == "org2_project2": + return TenantContext(organization_id=2, project_id=2) + + return default_scope + + app.dependency_overrides[validate_multitenant_key] = override_multitenant_key app.dependency_overrides[SessionDep] = override_session @@ -49,6 +116,20 @@ def override_dependencies(): app.dependency_overrides.clear() +@pytest.fixture(scope="function") +def seed_db(): + with Session(test_engine) as session: + seed_test_data(session) + session.commit() + yield + + +@pytest.fixture +def clear_database(): + """Compatibility fixture; database cleanup is handled by clean_db.""" + yield + + @pytest.fixture(scope="function") def client(): with TestClient(app) as c: diff --git a/backend/app/tests/seed_data.json b/backend/app/tests/seed_data.json new file mode 100644 index 0000000..009b1ce --- /dev/null +++ b/backend/app/tests/seed_data.json @@ -0,0 +1,90 @@ +{ + "ban_list": { + "unit": { + "test_id": "11111111-1111-1111-1111-111111111111", + "organization_id": 1, + "project_id": 10, + "sample": { + "name": "test", + "description": "desc", + "banned_words": ["bad"], + "domain": "health", + "is_public": false + } + }, + "integration": { + "organization_id": 1, + "project_id": 1, + "payloads": { + "minimal": { + "name": "default", + "description": "basic list", + "banned_words": ["bad"], + "domain": "general" + }, + "health": { + "name": "health-list", + "description": "healthcare words", + "banned_words": ["gender detection", "sonography"], + "domain": "health" + }, + "edu": { + "name": "edu-list", + "description": "education words", + "banned_words": ["cheating"], + "domain": "edu" + }, + "public": { + "name": "public-list", + "description": "shared", + "banned_words": ["shared"], + "is_public": true, + "domain": "general" + } + } + } + }, + "validator": { + "unit": { + "validator_id": "22222222-2222-2222-2222-222222222222", + "organization_id": 1, + "project_id": 1, + "type": "LexicalSlur", + "stage": "Input", + "on_fail_action": "Fix", + "is_enabled": true, + "config": { + "severity": "all", + "languages": ["en", "hi"] + } + }, + "integration": { + "organization_id": 1, + "project_id": 1, + "payloads": { + "lexical_slur": { + "type": "uli_slur_match", + "stage": "input", + "on_fail_action": "fix", + "severity": "all", + "languages": ["en", "hi"] + }, + "pii_remover_input": { + "type": "pii_remover", + "stage": "input", + "on_fail_action": "fix" + }, + "pii_remover_output": { + "type": "pii_remover", + "stage": "output", + "on_fail_action": "fix" + }, + "minimal": { + "type": "gender_assumption_bias", + "stage": "input", + "on_fail_action": "fix" + } + } + } + } +} diff --git a/backend/app/tests/seed_data.py b/backend/app/tests/seed_data.py new file mode 100644 index 0000000..a6bcaa6 --- /dev/null +++ b/backend/app/tests/seed_data.py @@ -0,0 +1,75 @@ +import json +from pathlib import Path +import uuid +from unittest.mock import MagicMock + +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.models.config.validator_config import ValidatorConfig +from app.schemas.ban_list import BanListCreate + +SEED_DATA_PATH = Path(__file__).with_name("seed_data.json") + + +def _load_seed_data() -> dict: + with SEED_DATA_PATH.open("r", encoding="utf-8") as f: + return json.load(f) + + +DATA = _load_seed_data() + +BAN_LIST_UNIT = DATA["ban_list"]["unit"] +BAN_LIST_INTEGRATION = DATA["ban_list"]["integration"] + +VALIDATOR_UNIT = DATA["validator"]["unit"] +VALIDATOR_INTEGRATION = DATA["validator"]["integration"] + +BAN_LIST_TEST_ID = uuid.UUID(BAN_LIST_UNIT["test_id"]) +BAN_LIST_TEST_ORGANIZATION_ID = BAN_LIST_UNIT["organization_id"] +BAN_LIST_TEST_PROJECT_ID = BAN_LIST_UNIT["project_id"] + +BAN_LIST_INTEGRATION_ORGANIZATION_ID = BAN_LIST_INTEGRATION["organization_id"] +BAN_LIST_INTEGRATION_PROJECT_ID = BAN_LIST_INTEGRATION["project_id"] +BAN_LIST_PAYLOADS = BAN_LIST_INTEGRATION["payloads"] + +VALIDATOR_TEST_ID = uuid.UUID(VALIDATOR_UNIT["validator_id"]) +VALIDATOR_TEST_ORGANIZATION_ID = VALIDATOR_UNIT["organization_id"] +VALIDATOR_TEST_PROJECT_ID = VALIDATOR_UNIT["project_id"] +VALIDATOR_TEST_TYPE = ValidatorType[VALIDATOR_UNIT["type"]] +VALIDATOR_TEST_STAGE = Stage[VALIDATOR_UNIT["stage"]] +VALIDATOR_TEST_ON_FAIL = GuardrailOnFail[VALIDATOR_UNIT["on_fail_action"]] +VALIDATOR_TEST_CONFIG = VALIDATOR_UNIT["config"] +VALIDATOR_TEST_IS_ENABLED = VALIDATOR_UNIT["is_enabled"] + +VALIDATOR_INTEGRATION_ORGANIZATION_ID = VALIDATOR_INTEGRATION["organization_id"] +VALIDATOR_INTEGRATION_PROJECT_ID = VALIDATOR_INTEGRATION["project_id"] +VALIDATOR_PAYLOADS = VALIDATOR_INTEGRATION["payloads"] + + +def build_ban_list_create_payload() -> BanListCreate: + return BanListCreate(**BAN_LIST_UNIT["sample"]) + + +def build_sample_ban_list_mock() -> MagicMock: + obj = MagicMock() + obj.id = BAN_LIST_TEST_ID + obj.name = BAN_LIST_UNIT["sample"]["name"] + obj.description = BAN_LIST_UNIT["sample"]["description"] + obj.banned_words = BAN_LIST_UNIT["sample"]["banned_words"] + obj.organization_id = BAN_LIST_TEST_ORGANIZATION_ID + obj.project_id = BAN_LIST_TEST_PROJECT_ID + obj.domain = BAN_LIST_UNIT["sample"]["domain"] + obj.is_public = BAN_LIST_UNIT["sample"].get("is_public", False) + return obj + + +def build_sample_validator_config() -> ValidatorConfig: + return ValidatorConfig( + id=VALIDATOR_TEST_ID, + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + type=VALIDATOR_TEST_TYPE, + stage=VALIDATOR_TEST_STAGE, + on_fail_action=VALIDATOR_TEST_ON_FAIL, + is_enabled=VALIDATOR_TEST_IS_ENABLED, + config=VALIDATOR_TEST_CONFIG, + ) diff --git a/backend/app/tests/test_banlists_api.py b/backend/app/tests/test_banlists_api.py new file mode 100644 index 0000000..224e542 --- /dev/null +++ b/backend/app/tests/test_banlists_api.py @@ -0,0 +1,132 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from sqlmodel import Session + +from app.api.deps import TenantContext +from app.api.routes.ban_lists import ( + create_ban_list, + list_ban_lists, + get_ban_list, + update_ban_list, + delete_ban_list, +) +from app.schemas.ban_list import BanListUpdate +from app.tests.seed_data import ( + BAN_LIST_TEST_ID, + BAN_LIST_TEST_ORGANIZATION_ID, + BAN_LIST_TEST_PROJECT_ID, + build_ban_list_create_payload, + build_sample_ban_list_mock, +) + + +@pytest.fixture +def mock_session(): + return MagicMock(spec=Session) + + +@pytest.fixture +def sample_ban_list(): + return build_sample_ban_list_mock() + + +@pytest.fixture +def create_payload(): + return build_ban_list_create_payload() + + +@pytest.fixture +def auth_context(): + return TenantContext( + organization_id=BAN_LIST_TEST_ORGANIZATION_ID, + project_id=BAN_LIST_TEST_PROJECT_ID, + ) + + +def test_create_calls_crud(mock_session, create_payload, sample_ban_list, auth_context): + with patch("app.api.routes.ban_lists.ban_list_crud") as crud: + crud.create.return_value = sample_ban_list + + result = create_ban_list( + payload=create_payload, + session=mock_session, + auth=auth_context, + ) + + assert result.data == sample_ban_list + + +def test_list_returns_data(mock_session, sample_ban_list, auth_context): + with patch("app.api.routes.ban_lists.ban_list_crud") as crud: + crud.list.return_value = [sample_ban_list] + + result = list_ban_lists( + session=mock_session, + auth=auth_context, + ) + + crud.list.assert_called_once_with( + mock_session, + BAN_LIST_TEST_ORGANIZATION_ID, + BAN_LIST_TEST_PROJECT_ID, + None, + offset=0, + limit=None, + ) + assert len(result.data) == 1 + + +def test_get_success(mock_session, sample_ban_list, auth_context): + with patch("app.api.routes.ban_lists.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + result = get_ban_list( + id=BAN_LIST_TEST_ID, + session=mock_session, + auth=auth_context, + ) + + assert result.data == sample_ban_list + + +def test_update_success(mock_session, sample_ban_list, auth_context): + with patch("app.api.routes.ban_lists.ban_list_crud") as crud: + crud.update.return_value = sample_ban_list + + result = update_ban_list( + id=BAN_LIST_TEST_ID, + payload=BanListUpdate(name="new"), + session=mock_session, + auth=auth_context, + ) + + crud.update.assert_called_once() + _, kwargs = crud.update.call_args + assert kwargs["id"] == BAN_LIST_TEST_ID + assert kwargs["organization_id"] == BAN_LIST_TEST_ORGANIZATION_ID + assert kwargs["project_id"] == BAN_LIST_TEST_PROJECT_ID + assert kwargs["data"].name == "new" + assert result.data == sample_ban_list + + +def test_delete_success(mock_session, sample_ban_list, auth_context): + with patch("app.api.routes.ban_lists.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + result = delete_ban_list( + id=BAN_LIST_TEST_ID, + session=mock_session, + auth=auth_context, + ) + + crud.get.assert_called_once_with( + mock_session, + BAN_LIST_TEST_ID, + BAN_LIST_TEST_ORGANIZATION_ID, + BAN_LIST_TEST_PROJECT_ID, + require_owner=True, + ) + crud.delete.assert_called_once_with(mock_session, sample_ban_list) + assert result.success is True diff --git a/backend/app/tests/test_banlists_api_integration.py b/backend/app/tests/test_banlists_api_integration.py new file mode 100644 index 0000000..64f2221 --- /dev/null +++ b/backend/app/tests/test_banlists_api_integration.py @@ -0,0 +1,297 @@ +import uuid +import pytest +from app.schemas.ban_list import ( + MAX_BANNED_WORD_LENGTH, + MAX_BANNED_WORDS_ITEMS, + MAX_BAN_LIST_DESCRIPTION_LENGTH, + MAX_BAN_LIST_NAME_LENGTH, +) +from app.tests.seed_data import BAN_LIST_PAYLOADS + +pytestmark = pytest.mark.integration + + +BASE_URL = "/api/v1/guardrails/ban_lists/" +DEFAULT_API_KEY = "org1_project1" +ALT_API_KEY_999 = "org999_project999" +ALT_API_KEY_2 = "org2_project2" + + +class BaseBanListTest: + def _headers(self, api_key=DEFAULT_API_KEY): + return {"X-API-Key": api_key} + + def create(self, client, payload_key="minimal", api_key=DEFAULT_API_KEY, **kwargs): + payload = {**BAN_LIST_PAYLOADS[payload_key], **kwargs} + return client.post(BASE_URL, json=payload, headers=self._headers(api_key)) + + def list(self, client, api_key=DEFAULT_API_KEY, **filters): + return client.get(BASE_URL, params=filters, headers=self._headers(api_key)) + + def get(self, client, id, api_key=DEFAULT_API_KEY): + return client.get(f"{BASE_URL}{id}/", headers=self._headers(api_key)) + + def update(self, client, id, payload, api_key=DEFAULT_API_KEY): + return client.patch( + f"{BASE_URL}{id}/", + json=payload, + headers=self._headers(api_key), + ) + + def delete(self, client, id, api_key=DEFAULT_API_KEY): + return client.delete(f"{BASE_URL}{id}/", headers=self._headers(api_key)) + + +class TestCreateBanList(BaseBanListTest): + def test_create_success(self, integration_client, clear_database): + response = self.create(integration_client, "minimal") + + assert response.status_code == 200 + data = response.json()["data"] + + assert data["name"] == "default" + assert data["banned_words"] == ["bad"] + + def test_create_validation_error(self, integration_client, clear_database): + response = integration_client.post( + BASE_URL, + json={"name": "missing words"}, + headers=self._headers(), + ) + + assert response.status_code == 422 + + def test_create_validation_error_banned_word_too_long( + self, integration_client, clear_database + ): + response = self.create( + integration_client, + "minimal", + banned_words=["a" * (MAX_BANNED_WORD_LENGTH + 1)], + ) + + assert response.status_code == 422 + + def test_create_validation_error_too_many_banned_words( + self, integration_client, clear_database + ): + response = self.create( + integration_client, + "minimal", + banned_words=["x"] * (MAX_BANNED_WORDS_ITEMS + 1), + ) + + assert response.status_code == 422 + + def test_create_validation_error_name_too_long( + self, integration_client, clear_database + ): + response = self.create( + integration_client, + "minimal", + name="n" * (MAX_BAN_LIST_NAME_LENGTH + 1), + ) + + assert response.status_code == 422 + + def test_create_validation_error_description_too_long( + self, integration_client, clear_database + ): + response = self.create( + integration_client, + "minimal", + description="d" * (MAX_BAN_LIST_DESCRIPTION_LENGTH + 1), + ) + + assert response.status_code == 422 + + +class TestListBanLists(BaseBanListTest): + def test_list_success(self, integration_client, seed_db): + response = self.list(integration_client) + + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 4 + + def test_filter_by_domain(self, integration_client, seed_db): + response = self.list(integration_client, domain="health") + + data = response.json()["data"] + + assert len(data) == 1 + assert data[0]["domain"] == "health" + + def test_list_empty(self, integration_client, clear_database): + response = self.list(integration_client) + + assert response.json()["data"] == [] + + def test_list_pagination_with_limit(self, integration_client, seed_db): + response = self.list(integration_client, limit=2) + + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + + def test_list_pagination_with_offset_and_limit(self, integration_client, seed_db): + full_response = self.list(integration_client) + full_data = full_response.json()["data"] + + response = self.list(integration_client, offset=2, limit=2) + + assert response.status_code == 200 + paged_data = response.json()["data"] + assert len(paged_data) == 2 + assert [item["id"] for item in paged_data] == [ + item["id"] for item in full_data[2:4] + ] + + +class TestPublicAccess(BaseBanListTest): + def test_public_visible_to_other_org(self, integration_client, clear_database): + create_resp = self.create(integration_client, "public") + ban_id = create_resp.json()["data"]["id"] + + response = self.get(integration_client, ban_id, api_key=ALT_API_KEY_999) + + assert response.status_code == 200 + + +class TestGetBanList(BaseBanListTest): + def test_get_success(self, integration_client, seed_db): + list_resp = self.list(integration_client) + ban_id = list_resp.json()["data"][0]["id"] + + response = self.get(integration_client, ban_id) + + assert response.status_code == 200 + + def test_get_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + response = self.get(integration_client, fake) + body = response.json() + + assert response.status_code == 404 + assert body["success"] is False + assert body["metadata"] is None + assert "Ban list not found" in body["error"] + + def test_get_wrong_owner_private(self, integration_client, seed_db): + list_resp = self.list(integration_client) + private_ban_list = next( + item for item in list_resp.json()["data"] if not item["is_public"] + ) + ban_id = private_ban_list["id"] + + response = self.get(integration_client, ban_id, api_key=ALT_API_KEY_2) + body = response.json() + + assert response.status_code == 403 + assert body["success"] is False + assert "permission" in body["error"].lower() + + +class TestUpdateBanList(BaseBanListTest): + def test_update_success(self, integration_client, seed_db): + list_resp = self.list(integration_client) + ban_id = list_resp.json()["data"][0]["id"] + + response = self.update( + integration_client, + ban_id, + {"banned_words": ["bad", "worse"]}, + ) + + assert response.status_code == 200 + + data = response.json()["data"] + assert data["banned_words"] == ["bad", "worse"] + + def test_partial_update(self, integration_client, seed_db): + list_resp = self.list(integration_client) + ban_id = list_resp.json()["data"][0]["id"] + + response = self.update(integration_client, ban_id, {"name": "updated"}) + + assert response.json()["data"]["name"] == "updated" + + def test_update_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + + response = self.update(integration_client, fake, {"name": "x"}) + body = response.json() + + assert response.status_code == 404 + assert body["success"] is False + assert "Ban list not found" in body["error"] + + def test_update_public_wrong_owner_fails(self, integration_client, clear_database): + create_resp = self.create(integration_client, "public") + ban_id = create_resp.json()["data"]["id"] + + response = self.update( + integration_client, + ban_id, + {"name": "updated-by-other-org"}, + api_key=ALT_API_KEY_999, + ) + body = response.json() + + assert response.status_code == 403 + assert body["success"] is False + assert "permission" in body["error"].lower() + + +class TestDeleteBanList(BaseBanListTest): + def test_delete_success(self, integration_client, seed_db): + list_resp = self.list(integration_client) + ban_id = list_resp.json()["data"][0]["id"] + + response = self.delete(integration_client, ban_id) + + assert response.status_code == 200 + assert response.json()["success"] is True + + def test_delete_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + + response = self.delete(integration_client, fake) + body = response.json() + + assert response.status_code == 404 + assert body["success"] is False + assert "Ban list not found" in body["error"] + + def test_delete_wrong_owner(self, integration_client, seed_db): + list_resp = self.list(integration_client) + private_ban_list = next( + item for item in list_resp.json()["data"] if not item["is_public"] + ) + ban_id = private_ban_list["id"] + + response = self.delete( + integration_client, + ban_id, + api_key=ALT_API_KEY_999, + ) + body = response.json() + + assert response.status_code == 403 + assert body["success"] is False + assert "permission" in body["error"].lower() + + def test_delete_public_wrong_owner_fails(self, integration_client, clear_database): + create_resp = self.create(integration_client, "public") + ban_id = create_resp.json()["data"]["id"] + + response = self.delete( + integration_client, + ban_id, + api_key=ALT_API_KEY_999, + ) + body = response.json() + + assert response.status_code == 403 + assert body["success"] is False + assert "permission" in body["error"].lower() diff --git a/backend/app/tests/test_deps_multitenant.py b/backend/app/tests/test_deps_multitenant.py new file mode 100644 index 0000000..2c76880 --- /dev/null +++ b/backend/app/tests/test_deps_multitenant.py @@ -0,0 +1,172 @@ +from unittest.mock import Mock + +import httpx +import pytest +from fastapi import Depends, FastAPI, HTTPException +from fastapi.testclient import TestClient + +from app.api.deps import TenantContext, validate_multitenant_key +from app.core.config import settings +from app.core.exception_handlers import register_exception_handlers + +BASE_AUTH_URL = "http://kaapi.local/api/v1" +VERIFY_URL = f"{BASE_AUTH_URL}/apikeys/verify" + + +def test_validate_multitenant_key_parses_credentials_shape(monkeypatch): + monkeypatch.setattr( + settings, + "KAAPI_AUTH_URL", + BASE_AUTH_URL, + ) + + response = Mock() + response.status_code = 200 + response.json.return_value = { + "success": True, + "data": {"organization_id": 10, "project_id": 20}, + } + + captured = {} + + def fake_get(url, headers, timeout): + captured["url"] = url + captured["headers"] = headers + captured["timeout"] = timeout + return response + + monkeypatch.setattr(httpx, "get", fake_get) + + context = validate_multitenant_key("abc123") + + assert isinstance(context, TenantContext) + assert (context.organization_id, context.project_id) == (10, 20) + assert captured["url"] == VERIFY_URL + assert captured["headers"]["X-API-KEY"] == "ApiKey abc123" + assert captured["timeout"] == 5 + + +def test_validate_multitenant_key_invalid_status_returns_401(monkeypatch): + monkeypatch.setattr( + settings, + "KAAPI_AUTH_URL", + BASE_AUTH_URL, + ) + + response = Mock() + response.status_code = 401 + response.json.return_value = {"success": False, "data": None} + + monkeypatch.setattr(httpx, "get", lambda *args, **kwargs: response) + + with pytest.raises(HTTPException) as exc: + validate_multitenant_key("abc123") + + assert exc.value.status_code == 401 + + +def test_validate_multitenant_key_network_error_returns_503(monkeypatch): + monkeypatch.setattr( + settings, + "KAAPI_AUTH_URL", + BASE_AUTH_URL, + ) + + def fake_get(*args, **kwargs): + raise httpx.RequestError("boom", request=Mock()) + + monkeypatch.setattr(httpx, "get", fake_get) + + with pytest.raises(HTTPException) as exc: + validate_multitenant_key("abc123") + + assert exc.value.status_code == 503 + + +def test_validate_multitenant_key_invalid_payload_returns_401(monkeypatch): + monkeypatch.setattr( + settings, + "KAAPI_AUTH_URL", + BASE_AUTH_URL, + ) + + response = Mock() + response.status_code = 200 + response.json.return_value = {"success": True, "data": {"foo": 1}} + + monkeypatch.setattr(httpx, "get", lambda *args, **kwargs: response) + + with pytest.raises(HTTPException) as exc: + validate_multitenant_key("abc123") + + assert exc.value.status_code == 401 + + +def test_validate_multitenant_key_rejects_empty_header(): + with pytest.raises(HTTPException) as exc: + validate_multitenant_key(" ") + + assert exc.value.status_code == 401 + + +def test_validate_multitenant_key_accepts_raw_header_value(monkeypatch): + monkeypatch.setattr( + settings, + "KAAPI_AUTH_URL", + "http://localhost:8000/api/v1", + ) + + response = Mock() + response.status_code = 200 + response.json.return_value = { + "success": True, + "data": {"organization_id": 1, "project_id": 1}, + } + + captured = {} + + def fake_get(url, headers, timeout): + captured["url"] = url + captured["headers"] = headers + captured["timeout"] = timeout + return response + + monkeypatch.setattr(httpx, "get", fake_get) + + context = validate_multitenant_key("No3x47A5") + + assert isinstance(context, TenantContext) + assert context.organization_id == 1 + assert context.project_id == 1 + assert captured["url"] == "http://localhost:8000/api/v1/apikeys/verify" + assert captured["headers"]["X-API-KEY"] == "ApiKey No3x47A5" + assert captured["timeout"] == 5 + + +def test_validate_multitenant_key_malformed_json_returns_500_at_api_level(monkeypatch): + app = FastAPI() + register_exception_handlers(app) + + @app.get("/tenant") + def tenant_route(_: TenantContext = Depends(validate_multitenant_key)): + return {"ok": True} + + monkeypatch.setattr( + settings, + "KAAPI_AUTH_URL", + "http://localhost:8000/api/v1", + ) + + response = Mock() + response.status_code = 200 + response.json.side_effect = ValueError("Malformed JSON") + + monkeypatch.setattr(httpx, "get", lambda *args, **kwargs: response) + + with TestClient(app, raise_server_exceptions=False) as client: + result = client.get("/tenant", headers={"X-API-KEY": "abc123"}) + + assert result.status_code == 500 + body = result.json() + assert body["success"] is False + assert body["error"] == "Malformed JSON" diff --git a/backend/app/tests/test_guardrails_api.py b/backend/app/tests/test_guardrails_api.py index c945023..86035ae 100644 --- a/backend/app/tests/test_guardrails_api.py +++ b/backend/app/tests/test_guardrails_api.py @@ -3,12 +3,18 @@ import pytest from app.tests.guardrails_mocks import MockResult +from app.tests.seed_data import ( + VALIDATOR_TEST_ORGANIZATION_ID, + VALIDATOR_TEST_PROJECT_ID, +) from app.tests.utils.constants import SAFE_TEXT_FIELD, VALIDATE_API_PATH build_guard_path = "app.api.routes.guardrails.build_guard" crud_path = "app.api.routes.guardrails.RequestLogCrud" request_id = "123e4567-e89b-12d3-a456-426614174000" +organization_id = VALIDATOR_TEST_ORGANIZATION_ID +project_id = VALIDATOR_TEST_PROJECT_ID @pytest.fixture @@ -34,6 +40,8 @@ def validate(self, data): VALIDATE_API_PATH, json={ "request_id": request_id, + "organization_id": organization_id, + "project_id": project_id, "input": "hello world", "validators": [], }, @@ -57,6 +65,8 @@ def validate(self, data): VALIDATE_API_PATH, json={ "request_id": request_id, + "organization_id": organization_id, + "project_id": project_id, "input": "my phone is 999999", "validators": [], }, @@ -76,6 +86,8 @@ def test_guardrails_internal_error(client, mock_crud): VALIDATE_API_PATH, json={ "request_id": request_id, + "organization_id": organization_id, + "project_id": project_id, "input": "text", "validators": [], }, diff --git a/backend/app/tests/test_guardrails_api_integration.py b/backend/app/tests/test_guardrails_api_integration.py index f9013cf..23ecb47 100644 --- a/backend/app/tests/test_guardrails_api_integration.py +++ b/backend/app/tests/test_guardrails_api_integration.py @@ -1,17 +1,25 @@ import pytest +from app.tests.seed_data import ( + VALIDATOR_INTEGRATION_ORGANIZATION_ID, + VALIDATOR_INTEGRATION_PROJECT_ID, +) from app.tests.utils.constants import SAFE_TEXT_FIELD, VALIDATE_API_PATH pytestmark = pytest.mark.integration request_id = "123e4567-e89b-12d3-a456-426614174000" +organization_id = VALIDATOR_INTEGRATION_ORGANIZATION_ID +project_id = VALIDATOR_INTEGRATION_PROJECT_ID -def test_input_guardrails_with_real_banlist(integration_client): +def test_input_guardrails_with_real_ban_list(integration_client): response = integration_client.post( VALIDATE_API_PATH, json={ "request_id": request_id, + "organization_id": organization_id, + "project_id": project_id, "input": "this contains badword", "validators": [ { @@ -34,6 +42,8 @@ def test_input_guardrails_passes_clean_text(integration_client): VALIDATE_API_PATH, json={ "request_id": request_id, + "organization_id": organization_id, + "project_id": project_id, "input": "this is clean", "validators": [ { @@ -56,6 +66,8 @@ def test_input_guardrails_with_lexical_slur(integration_client): VALIDATE_API_PATH, json={ "request_id": request_id, + "organization_id": organization_id, + "project_id": project_id, "input": "This sentence contains chakki.", "validators": [ { @@ -78,6 +90,8 @@ def test_input_guardrails_with_lexical_slur_clean_text(integration_client): VALIDATE_API_PATH, json={ "request_id": request_id, + "organization_id": organization_id, + "project_id": project_id, "input": "This is a completely safe sentence", "validators": [ { @@ -100,6 +114,8 @@ def test_input_guardrails_with_multiple_validators(integration_client): VALIDATE_API_PATH, json={ "request_id": request_id, + "organization_id": organization_id, + "project_id": project_id, "input": ( "This sentence contains chakki cause I want a " "sonography done to kill the female foetus." @@ -132,6 +148,8 @@ def test_input_guardrails_with_incorrect_validator_config(integration_client): VALIDATE_API_PATH, json={ "request_id": request_id, + "organization_id": organization_id, + "project_id": project_id, "input": "This sentence contains chakki.", "validators": [ { @@ -155,6 +173,8 @@ def test_input_guardrails_with_validator_actions_exception(integration_client): VALIDATE_API_PATH, json={ "request_id": request_id, + "organization_id": organization_id, + "project_id": project_id, "input": "This sentence contains chakki.", "validators": [ { @@ -179,6 +199,8 @@ def test_input_guardrails_with_validator_actions_rephrase(integration_client): VALIDATE_API_PATH, json={ "request_id": request_id, + "organization_id": organization_id, + "project_id": project_id, "input": "This sentence contains chakki.", "validators": [ { diff --git a/backend/app/tests/test_validate_with_guard.py b/backend/app/tests/test_validate_with_guard.py index 010ef24..1bcd70c 100644 --- a/backend/app/tests/test_validate_with_guard.py +++ b/backend/app/tests/test_validate_with_guard.py @@ -3,16 +3,33 @@ import pytest -from app.api.routes.guardrails import _validate_with_guard +from app.api.routes.guardrails import ( + _resolve_ban_list_banned_words, + _validate_with_guard, +) +from app.schemas.guardrail_config import GuardrailRequest from app.tests.guardrails_mocks import MockResult +from app.tests.seed_data import ( + VALIDATOR_TEST_ORGANIZATION_ID, + VALIDATOR_TEST_PROJECT_ID, +) from app.utils import APIResponse - mock_request_log_crud = MagicMock() mock_validator_log_crud = MagicMock() mock_request_log_id = uuid4() +def _build_payload(input_text: str) -> GuardrailRequest: + return GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input=input_text, + validators=[], + ) + + def test_validate_with_guard_success(): class MockGuard: def validate(self, data): @@ -23,8 +40,7 @@ def validate(self, data): return_value=MockGuard(), ): response = _validate_with_guard( - data="hello", - validators=[], + payload=_build_payload("hello"), request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -46,8 +62,7 @@ def validate(self, data): return_value=MockGuard(), ): response = _validate_with_guard( - data="bad text", - validators=[], + payload=_build_payload("bad text"), request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -65,8 +80,7 @@ def test_validate_with_guard_exception(): side_effect=Exception("Invalid config"), ): response = _validate_with_guard( - data="text", - validators=[], + payload=_build_payload("text"), request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -76,3 +90,45 @@ def test_validate_with_guard_exception(): assert response.success is False assert response.data.safe_text is None assert response.error == "Invalid config" + + +def test_resolve_ban_list_banned_words_from_ban_list_id(): + ban_list_id = str(uuid4()) + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="test", + validators=[{"type": "ban_list", "ban_list_id": ban_list_id}], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.ban_list_crud.get") as mock_get: + mock_get.return_value = MagicMock(banned_words=["foo", "bar"]) + _resolve_ban_list_banned_words(payload, mock_session) + + assert payload.validators[0].banned_words == ["foo", "bar"] + mock_get.assert_called_once_with( + mock_session, + id=payload.validators[0].ban_list_id, + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + ) + + +def test_resolve_ban_list_banned_words_skips_lookup_when_banned_words_provided(): + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="test", + validators=[ + {"type": "ban_list", "ban_list_id": str(uuid4()), "banned_words": ["foo"]} + ], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.ban_list_crud.get") as mock_get: + _resolve_ban_list_banned_words(payload, mock_session) + + mock_get.assert_not_called() diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py index 4697219..2ba0b94 100644 --- a/backend/app/tests/test_validator_configs.py +++ b/backend/app/tests/test_validator_configs.py @@ -1,20 +1,21 @@ -import uuid from unittest.mock import MagicMock import pytest from sqlmodel import Session from app.crud.validator_config import validator_config_crud -from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.core.enum import GuardrailOnFail, ValidatorType from app.models.config.validator_config import ValidatorConfig - -# Test data constants -TEST_ORGANIZATION_ID = 1 -TEST_PROJECT_ID = 1 -TEST_VALIDATOR_ID = uuid.uuid4() -TEST_TYPE = ValidatorType.LexicalSlur -TEST_STAGE = Stage.Input -TEST_ON_FAIL = GuardrailOnFail.Fix +from app.tests.seed_data import ( + VALIDATOR_TEST_CONFIG, + VALIDATOR_TEST_ID, + VALIDATOR_TEST_ON_FAIL, + VALIDATOR_TEST_ORGANIZATION_ID, + VALIDATOR_TEST_PROJECT_ID, + VALIDATOR_TEST_STAGE, + VALIDATOR_TEST_TYPE, + build_sample_validator_config, +) @pytest.fixture @@ -26,16 +27,7 @@ def mock_session(): @pytest.fixture def sample_validator(): """Create a sample validator config for testing.""" - return ValidatorConfig( - id=TEST_VALIDATOR_ID, - organization_id=TEST_ORGANIZATION_ID, - project_id=TEST_PROJECT_ID, - type=TEST_TYPE, - stage=TEST_STAGE, - on_fail_action=TEST_ON_FAIL, - is_enabled=True, - config={"severity": "all", "languages": ["en", "hi"]}, - ) + return build_sample_validator_config() class TestFlatten: @@ -44,16 +36,16 @@ def test_flatten_includes_config_fields(self, sample_validator): assert result["severity"] == "all" assert result["languages"] == ["en", "hi"] - assert result["id"] == TEST_VALIDATOR_ID + assert result["id"] == VALIDATOR_TEST_ID def test_flatten_empty_config(self): validator = ValidatorConfig( - id=TEST_VALIDATOR_ID, - organization_id=TEST_ORGANIZATION_ID, - project_id=TEST_PROJECT_ID, - type=TEST_TYPE, - stage=TEST_STAGE, - on_fail_action=TEST_ON_FAIL, + id=VALIDATOR_TEST_ID, + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + type=VALIDATOR_TEST_TYPE, + stage=VALIDATOR_TEST_STAGE, + on_fail_action=VALIDATOR_TEST_ON_FAIL, is_enabled=True, config={}, ) @@ -69,9 +61,9 @@ def test_success(self, sample_validator, mock_session): result = validator_config_crud.get( mock_session, - TEST_VALIDATOR_ID, - TEST_ORGANIZATION_ID, - TEST_PROJECT_ID, + VALIDATOR_TEST_ID, + VALIDATOR_TEST_ORGANIZATION_ID, + VALIDATOR_TEST_PROJECT_ID, ) assert result == sample_validator @@ -83,9 +75,9 @@ def test_not_found(self, mock_session): with pytest.raises(Exception) as exc: validator_config_crud.get( mock_session, - TEST_VALIDATOR_ID, - TEST_ORGANIZATION_ID, - TEST_PROJECT_ID, + VALIDATOR_TEST_ID, + VALIDATOR_TEST_ORGANIZATION_ID, + VALIDATOR_TEST_PROJECT_ID, ) assert "Validator not found" in str(exc.value) @@ -121,7 +113,7 @@ def test_update_extra_fields(self, sample_validator, mock_session): assert result["severity"] == "high" assert result["new_field"] == "new_value" - assert result["languages"] == ["en", "hi"] + assert result["languages"] == VALIDATOR_TEST_CONFIG["languages"] def test_merge_config(self, sample_validator, mock_session): sample_validator.config = {"severity": "all", "languages": ["en"]} diff --git a/backend/app/tests/test_validator_configs_integration.py b/backend/app/tests/test_validator_configs_integration.py index ccbdb62..73cf4ba 100644 --- a/backend/app/tests/test_validator_configs_integration.py +++ b/backend/app/tests/test_validator_configs_integration.py @@ -1,60 +1,20 @@ import uuid import pytest -from sqlmodel import Session, delete - -from app.core.db import engine -from app.models.config.validator_config import ValidatorConfig +from app.tests.seed_data import ( + VALIDATOR_INTEGRATION_ORGANIZATION_ID, + VALIDATOR_INTEGRATION_PROJECT_ID, + VALIDATOR_PAYLOADS, +) pytestmark = pytest.mark.integration -# Test data constants -TEST_ORGANIZATION_ID = 1 -TEST_PROJECT_ID = 1 BASE_URL = "/api/v1/guardrails/validators/configs/" DEFAULT_QUERY_PARAMS = ( - f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" + f"?organization_id={VALIDATOR_INTEGRATION_ORGANIZATION_ID}" + f"&project_id={VALIDATOR_INTEGRATION_PROJECT_ID}" ) -VALIDATOR_PAYLOADS = { - "lexical_slur": { - "type": "uli_slur_match", - "stage": "input", - "on_fail_action": "fix", - "severity": "all", - "languages": ["en", "hi"], - }, - "pii_remover_input": { - "type": "pii_remover", - "stage": "input", - "on_fail_action": "fix", - }, - "pii_remover_output": { - "type": "pii_remover", - "stage": "output", - "on_fail_action": "fix", - }, - "minimal": { - "type": "uli_slur_match", - "stage": "input", - "on_fail_action": "fix", - }, -} - - -@pytest.fixture -def clear_database(): - """Clear ValidatorConfig table before and after each test.""" - with Session(engine) as session: - session.exec(delete(ValidatorConfig)) - session.commit() - - yield - - with Session(engine) as session: - session.exec(delete(ValidatorConfig)) - session.commit() - class BaseValidatorTest: """Base class with helper methods for validator tests.""" @@ -71,7 +31,8 @@ def get_validator(self, client, validator_id): def list_validators(self, client, **query_params): """Helper to list validators with optional filters.""" params_str = ( - f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" + f"?organization_id={VALIDATOR_INTEGRATION_ORGANIZATION_ID}" + f"&project_id={VALIDATOR_INTEGRATION_PROJECT_ID}" ) if query_params: params_str += "&" + "&".join(f"{k}={v}" for k, v in query_params.items()) @@ -80,7 +41,7 @@ def list_validators(self, client, **query_params): def update_validator(self, client, validator_id, payload): """Helper to update a validator.""" return client.patch( - f"{BASE_URL}{validator_id}{DEFAULT_QUERY_PARAMS}", json=payload + f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}", json=payload ) def delete_validator(self, client, validator_id): @@ -130,40 +91,33 @@ def test_create_validator_missing_required_fields( class TestListValidators(BaseValidatorTest): """Tests for GET /guardrails/validators/configs endpoint.""" - def test_list_validators_success(self, integration_client, clear_database): + def test_list_validators_success(self, integration_client, seed_db): """Test successful validator listing.""" - # Create validators first - self.create_validator(integration_client, "lexical_slur") - self.create_validator(integration_client, "pii_remover_input") response = self.list_validators(integration_client) assert response.status_code == 200 data = response.json()["data"] - assert len(data) == 2 + assert len(data) == 4 - def test_list_validators_filter_by_stage(self, integration_client, clear_database): + def test_list_validators_filter_by_stage(self, integration_client, seed_db): """Test filtering validators by stage.""" - self.create_validator(integration_client, "lexical_slur") - self.create_validator(integration_client, "pii_remover_output") response = self.list_validators(integration_client, stage="input") assert response.status_code == 200 data = response.json()["data"] - assert len(data) == 1 + assert len(data) == 3 assert data[0]["stage"] == "input" - def test_list_validators_filter_by_type(self, integration_client, clear_database): + def test_list_validators_filter_by_type(self, integration_client, seed_db): """Test filtering validators by type.""" - self.create_validator(integration_client, "lexical_slur") - self.create_validator(integration_client, "pii_remover_input") response = self.list_validators(integration_client, type="pii_remover") assert response.status_code == 200 data = response.json()["data"] - assert len(data) == 1 + assert len(data) == 2 assert data[0]["type"] == "pii_remover" def test_list_validators_filter_by_ids(self, integration_client, clear_database): @@ -226,13 +180,10 @@ def test_list_validators_empty(self, integration_client, clear_database): class TestGetValidator(BaseValidatorTest): """Tests for GET /guardrails/validators/configs/{id} endpoint.""" - def test_get_validator_success(self, integration_client, clear_database): + def test_get_validator_success(self, integration_client, seed_db): """Test successful validator retrieval.""" - # Create a validator - create_response = self.create_validator( - integration_client, "lexical_slur", severity="all" - ) - validator_id = create_response.json()["data"]["id"] + list_response = self.list_validators(integration_client) + validator_id = list_response.json()["data"][0]["id"] # Retrieve it response = self.get_validator(integration_client, validator_id) @@ -261,7 +212,6 @@ def test_get_validator_invalid_id_returns_422( def test_get_validator_wrong_org(self, integration_client, clear_database): """Test that accessing validator from different org returns 404.""" - # Create a validator for org 1 create_response = self.create_validator(integration_client, "minimal") validator_id = create_response.json()["data"]["id"] @@ -276,13 +226,10 @@ def test_get_validator_wrong_org(self, integration_client, clear_database): class TestUpdateValidator(BaseValidatorTest): """Tests for PATCH /guardrails/validators/configs/{id} endpoint.""" - def test_update_validator_success(self, integration_client, clear_database): + def test_update_validator_success(self, integration_client, seed_db): """Test successful validator update.""" - # Create a validator - create_response = self.create_validator( - integration_client, "lexical_slur", severity="all" - ) - validator_id = create_response.json()["data"]["id"] + list_response = self.list_validators(integration_client) + validator_id = list_response.json()["data"][0]["id"] # Update it update_payload = {"on_fail_action": "exception", "is_enabled": False} @@ -295,16 +242,10 @@ def test_update_validator_success(self, integration_client, clear_database): assert data["on_fail_action"] == "exception" assert data["is_enabled"] is False - def test_update_validator_partial(self, integration_client, clear_database): + def test_update_validator_partial(self, integration_client, seed_db): """Test partial update preserves original fields.""" - # Create a validator - create_response = self.create_validator( - integration_client, - "lexical_slur", - severity="all", - languages=["en", "hi"], - ) - validator_id = create_response.json()["data"]["id"] + list_response = self.list_validators(integration_client) + validator_id = list_response.json()["data"][0]["id"] # Update only one field update_payload = {"is_enabled": False} @@ -330,11 +271,10 @@ def test_update_validator_not_found(self, integration_client, clear_database): class TestDeleteValidator(BaseValidatorTest): """Tests for DELETE /guardrails/validators/configs/{id} endpoint.""" - def test_delete_validator_success(self, integration_client, clear_database): + def test_delete_validator_success(self, integration_client, seed_db): """Test successful validator deletion.""" - # Create a validator - create_response = self.create_validator(integration_client, "minimal") - validator_id = create_response.json()["data"]["id"] + list_response = self.list_validators(integration_client) + validator_id = list_response.json()["data"][0]["id"] # Delete it response = self.delete_validator(integration_client, validator_id) @@ -353,11 +293,10 @@ def test_delete_validator_not_found(self, integration_client, clear_database): assert response.status_code == 404 - def test_delete_validator_wrong_org(self, integration_client, clear_database): + def test_delete_validator_wrong_org(self, integration_client, seed_db): """Test that deleting validator from different org returns 404.""" - # Create a validator for org 1 - create_response = self.create_validator(integration_client, "minimal") - validator_id = create_response.json()["data"]["id"] + list_response = self.list_validators(integration_client) + validator_id = list_response.json()["data"][0]["id"] # Try to delete it as different org response = integration_client.delete( diff --git a/backend/app/tests/validators/test_pii_remover.py b/backend/app/tests/validators/test_pii_remover.py index 46e59ca..0f919a7 100644 --- a/backend/app/tests/validators/test_pii_remover.py +++ b/backend/app/tests/validators/test_pii_remover.py @@ -4,7 +4,6 @@ from app.core.validators.pii_remover import ALL_ENTITY_TYPES, PIIRemover - # ------------------------------- # Fixtures # ------------------------------- diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 5df006c..b335986 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -33,8 +33,8 @@ dependencies = [ "scikit-learn>=1.6.0,<2.0.0", ] -[tool.uv] -dev-dependencies = [ +[dependency-groups] +dev = [ "pytest<8.0.0,>=7.4.3", "mypy<2.0.0,>=1.8.0", "ruff<1.0.0,>=0.2.2",