@@ -98,7 +98,12 @@ def __init__(
9898 * ,
9999 greeting : Optional [str | PathType ] = None ,
100100 client : Optional [str | chatlas .Chat ] = None ,
101- tools : TOOL_GROUPS | tuple [TOOL_GROUPS , ...] | None = ("update" , "query" , "visualize_dashboard" , "visualize_query" ),
101+ tools : TOOL_GROUPS | tuple [TOOL_GROUPS , ...] | None = (
102+ "update" ,
103+ "query" ,
104+ "visualize_dashboard" ,
105+ "visualize_query" ,
106+ ),
102107 data_description : Optional [str | PathType ] = None ,
103108 categorical_threshold : int = 20 ,
104109 extra_instructions : Optional [str | PathType ] = None ,
@@ -114,7 +119,12 @@ def __init__(
114119 * ,
115120 greeting : Optional [str | PathType ] = None ,
116121 client : Optional [str | chatlas .Chat ] = None ,
117- tools : TOOL_GROUPS | tuple [TOOL_GROUPS , ...] | None = ("update" , "query" , "visualize_dashboard" , "visualize_query" ),
122+ tools : TOOL_GROUPS | tuple [TOOL_GROUPS , ...] | None = (
123+ "update" ,
124+ "query" ,
125+ "visualize_dashboard" ,
126+ "visualize_query" ,
127+ ),
118128 data_description : Optional [str | PathType ] = None ,
119129 categorical_threshold : int = 20 ,
120130 extra_instructions : Optional [str | PathType ] = None ,
@@ -130,7 +140,12 @@ def __init__(
130140 * ,
131141 greeting : Optional [str | PathType ] = None ,
132142 client : Optional [str | chatlas .Chat ] = None ,
133- tools : TOOL_GROUPS | tuple [TOOL_GROUPS , ...] | None = ("update" , "query" , "visualize_dashboard" , "visualize_query" ),
143+ tools : TOOL_GROUPS | tuple [TOOL_GROUPS , ...] | None = (
144+ "update" ,
145+ "query" ,
146+ "visualize_dashboard" ,
147+ "visualize_query" ,
148+ ),
134149 data_description : Optional [str | PathType ] = None ,
135150 categorical_threshold : int = 20 ,
136151 extra_instructions : Optional [str | PathType ] = None ,
@@ -146,7 +161,12 @@ def __init__(
146161 * ,
147162 greeting : Optional [str | PathType ] = None ,
148163 client : Optional [str | chatlas .Chat ] = None ,
149- tools : TOOL_GROUPS | tuple [TOOL_GROUPS , ...] | None = ("update" , "query" , "visualize_dashboard" , "visualize_query" ),
164+ tools : TOOL_GROUPS | tuple [TOOL_GROUPS , ...] | None = (
165+ "update" ,
166+ "query" ,
167+ "visualize_dashboard" ,
168+ "visualize_query" ,
169+ ),
150170 data_description : Optional [str | PathType ] = None ,
151171 categorical_threshold : int = 20 ,
152172 extra_instructions : Optional [str | PathType ] = None ,
@@ -161,7 +181,12 @@ def __init__(
161181 * ,
162182 greeting : Optional [str | PathType ] = None ,
163183 client : Optional [str | chatlas .Chat ] = None ,
164- tools : TOOL_GROUPS | tuple [TOOL_GROUPS , ...] | None = ("update" , "query" , "visualize_dashboard" , "visualize_query" ),
184+ tools : TOOL_GROUPS | tuple [TOOL_GROUPS , ...] | None = (
185+ "update" ,
186+ "query" ,
187+ "visualize_dashboard" ,
188+ "visualize_query" ,
189+ ),
165190 data_description : Optional [str | PathType ] = None ,
166191 categorical_threshold : int = 20 ,
167192 extra_instructions : Optional [str | PathType ] = None ,
@@ -213,16 +238,23 @@ def ggvis(
213238 An Altair Chart object, or None if no visualization exists.
214239
215240 """
216- from . _ggsql import vegalite_to_altair
241+ import ggsql
217242
218243 if source == "filter" :
219- chart_dict = state .get ("filter_viz_chart" )
244+ spec = state .get ("filter_viz_spec" )
245+ if spec is None :
246+ return None
247+ # Render against current filtered data
248+ df = as_narwhals (self .df (state ))
249+ return ggsql .render_altair (df , spec )
220250 else :
221- chart_dict = state .get ("query_viz_chart" )
222-
223- if chart_dict is None :
224- return None
225- return vegalite_to_altair (chart_dict )
251+ ggsql_query = state .get ("query_viz_ggsql" )
252+ if ggsql_query is None :
253+ return None
254+ # Re-execute SQL and render
255+ sql , viz_spec = ggsql .split_query (ggsql_query )
256+ df = as_narwhals (self ._data_source .execute_query (sql ))
257+ return ggsql .render_altair (df , viz_spec )
226258
227259 def ggsql (
228260 self , state : AppStateDict , source : Literal ["filter" , "query" ] = "filter"
@@ -467,9 +499,7 @@ def app_layout(ids: IDs, table_name: str, chat_ui):
467499 srcDoc = "<p>No filter visualization yet. Ask the assistant to create one.</p>" ,
468500 style = {"width" : "100%" , "height" : "400px" , "border" : "none" },
469501 ),
470- dcc .Markdown (
471- id = ids .filter_ggsql , className = "querychat-ggsql-display mt-2"
472- ),
502+ dcc .Markdown (id = ids .filter_ggsql , className = "querychat-ggsql-display mt-2" ),
473503 ],
474504 title_id = ids .filter_plot_title ,
475505 class_name = "h-100" ,
@@ -483,9 +513,7 @@ def app_layout(ids: IDs, table_name: str, chat_ui):
483513 srcDoc = "<p>No query visualization yet. Ask the assistant to create one.</p>" ,
484514 style = {"width" : "100%" , "height" : "400px" , "border" : "none" },
485515 ),
486- dcc .Markdown (
487- id = ids .query_ggsql , className = "querychat-ggsql-display mt-2"
488- ),
516+ dcc .Markdown (id = ids .query_ggsql , className = "querychat-ggsql-display mt-2" ),
489517 ],
490518 title_id = ids .query_plot_title ,
491519 class_name = "h-100" ,
@@ -545,12 +573,13 @@ def register_app_callbacks(
545573 deserialize_state : Callable [[AppStateDict ], AppState ],
546574) -> None :
547575 """Register callbacks for SQL display, data table, visualizations, and export."""
576+ import ggsql
548577 from dash .dcc .express import send_data_frame
549578
550579 import dash
551580 from dash import Input , Output , State
552581
553- from ._ggsql import vegalite_to_altair
582+ from ._ggsql import vegalite_to_html
554583
555584 @app .callback (
556585 [
@@ -585,7 +614,8 @@ def update_display(state_data: AppStateDict, reset_clicks):
585614 sql_title = state .title or "SQL Query"
586615 sql_code = f"```sql\n { state .get_display_sql ()} \n ```"
587616
588- nw_df = as_narwhals (state .get_current_data ())
617+ current_data = state .get_current_data ()
618+ nw_df = as_narwhals (current_data )
589619 nrow , ncol = nw_df .shape
590620
591621 display_df = nw_df .to_pandas ()
@@ -598,35 +628,37 @@ def update_display(state_data: AppStateDict, reset_clicks):
598628 data_info_parts .append (f"Data has { nrow } rows and { ncol } columns." )
599629 data_info = " " .join (data_info_parts )
600630
601- # Filter visualization
631+ # Filter visualization - render on demand
602632 filter_title = state .filter_viz_title or "Filter Plot"
603- filter_chart_dict = state .filter_viz_chart
604633 filter_spec = state .filter_viz_spec
605634
606- if filter_chart_dict :
607- chart = vegalite_to_altair (filter_chart_dict )
608- filter_html = chart .to_html ()
635+ if filter_spec :
636+ # Render against current filtered data
637+ chart = ggsql .render_altair (nw_df , filter_spec )
638+ filter_html = vegalite_to_html (chart .to_dict ())
609639 else :
610640 filter_html = (
611641 "<p>No filter visualization yet. Ask the assistant to create one.</p>"
612642 )
613643
614644 filter_ggsql_md = f"```sql\n { filter_spec } \n ```" if filter_spec else ""
615645
616- # Query visualization
646+ # Query visualization - render on demand
617647 query_title = state .query_viz_title or "Query Plot"
618- query_chart_dict = state .query_viz_chart
619- query_spec = state .query_viz_ggsql
620-
621- if query_chart_dict :
622- chart = vegalite_to_altair (query_chart_dict )
623- query_html = chart .to_html ()
648+ query_ggsql_str = state .query_viz_ggsql
649+
650+ if query_ggsql_str :
651+ # Re-execute SQL and render
652+ sql_part , viz_spec = ggsql .split_query (query_ggsql_str )
653+ query_df = as_narwhals (state .data_source .execute_query (sql_part ))
654+ chart = ggsql .render_altair (query_df , viz_spec )
655+ query_html = vegalite_to_html (chart .to_dict ())
624656 else :
625657 query_html = (
626658 "<p>No query visualization yet. Ask the assistant to create one.</p>"
627659 )
628660
629- query_ggsql_md = f"```sql\n { query_spec } \n ```" if query_spec else ""
661+ query_ggsql_md = f"```sql\n { query_ggsql_str } \n ```" if query_ggsql_str else ""
630662
631663 return (
632664 sql_title ,
0 commit comments