Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 71 additions & 7 deletions src/bub/builtin/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import shlex
import time
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Collection, Coroutine, Iterable
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Collection, Coroutine, Iterable
from contextlib import AsyncExitStack
from dataclasses import dataclass, replace
from datetime import UTC, datetime
Expand All @@ -32,8 +32,8 @@
from bub.builtin.settings import ModelCandidate, load_settings
from bub.builtin.store import ForkTapeStore
from bub.builtin.tape import TapeService
from bub.errors import BubError, ErrorKind
from bub.framework import BubFramework
from bub.runtime import AsyncStreamEvents, BubError, ErrorKind, StreamEvent, StreamState
from bub.skills import discover_skills, render_skills_prompt
from bub.tape import InMemoryTapeStore, Tape
from bub.tools import (
Expand All @@ -42,7 +42,7 @@
ToolContext,
ToolExecutor,
)
from bub.types import State
from bub.types import Envelope, State
from bub.utils import workspace_from_state

CONTINUE_PROMPT = "Continue the task until all targets are completed."
Expand All @@ -55,6 +55,69 @@
TOOL_ARGUMENTS_ADAPTER = TypeAdapter(dict[str, Any])


@dataclass(frozen=True)
class StreamEvent:
kind: Literal["text", "reasoning", "tool_call", "tool_result", "usage", "error", "final"]
data: dict[str, Any]


@dataclass
class StreamState:
error: BubError | None = None
usage: dict[str, Any] | None = None


class AsyncStreamEvents:
def __init__(self, iterator: AsyncIterator[StreamEvent], *, state: StreamState | None = None) -> None:
self._iterator = iterator
self._state = state or StreamState()

def __aiter__(self) -> AsyncIterator[StreamEvent]:
return self._iterator

@property
def state(self) -> StreamState:
return self._state

@property
def error(self) -> BubError | None:
return self._state.error

@property
def usage(self) -> dict[str, Any] | None:
return self._state.usage


class BuiltinModelStream:
"""Builtin-owned stream envelope and binding."""

def __init__(self, events: AsyncStreamEvents) -> None:
self._events = events
self._output_parts: list[str] = []
self._stream_started = False

def stream(self) -> AsyncIterable[Envelope] | None:
if self._stream_started:
return None
self._stream_started = True

async def iterator() -> AsyncIterator[Envelope]:
async for event in self._events:
if event.kind == "text":
delta = str(event.data.get("delta", ""))
self._output_parts.append(delta)
yield {"content": delta, "source": event}
elif event.kind == "final":
yield {"end": True, "source": event}
else:
yield event

return iterator()

def output(self) -> Envelope | None:
return "".join(self._output_parts)


class Agent:
"""Agent that processes prompts using hooks, tools, tape, and any-llm-sdk."""

Expand All @@ -73,8 +136,8 @@ def tapes(self) -> TapeService:
return TapeService(bub.home / "tapes", tape_store, self.framework.build_tape_context())

@staticmethod
def _events_from_iterable(iterable: Iterable) -> AsyncStreamEvents:
async def generator() -> AsyncIterator:
def _events_from_iterable(iterable: Iterable[StreamEvent]) -> AsyncStreamEvents:
async def generator() -> AsyncIterator[StreamEvent]:
for item in iterable:
yield item

Expand All @@ -91,7 +154,7 @@ async def generator() -> AsyncIterator[StreamEvent]:
finally:
await callback()

return AsyncStreamEvents(generator(), state=events._state)
return AsyncStreamEvents(generator(), state=events.state)

async def run(
self,
Expand Down Expand Up @@ -467,7 +530,8 @@ async def iterator() -> AsyncGenerator[StreamEvent, None]:
)
yield StreamEvent("tool_result", {"tool_results": execution.tool_results})
yield StreamEvent(
"final", {"ok": True, "tool_calls": serialized_tool_calls, "tool_results": execution.tool_results}
"final",
{"ok": True, "tool_calls": serialized_tool_calls, "tool_results": execution.tool_results},
)
return

Expand Down
24 changes: 15 additions & 9 deletions src/bub/builtin/hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@
from loguru import logger

from bub import inquirer as bub_inquirer
from bub.builtin.agent import Agent
from bub.builtin.agent import Agent, BuiltinModelStream
from bub.builtin.context import default_tape_context
from bub.builtin.settings import DEFAULT_MODEL
from bub.channels.base import Channel
from bub.channels.message import ChannelMessage, MediaItem
from bub.envelope import content_of, field_of
from bub.framework import BubFramework
from bub.hookspecs import hookimpl
from bub.runtime import AsyncStreamEvents
from bub.tape import TapeContext, TapeStore
from bub.types import Envelope, MessageHandler, State
from bub.types import Envelope, EnvelopeBinding, MessageHandler, State

AGENTS_FILE_NAME = "AGENTS.md"
MODEL_PROVIDER_CHOICES: tuple[str, ...] = (
Expand Down Expand Up @@ -120,7 +119,7 @@ async def load_state(self, message: ChannelMessage, session_id: str) -> State:
return state

@hookimpl
async def save_state(self, session_id: str, state: State, message: ChannelMessage, model_output: str) -> None:
async def save_state(self, session_id: str, state: State, message: ChannelMessage, model_output: Envelope) -> None:
tp, value, traceback = sys.exc_info()
lifespan = field_of(message, "lifespan")
if lifespan is not None:
Expand Down Expand Up @@ -156,12 +155,19 @@ async def build_prompt(self, message: ChannelMessage, session_id: str, state: St
return text

@hookimpl
async def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> str:
async def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> Envelope:
return await self._get_agent().run(session_id=session_id, prompt=prompt, state=state)

@hookimpl
async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> AsyncStreamEvents:
return await self._get_agent().run_stream(session_id=session_id, prompt=prompt, state=state)
async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> Envelope:
stream = await self._get_agent().run_stream(session_id=session_id, prompt=prompt, state=state)
return BuiltinModelStream(stream)

@hookimpl
def bind_envelope(self, envelope: Envelope, session_id: str, state: State) -> EnvelopeBinding | None:
if isinstance(envelope, BuiltinModelStream):
return envelope
return None

@hookimpl
def register_cli_commands(self, app: typer.Typer) -> None:
Expand Down Expand Up @@ -273,13 +279,13 @@ def render_outbound(
message: Envelope,
session_id: str,
state: State,
model_output: str,
model_output: Envelope,
) -> list[ChannelMessage]:
outbound = ChannelMessage(
session_id=session_id,
channel=field_of(message, "channel", "default"),
chat_id=field_of(message, "chat_id", "default"),
content=model_output,
content=content_of(model_output),
output_channel=field_of(message, "output_channel", "default"),
kind=field_of(message, "kind", "normal"),
)
Expand Down
2 changes: 1 addition & 1 deletion src/bub/builtin/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic import BaseModel

from bub.builtin.store import ForkTapeStore
from bub.runtime import BubError
from bub.errors import BubError
from bub.tape import AsyncTapeStore, Tape, TapeContext, TapeEntry, TapeQuery, build_messages


Expand Down
4 changes: 2 additions & 2 deletions src/bub/channels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import ClassVar

from bub.channels.message import ChannelMessage
from bub.runtime import StreamEvent
from bub.types import Envelope


class Channel(ABC):
Expand Down Expand Up @@ -35,7 +35,7 @@ async def send(self, message: ChannelMessage) -> None:
# Do nothing by default
return

def stream_events(self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]:
def stream_events(self, message: ChannelMessage, stream: AsyncIterable[Envelope]) -> AsyncIterable[Envelope]:
"""Optionally wrap the output stream for this channel."""
return stream

Expand Down
31 changes: 16 additions & 15 deletions src/bub/channels/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
from bub.channels.base import Interface
from bub.channels.cli.renderer import CliRenderer
from bub.channels.message import ChannelMessage
from bub.envelope import field_of
from bub.runtime import StreamEvent
from bub.envelope import content_of, field_of
from bub.tools import REGISTRY
from bub.types import MessageHandler
from bub.types import Envelope, MessageHandler


class _StreamPrinter:
Expand All @@ -39,16 +38,20 @@ def __init__(self, *, console, print_head: Callable[[], None], expand_thinking:
self._reasoning_status: Status | None = None
self.head_printed = False

def render(self, event: StreamEvent) -> bool:
if event.kind == "reasoning":
self._record_reasoning(str(event.data.get("delta", "")))
def render(self, event: Envelope) -> bool:
kind = field_of(event, "kind")
data = field_of(event, "data", {})
data = data if isinstance(data, dict) else {}
if kind == "reasoning":
self._record_reasoning(str(data.get("delta", "")))
return True

if event.kind == "text":
return self._print_content(str(event.data.get("delta", "")))
elif event.kind == "tool_call":
content = content_of(event)
if content:
return self._print_content(content)
if kind == "tool_call":
self._print_stream_boundary()
elif event.kind == "final":
elif kind == "final" or field_of(event, "end"):
self._print_end()
return True

Expand Down Expand Up @@ -236,18 +239,16 @@ def _prompt_message(self) -> FormattedText:
symbol = ">" if self._mode == "agent" else ","
return FormattedText([("bold", f"{cwd} {symbol} ")])

async def stream_events(
self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]
) -> AsyncIterable[StreamEvent]:
async def stream_events(self, message: ChannelMessage, stream: AsyncIterable[Envelope]) -> AsyncIterable[Envelope]:
console = get_console()
printer = _StreamPrinter(
console=console,
print_head=lambda: self._renderer.print_head(message.kind),
expand_thinking=self._expand_thinking,
)
async for event in stream:
if printer.render(event):
yield event
printer.render(event)
yield event

def _build_prompt(self, workspace: Path) -> PromptSession[str]:
kb = KeyBindings()
Expand Down
3 changes: 1 addition & 2 deletions src/bub/channels/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from bub.configure import Settings, ensure_config
from bub.envelope import content_of, field_of
from bub.framework import BubFramework
from bub.runtime import StreamEvent
from bub.turn_admission import AdmitDecision, SessionTurnController
from bub.types import Envelope, MessageHandler
from bub.utils import wait_until_stopped
Expand Down Expand Up @@ -105,7 +104,7 @@ async def dispatch_output(self, message: Envelope) -> bool:
await channel.send(outbound)
return True

def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]:
def wrap_stream(self, message: Envelope, stream: AsyncIterable[Envelope]) -> AsyncIterable[Envelope]:
channel_name = field_of(message, "output_channel", field_of(message, "channel"))
if channel_name is None:
return stream
Expand Down
2 changes: 2 additions & 0 deletions src/bub/envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def field_of(message: Envelope, key: str, default: Any = None) -> Any:
def content_of(message: Envelope) -> str:
"""Get textual content from any envelope shape."""

if isinstance(message, str):
return message
return str(field_of(message, "content", ""))


Expand Down
40 changes: 40 additions & 0 deletions src/bub/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Shared Bub error types."""

from __future__ import annotations

from dataclasses import dataclass
from enum import StrEnum
from typing import Any


class ErrorKind(StrEnum):
"""Stable error kinds for Bub failures."""

INVALID_INPUT = "invalid_input"
CONFIG = "config"
PROVIDER = "provider"
TOOL = "tool"
TEMPORARY = "temporary"
NOT_FOUND = "not_found"
UNKNOWN = "unknown"


@dataclass(frozen=True)
class BubError(Exception):
"""Public error type for Bub failures."""

kind: ErrorKind
message: str
details: dict[str, Any] | None = None

def __str__(self) -> str:
return f"[{self.kind.value}] {self.message}"

def as_dict(self) -> dict[str, Any]:
payload: dict[str, Any] = {
"kind": self.kind.value,
"message": self.message,
}
if self.details:
payload["details"] = self.details
return payload
Loading
Loading