diff --git a/backend/app/alembic/versions/048_create_llm_chain_table.py b/backend/app/alembic/versions/048_create_llm_chain_table.py new file mode 100644 index 000000000..ac49eb0ec --- /dev/null +++ b/backend/app/alembic/versions/048_create_llm_chain_table.py @@ -0,0 +1,181 @@ +"""Create llm_chain table + +Revision ID: 048 +Revises: 047 +Create Date: 2026-02-20 00:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB + +revision = "048" +down_revision = "047" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # 1. Create llm_chain table + op.create_table( + "llm_chain", + sa.Column( + "id", + sa.Uuid(), + nullable=False, + comment="Unique identifier for the LLM chain record", + ), + sa.Column( + "job_id", + sa.Uuid(), + nullable=False, + comment="Reference to the parent job (status tracked in job table)", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project this LLM call belongs to", + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization this LLM call belongs to", + ), + sa.Column( + "status", + sa.String(), + nullable=False, + server_default="pending", + comment="Chain execution status (pending, running, failed, completed)", + ), + sa.Column( + "error", + sa.Text(), + nullable=True, + comment="Error message if the chain execution failed", + ), + sa.Column( + "block_sequences", + JSONB(), + nullable=True, + comment="Ordered list of llm_call UUIDs as blocks complete", + ), + sa.Column( + "total_blocks", + sa.Integer(), + nullable=False, + comment="Total number of blocks to execute", + ), + sa.Column( + "number_of_blocks_processed", + sa.Integer(), + nullable=False, + server_default="0", + comment="Number of blocks processed so far (used for tracking progress)", + ), + sa.Column( + "input", + sa.String(), + nullable=False, + comment="First block user's input - text string, binary data, or file path for multimodal", + ), + sa.Column( + "output", + JSONB(), + nullable=True, + comment="Last block's final output (set on chain completion)", + ), + sa.Column( + "configs", + JSONB(), + nullable=True, + comment="Ordered list of block configs as submitted in the request", + ), + sa.Column( + "total_usage", + JSONB(), + nullable=True, + comment="Aggregated token usage: {input_tokens, output_tokens, total_tokens}", + ), + sa.Column( + "metadata", + JSONB(), + nullable=True, + comment="Future-proof extensibility catch-all", + ), + sa.Column( + "started_at", + sa.DateTime(), + nullable=True, + comment="Timestamp when chain execution started", + ), + sa.Column( + "completed_at", + sa.DateTime(), + nullable=True, + comment="Timestamp when chain execution completed", + ), + sa.Column( + "created_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the chain record was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the chain record was last updated", + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["job_id"], ["job.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], ["organization.id"], ondelete="CASCADE" + ), + ) + + op.create_index( + "idx_llm_chain_job_id", + "llm_chain", + ["job_id"], + ) + + # 2. Add chain_id FK column to llm_call table + op.add_column( + "llm_call", + sa.Column( + "chain_id", + sa.Uuid(), + nullable=True, + comment="Reference to the parent chain (NULL for standalone /llm/call requests)", + ), + ) + op.create_foreign_key( + "fk_llm_call_chain_id", + "llm_call", + "llm_chain", + ["chain_id"], + ["id"], + ondelete="SET NULL", + ) + op.create_index( + "idx_llm_call_chain_id", + "llm_call", + ["chain_id"], + postgresql_where=sa.text("chain_id IS NOT NULL"), + ) + + op.execute("ALTER TYPE jobtype ADD VALUE IF NOT EXISTS 'LLM_CHAIN'") + + +def downgrade() -> None: + op.drop_index("idx_llm_call_chain_id", table_name="llm_call") + op.drop_constraint("fk_llm_call_chain_id", "llm_call", type_="foreignkey") + op.drop_column("llm_call", "chain_id") + + op.drop_index("idx_llm_chain_job_id", table_name="llm_chain") + op.drop_table("llm_chain") diff --git a/backend/app/api/docs/llm/llm_chain.md b/backend/app/api/docs/llm/llm_chain.md new file mode 100644 index 000000000..d6c17893c --- /dev/null +++ b/backend/app/api/docs/llm/llm_chain.md @@ -0,0 +1,60 @@ +Execute a chain of LLM calls sequentially, where each block's output becomes the next block's input. + +This endpoint initiates an asynchronous LLM chain job. The request is queued +for processing, and results are delivered via the callback URL when complete. + +### Key Parameters + +**`query`** (required) - Initial query input for the first block in the chain: +- `input` (required, string, min 1 char): User question/prompt/query +- `conversation` (optional, object): Conversation configuration + - `id` (optional, string): Existing conversation ID to continue + - `auto_create` (optional, boolean, default false): Create new conversation if no ID provided + - **Note**: Cannot specify both `id` and `auto_create=true` + + +**`blocks`** (required, array, min 1 block) - Ordered list of blocks to execute sequentially. Each block contains: + +- `config` (required) - Configuration for this block's LLM call (just choose one mode): + + - **Mode 1: Stored Configuration** + - `id` (UUID): Configuration ID + - `version` (integer >= 1): Version number + - **Both required together** + - **Note**: When using stored configuration, do not include the `blob` field in the request body + + - **Mode 2: Ad-hoc Configuration** + - `blob` (object): Complete configuration object + - `completion` (required, object): Completion configuration + - `provider` (required, string): Provider type - either `"openai"` (Kaapi abstraction) or `"openai-native"` (pass-through) + - `params` (required, object): Parameters structure depends on provider type (see schema for detailed structure) + - `prompt_template` (optional, object): Template for text interpolation + - `template` (required, string): Template string with `{{input}}` placeholder — replaced with the block's input before execution + - **Note** + - When using ad-hoc configuration, do not include `id` and `version` fields + - When using the Kaapi abstraction, parameters that are not supported by the selected provider or model are automatically suppressed. If any parameters are ignored, a list of warnings is included in the metadata.warnings. + - **Recommendation**: Use stored configs (Mode 1) for production; use ad-hoc configs only for testing/validation + - **Schema**: Check the API schema or examples below for the complete parameter structure for each provider type + +- `include_provider_raw_response` (optional, boolean, default false): + - When true, includes the unmodified raw response from the LLM provider for this block + +- `intermediate_callback` (optional, boolean, default false): + - When true, sends an intermediate callback after this block completes with the block's response, usage, and position in the chain + +**`callback_url`** (optional, HTTPS URL): +- Webhook endpoint to receive the final response and intermediate callbacks +- Must be a valid HTTPS URL +- If not provided, response is only accessible through job status + +**`request_metadata`** (optional, object): +- Custom JSON metadata +- Passed through unchanged in the response + +### Note +- Input guardrails from the first block's config are applied before chain execution starts +- Output guardrails from the last block's config are applied after all blocks complete +- If any block fails, the chain stops immediately — no subsequent blocks are executed +- `warnings` list is automatically added in response metadata when using Kaapi configs if any parameters are suppressed or adjusted (e.g., temperature on reasoning models) + +--- diff --git a/backend/app/api/main.py b/backend/app/api/main.py index ed58e57f2..5ab1cbd9e 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -10,6 +10,7 @@ login, languages, llm, + llm_chain, organization, openai_conversation, project, @@ -41,6 +42,7 @@ api_router.include_router(evaluations.router) api_router.include_router(languages.router) api_router.include_router(llm.router) +api_router.include_router(llm_chain.router) api_router.include_router(login.router) api_router.include_router(onboarding.router) api_router.include_router(openai_conversation.router) diff --git a/backend/app/api/routes/llm_chain.py b/backend/app/api/routes/llm_chain.py new file mode 100644 index 000000000..92a3cdb4d --- /dev/null +++ b/backend/app/api/routes/llm_chain.py @@ -0,0 +1,62 @@ +import logging + +from fastapi import APIRouter, Depends +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.models import LLMChainRequest, LLMChainResponse, Message +from app.services.llm.jobs import start_chain_job +from app.utils import APIResponse, validate_callback_url, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["LLM Chain"]) +llm_callback_router = APIRouter() + + +@llm_callback_router.post( + "{$callback_url}", + name="llm_chain_callback", +) +def llm_callback_notification(body: APIResponse[LLMChainResponse]): + """ + Callback endpoint specification for LLM chain completion. + + The callback will receive: + - On success: APIResponse with success=True and data containing LLMChainResponse + - On failure: APIResponse with success=False and error message + - metadata field will always be included if provided in the request + """ + ... + + +@router.post( + "/llm/chain", + description=load_description("llm/llm_chain.md"), + response_model=APIResponse[Message], + callbacks=llm_callback_router.routes, + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def llm_chain( + _current_user: AuthContextDep, _session: SessionDep, request: LLMChainRequest +): + """ + Endpoint to initiate an LLM chain as a background job. + """ + project_id = _current_user.project_.id + organization_id = _current_user.organization_.id + + if request.callback_url: + validate_callback_url(str(request.callback_url)) + + start_chain_job( + db=_session, + request=request, + project_id=project_id, + organization_id=organization_id, + ) + + return APIResponse.success_response( + data=Message( + message="Your response is being generated and will be delivered via callback." + ), + ) diff --git a/backend/app/crud/llm.py b/backend/app/crud/llm.py index b5c23cd6e..32f8ca46f 100644 --- a/backend/app/crud/llm.py +++ b/backend/app/crud/llm.py @@ -53,6 +53,7 @@ def create_llm_call( *, request: LLMCallRequest, job_id: UUID, + chain_id: UUID | None = None, project_id: int, organization_id: int, resolved_config: ConfigBlob, @@ -120,6 +121,7 @@ def create_llm_call( job_id=job_id, project_id=project_id, organization_id=organization_id, + chain_id=chain_id, input=serialize_input(request.query.input), input_type=input_type, output_type=output_type, diff --git a/backend/app/crud/llm_chain.py b/backend/app/crud/llm_chain.py new file mode 100644 index 000000000..77ab70987 --- /dev/null +++ b/backend/app/crud/llm_chain.py @@ -0,0 +1,151 @@ +import logging +from typing import Any +from uuid import UUID + +from sqlmodel import Session + +from app.core.util import now +from app.models.llm.request import ChainStatus, LlmChain + +logger = logging.getLogger(__name__) + + +def create_llm_chain( + session: Session, + *, + job_id: UUID, + project_id: int, + organization_id: int, + total_blocks: int, + input: str, + configs: list[dict[str, Any]], +) -> LlmChain: + """Create a new LLM chain record. + Args: + session: Database session + job_id: Reference to the parent job + project_id: Reference to the project + organization_id: Reference to the organization + total_blocks: Total number of blocks to execute + input: Serialized input string (via serialize_input) + configs: Ordered list of block configs as submitted + + Returns: + LlmChain: The created chain record + """ + db_llm_chain = LlmChain( + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + status=ChainStatus.PENDING, + total_blocks=total_blocks, + number_of_blocks_processed=0, + input=input, + configs=configs, + block_sequences=[], + ) + + session.add(db_llm_chain) + session.commit() + session.refresh(db_llm_chain) + + logger.info( + f"[create_llm_chain] Created LLM chain id={db_llm_chain.id}, " + f"job_id={job_id}, total_blocks={total_blocks}" + ) + + return db_llm_chain + + +def update_llm_chain_status( + session: Session, + *, + chain_id: UUID, + status: ChainStatus, + output: dict[str, Any] | None = None, + total_usage: dict[str, Any] | None = None, + error: str | None = None, +) -> LlmChain: + """Update chain record status and related fields. + Args: + session: Database session + chain_id: The chain record ID + status: New chain status + output: Last block's output dict (only for COMPLETED) + total_usage: Aggregated token usage across all blocks (for COMPLETED/FAILED) + error: Error message (only for FAILED) + + Returns: + LlmChain: The updated chain record + """ + db_chain = session.get(LlmChain, chain_id) + if not db_chain: + raise ValueError(f"LLM chain not found with id={chain_id}") + + db_chain.status = status + db_chain.updated_at = now() + + if status == ChainStatus.RUNNING: + db_chain.started_at = now() + + if status == ChainStatus.FAILED: + db_chain.error = error + db_chain.total_usage = total_usage + db_chain.completed_at = now() + + if status == ChainStatus.COMPLETED: + db_chain.output = output + db_chain.total_usage = total_usage + db_chain.completed_at = now() + + session.add(db_chain) + session.commit() + session.refresh(db_chain) + + logger.info( + f"[update_llm_chain_status] Chain {chain_id} → {status.value} | " + f"has_output={output is not None}, " + f"blocks={db_chain.number_of_blocks_processed}/{db_chain.total_blocks}, " + f"error={error}" + ) + return db_chain + + +def update_llm_chain_block_completed( + session: Session, + *, + chain_id: UUID, + llm_call_id: UUID, +) -> LlmChain: + """Update chain progress after a block completes. + Args: + session: Database session + chain_id: The chain record ID + llm_call_id: The llm_call record ID for the completed block + + Returns: + LlmChain: The updated chain record + """ + db_chain = session.get(LlmChain, chain_id) + if not db_chain: + raise ValueError(f"LLM chain not found with id={chain_id}") + + # Append to block_sequences + sequences = list(db_chain.block_sequences or []) + sequences.append(str(llm_call_id)) + db_chain.block_sequences = sequences + + # Increment progress + db_chain.number_of_blocks_processed = len(sequences) + db_chain.updated_at = now() + + session.add(db_chain) + session.commit() + session.refresh(db_chain) + + logger.info( + f"[update_llm_chain_block_completed] Chain {chain_id} | " + f"block={db_chain.number_of_blocks_processed}/{db_chain.total_blocks}, " + f"llm_call_id={llm_call_id}" + ) + return db_chain diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 2c28d7b4f..c76a02579 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -111,6 +111,9 @@ LLMCallRequest, LLMCallResponse, LlmCall, + LLMChainRequest, + LLMChainResponse, + LlmChain, ) from .message import Message diff --git a/backend/app/models/job.py b/backend/app/models/job.py index b6a1a5ae7..3b20249f5 100644 --- a/backend/app/models/job.py +++ b/backend/app/models/job.py @@ -17,6 +17,7 @@ class JobStatus(str, Enum): class JobType(str, Enum): RESPONSE = "RESPONSE" LLM_API = "LLM_API" + LLM_CHAIN = "LLM_CHAIN" class Job(SQLModel, table=True): diff --git a/backend/app/models/llm/__init__.py b/backend/app/models/llm/__init__.py index b183543c4..9bcf3a035 100644 --- a/backend/app/models/llm/__init__.py +++ b/backend/app/models/llm/__init__.py @@ -9,6 +9,13 @@ LlmCall, AudioContent, TextContent, + TextInput, + AudioInput, + PromptTemplate, + ChainBlock, + ChainStatus, + LLMChainRequest, + LlmChain, ) from app.models.llm.response import ( LLMCallResponse, @@ -17,4 +24,6 @@ Usage, TextOutput, AudioOutput, + LLMChainResponse, + IntermediateChainResponse, ) diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index b90fb6229..d6abd7d8d 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Annotated, Any, Literal, Union from uuid import UUID, uuid4 @@ -214,11 +215,21 @@ class Validator(SQLModel): validator_config_id: UUID +class PromptTemplate(SQLModel): + template: str = Field(..., description="Template string with {{input}} placeholder") + + class ConfigBlob(SQLModel): """Raw JSON blob of config.""" completion: CompletionConfig = Field(..., description="Completion configuration") + # used for llm-chain to provide prompt interpolation + prompt_template: PromptTemplate | None = Field( + default=None, + description="Prompt template with {{input}} placeholder to wrap around the user input", + ) + input_guardrails: list[Validator] | None = Field( default=None, description="Guardrails applied to validate/sanitize the input before the LLM call", @@ -384,6 +395,16 @@ class LlmCall(SQLModel, table=True): }, ) + chain_id: UUID | None = Field( + default=None, + foreign_key="llm_chain.id", + nullable=True, + ondelete="SET NULL", + sa_column_kwargs={ + "comment": "Reference to the parent chain (NULL for standalone llm_call requests)" + }, + ) + # Request fields input: str = Field( ..., @@ -496,3 +517,213 @@ class LlmCall(SQLModel, table=True): nullable=True, sa_column_kwargs={"comment": "Timestamp when the record was soft-deleted"}, ) + + +class ChainBlock(SQLModel): + """A single block in an LLM chain execution.""" + + config: LLMCallConfig = Field( + ..., description="LLM call configuration (stored id+version OR ad-hoc blob)" + ) + + include_provider_raw_response: bool = Field( + default=False, + description="Whether to include the raw LLM provider response in the output for this block", + ) + + intermediate_callback: bool = Field( + default=False, + description="Whether to send intermediate callback after this block completes", + ) + + +class LLMChainRequest(SQLModel): + """ + API request for an LLM chain execution. + + Orchestrates multiple LLM calls sequentially where each block's output + becomes the next block's input. + """ + + query: QueryParams = Field( + ..., description="Initial query input for the first block in the chain" + ) + + blocks: list[ChainBlock] = Field( + ..., min_length=1, description="Ordered list of blocks to execute sequentially" + ) + + callback_url: HttpUrl | None = Field( + default=None, description="Webhook URL for async response delivery" + ) + + request_metadata: dict[str, Any] | None = Field( + default=None, + description=( + "Client-provided metadata passed through unchanged in the response. " + "Use this to correlate responses with requests or track request state. " + "The exact dictionary provided here will be returned in the response metadata field." + ), + ) + + +class ChainStatus(str, Enum): + """Status of an LLM chain execution.""" + + PENDING = "pending" + RUNNING = "running" + FAILED = "failed" + COMPLETED = "completed" + + +class LlmChain(SQLModel, table=True): + """ + Database model for tracking LLM chain execution + + it manages and orchestrates sequential llm_call executions. + """ + + __tablename__ = "llm_chain" + __table_args__ = ( + Index( + "idx_llm_chain_job_id", + "job_id", + ), + ) + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the LLM chain record"}, + ) + + job_id: UUID = Field( + foreign_key="job.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the parent job (status tracked in job table)" + }, + ) + + project_id: int = Field( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the project this LLM call belongs to" + }, + ) + + organization_id: int = Field( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the organization this LLM call belongs to" + }, + ) + + status: ChainStatus = Field( + default=ChainStatus.PENDING, + sa_column_kwargs={ + "comment": "Chain execution status (pending, running, failed, completed)" + }, + ) + + error: str | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Error message if the chain execution failed"}, + ) + + block_sequences: list[str] | None = Field( + default_factory=list, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Ordered list of llm_call UUIDs as blocks complete", + ), + ) + + total_blocks: int = Field( + ..., sa_column_kwargs={"comment": "Total number of blocks to execute"} + ) + + number_of_blocks_processed: int = Field( + default=0, + sa_column_kwargs={ + "comment": "Number of blocks processed so far (used for tracking progress)" + }, + ) + + # Request fields + input: str = Field( + ..., + sa_column_kwargs={ + "comment": "First block user's input - text string, binary data, or file path for multimodal" + }, + ) + + output: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Last block's final output (set on chain completion)", + ), + ) + + configs: list[dict[str, Any]] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Ordered list of block configs as submitted in the request", + ), + ) + + total_usage: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Aggregated token usage: {input_tokens, output_tokens, total_tokens}", + ), + ) + + metadata_: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + "metadata", + JSONB, + nullable=True, + comment="Future-proof extensibility catch-all", + ), + ) + + started_at: datetime | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Timestamp when chain execution started"}, + ) + + completed_at: datetime | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Timestamp when chain execution completed"}, + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the chain record was created"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={ + "comment": "Timestamp when the chain record was last updated" + }, + ) diff --git a/backend/app/models/llm/response.py b/backend/app/models/llm/response.py index 7b13e301c..1ae7619f6 100644 --- a/backend/app/models/llm/response.py +++ b/backend/app/models/llm/response.py @@ -62,3 +62,42 @@ class LLMCallResponse(SQLModel): default=None, description="Unmodified raw response from the LLM provider.", ) + + +class LLMChainResponse(SQLModel): + """Response schema for an LLM chain execution.""" + + response: LLMResponse = Field( + ..., description="LLM response from the final step of the chain execution." + ) + usage: Usage = Field( + ..., + description="Aggregate token usage and cost for the entire chain execution.", + ) + provider_raw_response: dict[str, object] | None = Field( + default=None, + description="Raw provider response from the last block (if requested)", + ) + + +class IntermediateChainResponse(SQLModel): + """ + Intermediate callback response from the intermediate blocks + from the llm chain execution. (if configured) + + Flattend structure matching LLMCallResponse keys for consistency + """ + + type: Literal["intermediate"] = "intermediate" + block_index: int = Field(..., description="Current block position") + total_blocks: int = Field(..., description="Total number of blocks in the chain") + response: LLMResponse = Field( + ..., description="LLM Response from the current block" + ) + usage: Usage = Field( + ..., description="Token usage and cost information from the current block" + ) + provider_raw_response: dict[str, object] | None = Field( + default=None, + description="Unmodified raw response from the LLM provider from the current block", + ) diff --git a/backend/app/services/llm/chain/__init__.py b/backend/app/services/llm/chain/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/app/services/llm/chain/chain.py b/backend/app/services/llm/chain/chain.py new file mode 100644 index 000000000..390247d8d --- /dev/null +++ b/backend/app/services/llm/chain/chain.py @@ -0,0 +1,221 @@ +import logging +from dataclasses import dataclass, field +from typing import Any +from uuid import UUID + +from sqlmodel import Session + +from app.core.db import engine +from app.crud.llm_chain import update_llm_chain_block_completed +from app.models.llm.request import ( + LLMCallConfig, + QueryParams, + TextInput, + TextContent, + AudioInput, +) +from app.models.llm.response import ( + IntermediateChainResponse, + TextOutput, + AudioOutput, + Usage, +) +from app.services.llm.chain.types import BlockResult +from app.services.llm.jobs import execute_llm_call +from app.utils import APIResponse, send_callback + + +logger = logging.getLogger(__name__) + + +@dataclass +class ChainContext: + """Shared state passed to all blocks. Accumulates responses.""" + + job_id: UUID + chain_id: UUID + project_id: int + organization_id: int + callback_url: str + total_blocks: int + + langfuse_credentials: dict[str, Any] | None = None + request_metadata: dict | None = None + intermediate_callback_flags: list[bool] = field(default_factory=list) + aggregated_usage: Usage = field( + default_factory=lambda: Usage( + input_tokens=0, + output_tokens=0, + total_tokens=0, + ) + ) + + def on_block_completed(self, block_index: int, result: BlockResult) -> None: + """Called after each block completes. Updates chain state in DB and sends intermediate callback.""" + + if result.usage: + self.aggregated_usage.input_tokens += result.usage.input_tokens + self.aggregated_usage.output_tokens += result.usage.output_tokens + self.aggregated_usage.total_tokens += result.usage.total_tokens + + if result.success and result.llm_call_id: + with Session(engine) as session: + update_llm_chain_block_completed( + session, + chain_id=self.chain_id, + llm_call_id=result.llm_call_id, + ) + + if ( + block_index < len(self.intermediate_callback_flags) + and self.intermediate_callback_flags[block_index] + and self.callback_url + ): + self._send_intermediate_callback(block_index, result) + + def _send_intermediate_callback( + self, block_index: int, result: BlockResult + ) -> None: + """Send intermediate callback for a completed block.""" + try: + intermediate = IntermediateChainResponse( + block_index=block_index + 1, + total_blocks=self.total_blocks, + response=result.response.response, + usage=result.usage, + provider_raw_response=result.response.provider_raw_response, + ) + callback_data = APIResponse.success_response( + data=intermediate, + metadata=self.request_metadata, + ) + send_callback( + callback_url=self.callback_url, + data=callback_data.model_dump(), + ) + logger.info( + f"[ChainContext] Sent intermediate callback | " + f"block={block_index + 1}/{self.total_blocks}, job_id={self.job_id}" + ) + except Exception as e: + logger.warning( + f"[ChainContext] Failed to send intermediate callback: {e} | " + f"block={block_index + 1}/{self.total_blocks}, job_id={self.job_id}" + ) + + +def result_to_query(result: BlockResult) -> QueryParams: + """Convert a block's output into the next block's QueryParams. + + Text output → TextInput query + Audio output → AudioInput query + """ + output = result.response.response.output + + if isinstance(output, TextOutput): + return QueryParams( + input=TextInput(content=TextContent(value=output.content.value)) + ) + elif isinstance(output, AudioOutput): + return QueryParams(input=AudioInput(content=output.content)) + else: + raise ValueError(f"Cannot chain output type: {output.type}") + + +class ChainBlock: + """A single node in the linked chain. + + Wraps execute_block() with linking capability. + Each block knows its next block and forwards output to it. + """ + + def __init__( + self, + *, + config: LLMCallConfig, + index: int, + context: ChainContext, + include_provider_raw_response: bool = False, + ): + self._config = config + self._index = index + self._context = context + self._include_provider_raw_response = include_provider_raw_response + self._next: ChainBlock | None = None + + def link(self, next_block: "ChainBlock") -> "ChainBlock": + """Link to the next block in the chain.""" + self._next = next_block + return next_block + + def execute(self, query: QueryParams) -> BlockResult: + """Execute this block, then flow to next. + + No loop. Each block calls the next via the linked reference. + Data flows through the chain like a linked list traversal. + """ + logger.info( + f"[ChainBlock.execute] Executing block {self._index} | " + f"job_id={self._context.job_id}" + ) + + result = execute_llm_call( + config=self._config, + query=query, + job_id=self._context.job_id, + project_id=self._context.project_id, + organization_id=self._context.organization_id, + request_metadata=self._context.request_metadata, + langfuse_credentials=self._context.langfuse_credentials, + include_provider_raw_response=self._include_provider_raw_response, + chain_id=self._context.chain_id, + ) + + self._context.on_block_completed(self._index, result) + + if not result.success: + logger.error( + f"[ChainBlock.execute] Block {self._index} failed: {result.error} | " + f"job_id={self._context.job_id}" + ) + return result + + if self._next: + next_query = result_to_query(result) + return self._next.execute(next_query) + + logger.info( + f"[ChainBlock.execute] Block {self._index} is the last block | " + f"job_id={self._context.job_id}" + ) + return result + + +class LLMChain: + """Links ChainBlocks together into a sequential chain. + + Construction builds the linked structure. + Execution pushes input into the head — it flows through to the tail. + """ + + def __init__(self, blocks: list[ChainBlock]): + self._head: ChainBlock | None = None + self._tail: ChainBlock | None = None + self._link_blocks(blocks) + + def _link_blocks(self, blocks: list[ChainBlock]) -> None: + """Link all blocks in sequence.""" + if not blocks: + return + self._head = blocks[0] + self._tail = blocks[-1] + prev = blocks[0] + for curr in blocks[1:]: + prev.link(curr) + prev = curr + + def execute(self, query: QueryParams) -> BlockResult: + """Push input into the chain head. It flows through to the tail.""" + if not self._head: + return BlockResult(error="Chain has no blocks") + return self._head.execute(query) diff --git a/backend/app/services/llm/chain/executor.py b/backend/app/services/llm/chain/executor.py new file mode 100644 index 000000000..25208aad6 --- /dev/null +++ b/backend/app/services/llm/chain/executor.py @@ -0,0 +1,130 @@ +import logging + +from sqlmodel import Session + +from app.core.db import engine +from app.crud.jobs import JobCrud +from app.crud.llm_chain import update_llm_chain_status +from app.models import JobStatus, JobUpdate +from app.models.llm.request import ( + ChainStatus, + LLMChainRequest, +) +from app.models.llm.response import LLMChainResponse +from app.services.llm.chain.chain import ChainContext, LLMChain +from app.services.llm.chain.types import BlockResult +from app.utils import APIResponse, send_callback + +logger = logging.getLogger(__name__) + + +class ChainExecutor: + """Manage the lifecycle of an LLM chain execution.""" + + def __init__( + self, + *, + chain: LLMChain, + context: ChainContext, + request: LLMChainRequest, + ): + self._chain = chain + self._context = context + self._request = request + + def run(self) -> dict: + """Execute the full chain lifecycle. Returns serialized APIResponse.""" + try: + self._setup() + + result = self._chain.execute(self._request.query) + + return self._teardown(result) + + except Exception as e: + return self._handle_unexpected_error(e) + + def _setup(self) -> None: + with Session(engine) as session: + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.PROCESSING), + ) + + update_llm_chain_status( + session=session, + chain_id=self._context.chain_id, + status=ChainStatus.RUNNING, + ) + + def _teardown(self, result: BlockResult) -> dict: + """Finalize chain record, send callback, and update job status.""" + + with Session(engine) as session: + if result.success: + final = LLMChainResponse( + response=result.response.response, + usage=result.usage, + provider_raw_response=result.response.provider_raw_response, + ) + callback_response = APIResponse.success_response( + data=final, metadata=self._request.request_metadata + ) + if self._request.callback_url: + send_callback( + callback_url=str(self._request.callback_url), + data=callback_response.model_dump(), + ) + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.SUCCESS), + ) + update_llm_chain_status( + session=session, + chain_id=self._context.chain_id, + status=ChainStatus.COMPLETED, + output=result.response.response.output.model_dump(), + total_usage=self._context.aggregated_usage.model_dump(), + ) + return callback_response.model_dump() + else: + return self._handle_error(result.error) + + def _handle_error(self, error: str) -> dict: + callback_response = APIResponse.failure_response( + error=error or "Unknown error occurred", + metadata=self._request.request_metadata, + ) + logger.error( + f"[ChainExecutor] Chain execution failed | " + f"chain_id={self._context.chain_id}, job_id={self._context.job_id}, error={error}" + ) + + with Session(engine) as session: + if self._request.callback_url: + send_callback( + callback_url=str(self._request.callback_url), + data=callback_response.model_dump(), + ) + + update_llm_chain_status( + session, + chain_id=self._context.chain_id, + status=ChainStatus.FAILED, + output=None, + total_usage=self._context.aggregated_usage.model_dump(), + error=error, + ) + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.FAILED, error_message=error), + ) + return callback_response.model_dump() + + def _handle_unexpected_error(self, e: Exception) -> dict: + logger.error( + f"[ChainExecutor.run] Unexpected error: {e} | " + f"job_id={self._context.job_id}", + exc_info=True, + ) + return self._handle_error("Unexpected error occurred") diff --git a/backend/app/services/llm/chain/types.py b/backend/app/services/llm/chain/types.py new file mode 100644 index 000000000..69ab3d02f --- /dev/null +++ b/backend/app/services/llm/chain/types.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from uuid import UUID + +from app.models.llm.response import LLMCallResponse, Usage + + +@dataclass +class BlockResult: + """Result of a single block/LLM call execution.""" + + response: LLMCallResponse | None = None + llm_call_id: UUID | None = None + usage: Usage | None = None + error: str | None = None + + @property + def success(self) -> bool: + return self.error is None and self.response is not None diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index c6997a084..196bdd60b 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -11,23 +11,26 @@ from app.crud.config import ConfigVersionCrud from app.crud.credentials import get_provider_credential from app.crud.jobs import JobCrud -from app.crud.llm import create_llm_call, update_llm_call_response -from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, Job +from app.crud.llm import create_llm_call, serialize_input, update_llm_call_response +from app.crud.llm_chain import create_llm_chain, update_llm_chain_status +from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMChainRequest from app.models.llm.request import ( + ChainStatus, ConfigBlob, - LLMCallConfig, KaapiCompletionConfig, + LLMCallConfig, + QueryParams, TextInput, ) from app.models.llm.response import TextOutput +from app.services.llm.chain.types import BlockResult from app.services.llm.guardrails import ( list_validators_config, run_guardrails_validation, ) -from app.services.llm.providers.registry import get_llm_provider +from app.services.llm.input_resolver import cleanup_temp_file, resolve_input from app.services.llm.mappers import transform_kaapi_config_to_native -from app.services.llm.input_resolver import resolve_input, cleanup_temp_file - +from app.services.llm.providers.registry import get_llm_provider from app.utils import APIResponse, send_callback logger = logging.getLogger(__name__) @@ -75,6 +78,49 @@ def start_job( return job.id +def start_chain_job( + db: Session, request: LLMChainRequest, project_id: int, organization_id: int +) -> UUID: + """Create an LLM Chain job and schedule Celery task.""" + trace_id = correlation_id.get() or "N/A" + job_crud = JobCrud(session=db) + job = job_crud.create(job_type=JobType.LLM_CHAIN, trace_id=trace_id) + + # Explicitly flush to ensure job is persisted before Celery task starts + db.flush() + db.commit() + + logger.info( + f"[start_chain_job] Created job | job_id={job.id}, status={job.status}, project_id={project_id}" + ) + + try: + task_id = start_high_priority_job( + function_path="app.services.llm.jobs.execute_chain_job", + project_id=project_id, + job_id=str(job.id), + trace_id=trace_id, + request_data=request.model_dump(mode="json"), + organization_id=organization_id, + ) + except Exception as e: + logger.error( + f"[start_chain_job] Error starting Celery task: {str(e)} | job_id={job.id}, project_id={project_id}", + exc_info=True, + ) + job_update = JobUpdate(status=JobStatus.FAILED, error_message=str(e)) + job_crud.update(job_id=job.id, job_update=job_update) + raise HTTPException( + status_code=500, + detail="Internal server error while executing LLM chain job", + ) + + logger.info( + f"[start_chain_job] Job scheduled for LLM chain job | job_id={job.id}, project_id={project_id}, task_id={task_id}" + ) + return job.id + + def handle_job_error( job_id: UUID, callback_url: str | None, @@ -136,226 +182,233 @@ def resolve_config_blob( return None, "Unexpected error occurred while parsing stored configuration" -def execute_job( - request_data: dict, +def apply_input_guardrails( + *, + config_blob: ConfigBlob | None, + query: QueryParams, + job_id: UUID, project_id: int, organization_id: int, - job_id: str, - task_id: str, - task_instance, -) -> dict: - """Celery task to process an LLM request asynchronously. +) -> tuple[QueryParams, str | None]: + """Apply input guardrails from a config_blob. Shared with llm-call and llm-chain.""" + if not config_blob or not config_blob.input_guardrails: + return query, None + + if not isinstance(query.input, TextInput): + logger.info( + f"[apply_input_guardrails] Skipping for non-text input. " + f"job_id={job_id}, " + f"input_type={getattr(query.input, 'type', type(query.input).__name__)}" + ) + return query, None - Returns: - dict: Serialized APIResponse[LLMCallResponse] on success, APIResponse[None] on failure + input_guardrails, _ = list_validators_config( + organization_id=organization_id, + project_id=project_id, + input_validator_configs=config_blob.input_guardrails, + output_validator_configs=None, + ) + + if not input_guardrails: + return query, None + + safe = run_guardrails_validation( + query.input.content.value, + input_guardrails, + job_id, + project_id, + organization_id, + suppress_pass_logs=True, + ) + + logger.info( + f"[apply_input_guardrails] Validation result | success={safe['success']}, job_id={job_id}" + ) + + if safe.get("bypassed"): + logger.info( + f"[apply_input_guardrails] Guardrails bypassed (service unavailable) | job_id={job_id}" + ) + return query, None + + if safe["success"]: + query.input.content.value = safe["data"]["safe_text"] + return query, None + + return query, safe["error"] + + +def apply_output_guardrails( + *, + config_blob: ConfigBlob | None, + result: BlockResult, + job_id: UUID, + project_id: int, + organization_id: int, +) -> tuple[BlockResult, str | None]: + """Apply output guardrails from a config_blob. Shared by /llm/call and /llm/chain. + + Returns (modified_result, None) on success, or (result, error_string) on failure. """ + if not config_blob or not config_blob.output_guardrails: + return result, None + + if not isinstance(result.response.response.output, TextOutput): + logger.info( + f"[apply_output_guardrails] Skipping for non-text output. " + f"job_id={job_id}, " + f"output_type={getattr(result.response.response.output, 'type', type(result.response.response.output).__name__)}" + ) + return result, None - request = LLMCallRequest(**request_data) - job_id: UUID = UUID(job_id) + _, output_guardrails = list_validators_config( + organization_id=organization_id, + project_id=project_id, + input_validator_configs=None, + output_validator_configs=config_blob.output_guardrails, + ) - config = request.config - callback_response = None - config_blob: ConfigBlob | None = None - input_guardrails: list[dict] = [] - output_guardrails: list[dict] = [] - llm_call_id: UUID | None = None # Track the LLM call record + if not output_guardrails: + return result, None + + output_text = result.response.response.output.content.value + safe = run_guardrails_validation( + output_text, + output_guardrails, + job_id, + project_id, + organization_id, + suppress_pass_logs=True, + ) logger.info( - f"[execute_job] Starting LLM job execution | job_id={job_id}, task_id={task_id}, " + f"[apply_output_guardrails] Validation result | success={safe['success']}, job_id={job_id}" ) - try: - with Session(engine) as session: - # Update job status to PROCESSING - job_crud = JobCrud(session=session) - logger.info(f"[execute_job] Attempting to fetch job | job_id={job_id}") - job = session.get(Job, job_id) - if not job: - # Log all jobs to see what's in the database - from sqlmodel import select - - all_jobs = session.exec( - select(Job).order_by(Job.created_at.desc()).limit(5) - ).all() - logger.error( - f"[execute_job] Job not found! | job_id={job_id} | " - f"Recent jobs in DB: {[(j.id, j.status) for j in all_jobs]}" - ) - else: - logger.info( - f"[execute_job] Found job | job_id={job_id}, status={job.status}" - ) + if safe.get("bypassed"): + logger.info( + f"[apply_output_guardrails] Guardrails bypassed (service unavailable) | job_id={job_id}" + ) + return result, None - job_crud.update( - job_id=job_id, job_update=JobUpdate(status=JobStatus.PROCESSING) - ) + if safe["success"]: + result.response.response.output.content.value = safe["data"]["safe_text"] + return result, None - # if stored config, fetch blob from DB + return result, safe["error"] + + +def execute_llm_call( + *, + config: LLMCallConfig, + query: QueryParams, + job_id: UUID, + project_id: int, + organization_id: int, + request_metadata: dict | None, + langfuse_credentials: dict | None, + include_provider_raw_response: bool = False, + chain_id: UUID | None = None, +) -> BlockResult: + """Execute a single LLM call. Shared by /llm/call and /llm/chain. + + Returns BlockResult with response + usage on success, or error on failure. + """ + + config_blob: ConfigBlob | None = None + llm_call_id: UUID | None = None + + try: + with Session(engine) as session: if config.is_stored_config: config_crud = ConfigVersionCrud( session=session, project_id=project_id, config_id=config.id ) - - # blob is dynamic, need to resolve to ConfigBlob format config_blob, error = resolve_config_blob(config_crud, config) - if error: - callback_response = APIResponse.failure_response( - error=error, - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - + return BlockResult(error=error) else: config_blob = config.blob - if config_blob is not None: - if config_blob.input_guardrails or config_blob.output_guardrails: - input_guardrails, output_guardrails = list_validators_config( - organization_id=organization_id, - project_id=project_id, - input_validator_configs=config_blob.input_guardrails, - output_validator_configs=config_blob.output_guardrails, - ) + if config_blob.prompt_template and isinstance(query.input, TextInput): + template = config_blob.prompt_template.template + interpolated = template.replace("{{input}}", query.input.content.value) + query.input.content.value = interpolated - if input_guardrails: - if not isinstance(request.query.input, TextInput): - logger.info( - "[execute_job] Skipping input guardrails for non-text input. " - f"job_id={job_id}, input_type={getattr(request.query.input, 'type', type(request.query.input).__name__)}" - ) - else: - safe_input = run_guardrails_validation( - request.query.input.content.value, - input_guardrails, - job_id, - project_id, - organization_id, - suppress_pass_logs=True, - ) - - logger.info( - f"[execute_job] Input guardrail validation | success={safe_input['success']}." - ) - - if safe_input.get("bypassed"): - logger.info( - "[execute_job] Guardrails bypassed (service unavailable)" - ) - - elif safe_input["success"]: - request.query.input.content.value = safe_input["data"][ - "safe_text" - ] - else: - # Update the text value with error message - request.query.input.content.value = safe_input["error"] - - callback_response = APIResponse.failure_response( - error=safe_input["error"], - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - user_sent_config_provider = "" - - try: - # Transform Kaapi config to native config if needed (before getting provider) - completion_config = config_blob.completion - - original_provider = ( - config_blob.completion.provider - ) # openai, google or prefixed + query, input_error = apply_input_guardrails( + config_blob=config_blob, + query=query, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if input_error: + return BlockResult(error=input_error) - if isinstance(completion_config, KaapiCompletionConfig): - completion_config, warnings = transform_kaapi_config_to_native( - completion_config - ) + completion_config = config_blob.completion + original_provider = completion_config.provider - if request.request_metadata is None: - request.request_metadata = {} - request.request_metadata.setdefault("warnings", []).extend(warnings) - else: - pass - except Exception as e: - callback_response = APIResponse.failure_response( - error=f"Error processing configuration: {str(e)}", - metadata=request.request_metadata, + if isinstance(completion_config, KaapiCompletionConfig): + completion_config, warnings = transform_kaapi_config_to_native( + completion_config ) - return handle_job_error(job_id, request.callback_url, callback_response) + if request_metadata is None: + request_metadata = {} + request_metadata.setdefault("warnings", []).extend(warnings) + + resolved_config_blob = ConfigBlob( + completion=completion_config, + prompt_template=config_blob.prompt_template, + input_guardrails=config_blob.input_guardrails, + output_guardrails=config_blob.output_guardrails, + ) - # Create LLM call record before execution try: - # Rebuild ConfigBlob with transformed native config - resolved_config_blob = ConfigBlob( - completion=completion_config, - input_guardrails=config_blob.input_guardrails, - output_guardrails=config_blob.output_guardrails, + temp_request = LLMCallRequest( + query=query, + config=config, + request_metadata=request_metadata, ) - llm_call = create_llm_call( session, - request=request, + request=temp_request, job_id=job_id, project_id=project_id, organization_id=organization_id, resolved_config=resolved_config_blob, original_provider=original_provider, + chain_id=chain_id, ) llm_call_id = llm_call.id logger.info( - f"[execute_job] Created LLM call record | llm_call_id={llm_call_id}, job_id={job_id}" + f"[execute_llm_call] Created LLM call record | " + f"llm_call_id={llm_call_id}, job_id={job_id}" ) except Exception as e: logger.error( - f"[execute_job] Failed to create LLM call record: {str(e)} | job_id={job_id}", + f"[execute_llm_call] Failed to create LLM call record: {e} | job_id={job_id}", exc_info=True, ) - callback_response = APIResponse.failure_response( - error=f"Failed to create LLM call record: {str(e)}", - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) + return BlockResult(error=f"Failed to create LLM call record: {str(e)}") try: provider_instance = get_llm_provider( session=session, - provider_type=completion_config.provider, # Now always native provider type i.e openai-native, google-native regardless + provider_type=completion_config.provider, project_id=project_id, organization_id=organization_id, ) except ValueError as ve: - callback_response = APIResponse.failure_response( - error=str(ve), - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) - - langfuse_credentials = get_provider_credential( - session=session, - org_id=organization_id, - project_id=project_id, - provider="langfuse", - ) + return BlockResult(error=str(ve), llm_call_id=llm_call_id) - # Extract conversation_id for langfuse session grouping conversation_id = None - if request.query.conversation and request.query.conversation.id: - conversation_id = request.query.conversation.id + if query.conversation and query.conversation.id: + conversation_id = query.conversation.id - # Resolve input (handles text, audio_base64, audio_url) - resolved_input, resolve_error = resolve_input(request.query.input) + resolved_input, resolve_error = resolve_input(query.input) if resolve_error: - callback_response = APIResponse.failure_response( - error=resolve_error, - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) + return BlockResult(error=resolve_error, llm_call_id=llm_call_id) - # Apply Langfuse observability decorator to provider execute method decorated_execute = observe_llm_execution( credentials=langfuse_credentials, session_id=conversation_id, @@ -364,80 +417,16 @@ def execute_job( try: response, error = decorated_execute( completion_config=completion_config, - query=request.query, + query=query, resolved_input=resolved_input, - include_provider_raw_response=request.include_provider_raw_response, + include_provider_raw_response=include_provider_raw_response, ) finally: - # Clean up temp files for audio inputs - if resolved_input and resolved_input != request.query.input: + if resolved_input and resolved_input != query.input: cleanup_temp_file(resolved_input) if response: - if output_guardrails: - if not isinstance(response.response.output, TextOutput): - logger.info( - "[execute_job] Skipping output guardrails for non-text output. " - f"job_id={job_id}, output_type={getattr(response.response.output, 'type', type(response.response.output).__name__)}" - ) - else: - output_text = response.response.output.content.value - safe_output = run_guardrails_validation( - output_text, - output_guardrails, - job_id, - project_id, - organization_id, - suppress_pass_logs=True, - ) - - logger.info( - f"[execute_job] Output guardrail validation | success={safe_output['success']}." - ) - - if safe_output.get("bypassed"): - logger.info( - "[execute_job] Guardrails bypassed (service unavailable)" - ) - - elif safe_output["success"]: - response.response.output.content.value = safe_output["data"][ - "safe_text" - ] - - if safe_output["data"]["rephrase_needed"] == True: - callback_response = APIResponse.failure_response( - error=request.query.input, - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - - else: - response.response.output.content.value = safe_output["error"] - - callback_response = APIResponse.failure_response( - error=safe_output["error"], - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - - callback_response = APIResponse.success_response( - data=response, metadata=request.request_metadata - ) - if request.callback_url: - send_callback( - callback_url=request.callback_url, - data=callback_response.model_dump(), - ) - with Session(engine) as session: - job_crud = JobCrud(session=session) - - # Update LLM call record with response data if llm_call_id: try: update_llm_call_response( @@ -448,34 +437,120 @@ def execute_job( usage=response.usage.model_dump(), conversation_id=response.response.conversation_id, ) - logger.info( - f"[execute_job] Updated LLM call record | llm_call_id={llm_call_id}" - ) except Exception as e: logger.error( - f"[execute_job] Failed to update LLM call record: {str(e)} | llm_call_id={llm_call_id}", + f"[execute_llm_call] Failed to update LLM call record: {e} | " + f"llm_call_id={llm_call_id}", exc_info=True, ) - # Don't fail the job if updating the record fails + result = BlockResult( + response=response, + llm_call_id=llm_call_id, + usage=response.usage, + ) + + result, output_error = apply_output_guardrails( + config_blob=config_blob, + result=result, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if output_error: + return BlockResult(error=output_error, llm_call_id=llm_call_id) + + return result + + return BlockResult( + error=error or "Unknown error occurred", + llm_call_id=llm_call_id, + ) + + except Exception as e: + logger.error( + f"[execute_llm_call] Unexpected error: {e} | job_id={job_id}", + exc_info=True, + ) + return BlockResult( + error="Unexpected error occurred", + llm_call_id=llm_call_id, + ) + + +def execute_job( + request_data: dict, + project_id: int, + organization_id: int, + job_id: str, + task_id: str, + task_instance, +) -> dict: + """Celery task to process an LLM request asynchronously. + + Returns: + dict: Serialized APIResponse[LLMCallResponse] on success, APIResponse[None] on failure + """ + request = LLMCallRequest(**request_data) + job_id: UUID = UUID(job_id) + + logger.info( + f"[execute_job] Starting LLM job execution | job_id={job_id}, task_id={task_id}" + ) + + try: + with Session(engine) as session: + job_crud = JobCrud(session=session) + job_crud.update( + job_id=job_id, job_update=JobUpdate(status=JobStatus.PROCESSING) + ) + + langfuse_credentials = get_provider_credential( + session=session, + org_id=organization_id, + project_id=project_id, + provider="langfuse", + ) + + result = execute_llm_call( + config=request.config, + query=request.query, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + request_metadata=request.request_metadata, + langfuse_credentials=langfuse_credentials, + include_provider_raw_response=request.include_provider_raw_response, + ) - job_crud.update( + if result.success: + callback_response = APIResponse.success_response( + data=result.response, metadata=request.request_metadata + ) + if request.callback_url: + send_callback( + callback_url=request.callback_url, + data=callback_response.model_dump(), + ) + + with Session(engine) as session: + JobCrud(session=session).update( job_id=job_id, job_update=JobUpdate(status=JobStatus.SUCCESS) ) logger.info( f"[execute_job] Successfully completed LLM job | job_id={job_id}, " - f"provider_response_id={response.response.provider_response_id}, tokens={response.usage.total_tokens}" + f"tokens={result.usage.total_tokens}" ) return callback_response.model_dump() callback_response = APIResponse.failure_response( - error=error or "Unknown error occurred", + error=result.error or "Unknown error occurred", metadata=request.request_metadata, ) return handle_job_error(job_id, request.callback_url, callback_response) except Exception as e: callback_response = APIResponse.failure_response( - error=f"Unexpected error occurred", + error="Unexpected error occurred", metadata=request.request_metadata, ) logger.error( @@ -483,3 +558,112 @@ def execute_job( exc_info=True, ) return handle_job_error(job_id, request.callback_url, callback_response) + + +def execute_chain_job( + request_data: dict, + project_id: int, + organization_id: int, + job_id: str, + task_id: str, + task_instance, +) -> dict: + """Celery task to process an LLM Chain request asynchronously. + + Returns: + dict: Serialized APIResponse[LLMChainResponse] on success, APIResponse[None] on failure + """ + # imports to avoid circular dependency: + from app.services.llm.chain.chain import ChainBlock, ChainContext, LLMChain + from app.services.llm.chain.executor import ChainExecutor + + request = LLMChainRequest(**request_data) + job_uuid = UUID(job_id) + chain_uuid = None + + logger.info( + f"[execute_chain_job] Starting chain execution | " + f"job_id={job_uuid}, total_blocks={len(request.blocks)}" + ) + + try: + with Session(engine) as session: + chain_record = create_llm_chain( + session, + job_id=job_uuid, + project_id=project_id, + organization_id=organization_id, + total_blocks=len(request.blocks), + input=serialize_input(request.query.input), + configs=[block.model_dump(mode="json") for block in request.blocks], + ) + chain_uuid = chain_record.id + + logger.info( + f"[execute_chain_job] Created chain record | " + f"chain_id={chain_uuid}, job_id={job_uuid}" + ) + + langfuse_credentials = get_provider_credential( + session=session, + org_id=organization_id, + project_id=project_id, + provider="langfuse", + ) + + context = ChainContext( + job_id=job_uuid, + chain_id=chain_uuid, + project_id=project_id, + organization_id=organization_id, + langfuse_credentials=langfuse_credentials, + request_metadata=request.request_metadata, + total_blocks=len(request.blocks), + callback_url=str(request.callback_url) if request.callback_url else None, + intermediate_callback_flags=[ + block.intermediate_callback for block in request.blocks + ], + ) + + blocks = [ + ChainBlock( + config=block.config, + index=i, + context=context, + include_provider_raw_response=block.include_provider_raw_response, + ) + for i, block in enumerate(request.blocks) + ] + + chain = LLMChain(blocks) + + executor = ChainExecutor(chain=chain, context=context, request=request) + return executor.run() + + except Exception as e: + logger.error( + f"[execute_chain_job] Failed: {e} | job_id={job_uuid}", + exc_info=True, + ) + + if chain_uuid: + try: + with Session(engine) as session: + update_llm_chain_status( + session, + chain_id=chain_uuid, + status=ChainStatus.FAILED, + error=str(e), + ) + except Exception: + logger.error( + f"[execute_chain_job] Failed to update chain status: {e} | " + f"chain_id={chain_uuid}", + exc_info=True, + ) + + callback_response = APIResponse.failure_response( + error="Unexpected error occurred", + metadata=request.request_metadata, + ) + return handle_job_error(job_uuid, request.callback_url, callback_response)