Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions pkg-py/examples/10-viz-app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from querychat import QueryChat
from querychat.data import titanic

from shiny import App, ui

# Omits "update" tool — this demo focuses on query + visualization only
qc = QueryChat(
titanic(),
"titanic",
tools=("query", "visualize_query"),
)

app_ui = ui.page_fillable(
qc.ui(),
)


def server(input, output, session):
qc.server()


app = App(app_ui, server)
22 changes: 19 additions & 3 deletions pkg-py/src/querychat/_icons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,35 @@

from shiny import ui

ICON_NAMES = Literal["arrow-counterclockwise", "funnel-fill", "terminal-fill", "table"]
ICON_NAMES = Literal[
"arrow-counterclockwise",
"bar-chart-fill",
"chevron-down",
"download",
"funnel-fill",
"graph-up",
"terminal-fill",
"table",
]


def bs_icon(name: ICON_NAMES) -> ui.HTML:
def bs_icon(name: ICON_NAMES, cls: str = "") -> ui.HTML:
"""Get Bootstrap icon SVG by name."""
if name not in BS_ICONS:
raise ValueError(f"Unknown Bootstrap icon: {name}")
return ui.HTML(BS_ICONS[name])
svg = BS_ICONS[name]
if cls:
svg = svg.replace('class="', f'class="{cls} ', 1)
return ui.HTML(svg)


BS_ICONS = {
"arrow-counterclockwise": '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" class="bi bi-arrow-counterclockwise" style="height:1em;width:1em;fill:currentColor;vertical-align:-0.125em;" aria-hidden="true" role="img"><path fill-rule="evenodd" d="M8 3a5 5 0 1 1-4.546 2.914.5.5 0 0 0-.908-.417A6 6 0 1 0 8 2v1z"></path><path d="M8 4.466V.534a.25.25 0 0 0-.41-.192L5.23 2.308a.25.25 0 0 0 0 .384l2.36 1.966A.25.25 0 0 0 8 4.466z"></path></svg>',
"bar-chart-fill": '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" class="bi bi-bar-chart-fill" style="height:1em;width:1em;fill:currentColor;vertical-align:-0.125em;" aria-hidden="true" role="img"><path d="M1 11a1 1 0 0 1 1-1h2a1 1 0 0 1 1 1v3a1 1 0 0 1-1 1H2a1 1 0 0 1-1-1zm5-4a1 1 0 0 1 1-1h2a1 1 0 0 1 1 1v7a1 1 0 0 1-1 1H7a1 1 0 0 1-1-1zm5-5a1 1 0 0 1 1-1h2a1 1 0 0 1 1 1v12a1 1 0 0 1-1 1h-2a1 1 0 0 1-1-1z"/></svg>',
"chevron-down": '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" class="bi bi-chevron-down" style="height:1em;width:1em;fill:currentColor;vertical-align:-0.125em;" aria-hidden="true" role="img"><path fill-rule="evenodd" d="M1.646 4.646a.5.5 0 0 1 .708 0L8 10.293l5.646-5.647a.5.5 0 0 1 .708.708l-6 6a.5.5 0 0 1-.708 0l-6-6a.5.5 0 0 1 0-.708"/></svg>',
"download": '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" class="bi bi-download" style="height:1em;width:1em;fill:currentColor;vertical-align:-0.125em;" aria-hidden="true" role="img"><path d="M.5 9.9a.5.5 0 0 1 .5.5v2.5a1 1 0 0 0 1 1h12a1 1 0 0 0 1-1v-2.5a.5.5 0 0 1 1 0v2.5a2 2 0 0 1-2 2H2a2 2 0 0 1-2-2v-2.5a.5.5 0 0 1 .5-.5"/><path d="M7.646 11.854a.5.5 0 0 0 .708 0l3-3a.5.5 0 0 0-.708-.708L8.5 10.293V1.5a.5.5 0 0 0-1 0v8.793L5.354 8.146a.5.5 0 1 0-.708.708z"/></svg>',
"funnel-fill": '<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-funnel-fill" viewBox="0 0 16 16"><path d="M1.5 1.5A.5.5 0 0 1 2 1h12a.5.5 0 0 1 .5.5v2a.5.5 0 0 1-.128.334L10 8.692V13.5a.5.5 0 0 1-.342.474l-3 1A.5.5 0 0 1 6 14.5V8.692L1.628 3.834A.5.5 0 0 1 1.5 3.5z"/></svg>',
"graph-up": '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" class="bi bi-graph-up" style="height:1em;width:1em;fill:currentColor;vertical-align:-0.125em;" aria-hidden="true" role="img"><path fill-rule="evenodd" d="M0 0h1v15h15v1H0zm14.817 3.113a.5.5 0 0 1 .07.704l-4.5 5.5a.5.5 0 0 1-.74.037L7.06 6.767l-3.656 5.027a.5.5 0 0 1-.808-.588l4-5.5a.5.5 0 0 1 .758-.06l2.609 2.61 4.15-5.073a.5.5 0 0 1 .704-.07"/></svg>',
"terminal-fill": '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" class="bi bi-terminal-fill " style="height:1em;width:1em;fill:currentColor;vertical-align:-0.125em;" aria-hidden="true" role="img" ><path d="M0 3a2 2 0 0 1 2-2h12a2 2 0 0 1 2 2v10a2 2 0 0 1-2 2H2a2 2 0 0 1-2-2V3zm9.5 5.5h-3a.5.5 0 0 0 0 1h3a.5.5 0 0 0 0-1zm-6.354-.354a.5.5 0 1 0 .708.708l2-2a.5.5 0 0 0 0-.708l-2-2a.5.5 0 1 0-.708.708L4.793 6.5 3.146 8.146z"></path></svg>',
"table": '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" class="bi bi-table " style="height:1em;width:1em;fill:currentColor;vertical-align:-0.125em;" aria-hidden="true" role="img" ><path d="M0 2a2 2 0 0 1 2-2h12a2 2 0 0 1 2 2v12a2 2 0 0 1-2 2H2a2 2 0 0 1-2-2V2zm15 2h-4v3h4V4zm0 4h-4v3h4V8zm0 4h-4v3h3a1 1 0 0 0 1-1v-2zm-5 3v-3H6v3h4zm-5 0v-3H1v2a1 1 0 0 0 1 1h3zm-4-4h4V8H1v3zm0-4h4V4H1v3zm5-3v3h4V4H6zm4 4H6v3h4V8z"></path></svg>',
}
40 changes: 31 additions & 9 deletions pkg-py/src/querychat/_querychat_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,24 @@
from ._shiny_module import GREETING_PROMPT
from ._system_prompt import QueryChatSystemPrompt
from ._utils import MISSING, MISSING_TYPE, is_ibis_table
from ._viz_utils import has_viz_deps, has_viz_tool
from .tools import (
UpdateDashboardData,
tool_query,
tool_reset_dashboard,
tool_update_dashboard,
tool_visualize_query,
)

if TYPE_CHECKING:
from collections.abc import Callable

from narwhals.stable.v1.typing import IntoFrame

TOOL_GROUPS = Literal["update", "query"]
from ._viz_tools import VisualizeQueryData

TOOL_GROUPS = Literal["update", "query", "visualize_query"]
DEFAULT_TOOLS: tuple[TOOL_GROUPS, ...] = ("update", "query")

class QueryChatBase(Generic[IntoFrameT]):
"""
Expand All @@ -58,7 +62,7 @@ def __init__(
*,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
Expand All @@ -72,7 +76,7 @@ def __init__(
"Table name must begin with a letter and contain only letters, numbers, and underscores",
)

self.tools = normalize_tools(tools, default=("update", "query"))
self.tools = normalize_tools(tools, default=DEFAULT_TOOLS)
self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting

# Store init parameters for deferred system prompt building
Expand Down Expand Up @@ -132,18 +136,22 @@ def client(
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING,
update_dashboard: Callable[[UpdateDashboardData], None] | None = None,
reset_dashboard: Callable[[], None] | None = None,
visualize_query: Callable[[VisualizeQueryData], None] | None = None,
) -> chatlas.Chat:
"""
Create a chat client with registered tools.

Parameters
----------
tools
Which tools to include: `"update"`, `"query"`, or both.
Which tools to include: `"update"`, `"query"`, `"visualize_query"`,
or a combination.
update_dashboard
Callback when update_dashboard tool succeeds.
reset_dashboard
Callback when reset_dashboard tool is invoked.
visualize_query
Callback when visualize_query tool succeeds.

Returns
-------
Expand Down Expand Up @@ -172,6 +180,10 @@ def client(
if "query" in tools:
chat.register_tool(tool_query(data_source))

if "visualize_query" in tools:
query_viz_fn = visualize_query or (lambda _: None)
chat.register_tool(tool_visualize_query(data_source, query_viz_fn))

return chat

def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str:
Expand Down Expand Up @@ -278,14 +290,24 @@ def normalize_client(client: str | chatlas.Chat | None) -> chatlas.Chat:
def normalize_tools(
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE,
default: tuple[TOOL_GROUPS, ...] | None,
*,
check_deps: bool = True,
) -> tuple[TOOL_GROUPS, ...] | None:
if tools is None or tools == ():
return None
result = None
elif isinstance(tools, MISSING_TYPE):
return default
result = default
elif isinstance(tools, str):
return (tools,)
result = (tools,)
elif isinstance(tools, tuple):
return tools
result = tools
else:
return tuple(tools)
result = tuple(tools)
if not check_deps:
return result
if has_viz_tool(result) and not has_viz_deps():
raise ImportError(
"Visualization tools require ggsql, altair, and shinywidgets. "
"Install them with: pip install querychat[viz]"
)
return result
28 changes: 17 additions & 11 deletions pkg-py/src/querychat/_shiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui

from ._icons import bs_icon
from ._querychat_base import TOOL_GROUPS, QueryChatBase
from ._querychat_base import DEFAULT_TOOLS, TOOL_GROUPS, QueryChatBase
from ._shiny_module import ServerValues, mod_server, mod_ui
from ._utils import as_narwhals
from ._viz_utils import has_viz_tool

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -97,10 +98,11 @@ class QueryChat(QueryChatBase[IntoFrameT]):
tools
Which querychat tools to include in the chat client by default. Can be:
- A single tool string: `"update"` or `"query"`
- A tuple of tools: `("update", "query")`
- A tuple of tools: `("update", "query", "visualize_query")`
- `None` or `()` to disable all tools

Default is `("update", "query")` (both tools enabled).
Default is `("update", "query")`. The visualization tool (`"visualize_query"`)
can be opted into by including it in the tuple.

Set to `"update"` to prevent the LLM from accessing data values, only
allowing dashboard filtering without answering questions.
Expand Down Expand Up @@ -156,7 +158,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
Expand All @@ -172,7 +174,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
Expand All @@ -188,7 +190,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
Expand All @@ -204,7 +206,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
Expand All @@ -219,7 +221,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
Expand All @@ -245,7 +247,7 @@ def app(
"""
Quickly chat with a dataset.

Creates a Shiny app with a chat sidebar and data table view -- providing a
Creates a Shiny app with a chat sidebar and data view -- providing a
quick-and-easy way to start chatting with your data.

Parameters
Expand Down Expand Up @@ -301,6 +303,7 @@ def app_server(input: Inputs, output: Outputs, session: Session):
greeting=self.greeting,
client=self._client,
enable_bookmarking=enable_bookmarking,
tools=self.tools,
)

@render.text
Expand Down Expand Up @@ -399,7 +402,7 @@ def ui(self, *, id: Optional[str] = None, **kwargs):
A UI component.

"""
return mod_ui(id or self.id, **kwargs)
return mod_ui(id or self.id, preload_viz=has_viz_tool(self.tools), **kwargs)

def server(
self,
Expand Down Expand Up @@ -493,6 +496,7 @@ def title():
greeting=self.greeting,
client=self.client,
enable_bookmarking=enable_bookmarking,
tools=self.tools,
)


Expand Down Expand Up @@ -730,6 +734,7 @@ def __init__(
greeting=self.greeting,
client=self._client,
enable_bookmarking=enable,
tools=self.tools,
)

def sidebar(
Expand Down Expand Up @@ -791,7 +796,7 @@ def ui(self, *, id: Optional[str] = None, **kwargs):
A UI component.

"""
return mod_ui(id or self.id, **kwargs)
return mod_ui(id or self.id, preload_viz=has_viz_tool(self.tools), **kwargs)

def df(self) -> IntoFrameT:
"""
Expand Down Expand Up @@ -870,3 +875,4 @@ def title(self, value: Optional[str] = None) -> str | None | bool:
return self._vals.title()
else:
return self._vals.title.set(value)

Loading
Loading