diff --git a/src/flyquery/core/agents/column_name_proposer_agent.py b/src/flyquery/core/agents/column_name_proposer_agent.py index 3eed345..800e259 100644 --- a/src/flyquery/core/agents/column_name_proposer_agent.py +++ b/src/flyquery/core/agents/column_name_proposer_agent.py @@ -116,6 +116,9 @@ def build_column_name_proposer_agent(settings): output_type=ProposedColumnNames, instructions=prompt.instructions, settings=settings, + # Deterministic naming: identical re-ingests must yield identical + # column names, otherwise reconcile sees phantom schema churn. + extra_settings={"temperature": 0.0}, ) diff --git a/src/flyquery/core/agents/describe_agent.py b/src/flyquery/core/agents/describe_agent.py index ac6870e..0adf064 100644 --- a/src/flyquery/core/agents/describe_agent.py +++ b/src/flyquery/core/agents/describe_agent.py @@ -140,4 +140,7 @@ def build_describe_agent(settings): output_type=DescribedObjects, instructions=prompt.instructions, settings=settings, + # Deterministic descriptions/semantic types: identical columns on + # re-ingest must produce identical metadata, no schema-change churn. + extra_settings={"temperature": 0.0}, ) diff --git a/src/flyquery/core/agents/rename_detection_agent.py b/src/flyquery/core/agents/rename_detection_agent.py index 2c1c4a9..b17791d 100644 --- a/src/flyquery/core/agents/rename_detection_agent.py +++ b/src/flyquery/core/agents/rename_detection_agent.py @@ -62,4 +62,7 @@ def build_rename_detection_agent(settings): settings=settings, # Rename detection is a short task; cap output tokens tightly max_output_tokens=2048, + # Deterministic: the same removed/candidate pair must always + # resolve the same way so re-ingests don't flip-flop renames. + extra_settings={"temperature": 0.0}, ) diff --git a/src/flyquery/core/services/examples/auto_learner.py b/src/flyquery/core/services/examples/auto_learner.py index 48101da..1219058 100644 --- a/src/flyquery/core/services/examples/auto_learner.py +++ b/src/flyquery/core/services/examples/auto_learner.py @@ -28,6 +28,7 @@ class AutoLearner: Skips when: - ``retries > 0`` (query required critic refinement) - PII findings were detected in the result + - the query returned no rows (``row_count`` is 0, when provided) Called by QueryService (Phase D) after a successful execution. """ @@ -46,6 +47,7 @@ async def maybe_propose( retries: int, pii_findings: list[Any], query_id: uuid.UUID, + row_count: int | None = None, ) -> None: """Insert a flyquery_examples row when all criteria pass. @@ -57,11 +59,17 @@ async def maybe_propose( :param retries: number of critic refinement loops (must be 0 to propose) :param pii_findings: any PII signals detected (must be empty to propose) :param query_id: UUID of the parent query record + :param row_count: number of rows the query returned; when provided it + must be > 0 to propose (a valid-but-wrong query returning 0 rows + would otherwise poison grounding). When ``None`` the row gate is + skipped to preserve behaviour for callers that do not pass it. """ if retries > 0: return if pii_findings: return + if row_count is not None and row_count <= 0: + return await self._service.create( tenant_id, workspace_id, diff --git a/src/flyquery/core/services/execution/ast_classifier.py b/src/flyquery/core/services/execution/ast_classifier.py index a5f9cd7..5402588 100644 --- a/src/flyquery/core/services/execution/ast_classifier.py +++ b/src/flyquery/core/services/execution/ast_classifier.py @@ -76,8 +76,23 @@ def classify(self, sql: str) -> AstClassification: # pyright does not unify with the public ``Expression`` base below. kind = self._kind(stmt) # pyright: ignore[reportArgumentType] - # Collect table refs — skip anonymous subquery aliases - tables = tuple(sorted({t.name for t in stmt.find_all(sqlglot.expressions.Table) if t.name})) + # Collect table refs — skip anonymous subquery aliases AND + # CTE-defined names. sqlglot represents a reference to a CTE + # (``FROM base`` where ``WITH base AS (...)``) as an ``exp.Table`` + # node, so without this filter the CTE alias leaks into + # ``table_refs``; the downstream bad-tables guard then flags it + # as a non-existent table and the (otherwise valid) query is + # rejected — see QueryService bad-tables set-difference. + cte_names = {cte.alias_or_name for cte in stmt.find_all(sqlglot.expressions.CTE) if cte.alias_or_name} + tables = tuple( + sorted( + { + t.name + for t in stmt.find_all(sqlglot.expressions.Table) + if t.name and t.name not in cte_names + } + ) + ) columns = tuple(sorted({c.name for c in stmt.find_all(sqlglot.expressions.Column) if c.name})) has_subquery = bool(list(stmt.find_all(sqlglot.expressions.Subquery))) diff --git a/src/flyquery/core/services/execution/table_resolver.py b/src/flyquery/core/services/execution/table_resolver.py index 3bcd6b5..8650ba5 100644 --- a/src/flyquery/core/services/execution/table_resolver.py +++ b/src/flyquery/core/services/execution/table_resolver.py @@ -32,6 +32,7 @@ from __future__ import annotations +import json import uuid import sqlalchemy as sa @@ -54,6 +55,7 @@ async def resolve( dataset_id: uuid.UUID, table_names: list[str], object_store_base: str | None = None, + pins: dict[str, str] | None = None, ) -> dict[str, str]: """Return a mapping of table name → absolute parquet path. @@ -65,6 +67,11 @@ async def resolve( :param dataset_id: dataset to scope the lookup :param table_names: unqualified table names from the AST :param object_store_base: override for ``settings.object_store_base`` + :param pins: optional ``{table_name: snapshot_id}`` — a follow-up + drill-down turn pins each table to the snapshot it resolved to + on the first turn, so a mid-conversation re-ingest does not + silently switch the answer to a newer schema. Unpinned tables + fall back to ``current_snapshot_id``. :return: ``{name: path}`` dict for all resolvable tables """ if not table_names: @@ -77,12 +84,16 @@ async def resolve( SELECT t.name, ss.parquet_object_key FROM flyquery_tables t JOIN flyquery_schema_snapshots ss - ON ss.id = t.current_snapshot_id + ON ss.table_id = t.id + AND ss.id = COALESCE( + (CAST(:pins AS jsonb) ->> t.name)::uuid, + t.current_snapshot_id + ) WHERE t.dataset_id = :ds AND t.name = ANY(:names) AND t.is_active = true """), - {"ds": dataset_id, "names": list(table_names)}, + {"ds": dataset_id, "names": list(table_names), "pins": json.dumps(pins or {})}, ) out: dict[str, str] = {} @@ -90,3 +101,39 @@ async def resolve( key: str = r["parquet_object_key"] out[r["name"]] = f"{base}/{key}" return out + + async def table_kinds_by_name(self, dataset_id: uuid.UUID, table_names: list[str]) -> dict[str, str]: + """Return ``{name: kind}`` for the active tables in the dataset. + + Used by the firewall/bad-tables guard. Lives here (service layer) + rather than in a controller so the raw SQL stays out of the web tier. + """ + if not table_names: + return {} + rows = await self._session.execute( + sa.text(""" + SELECT name, kind FROM flyquery_tables + WHERE dataset_id = :ds AND name = ANY(:names) AND is_active = true + """), + {"ds": dataset_id, "names": list(table_names)}, + ) + return {r["name"]: r["kind"] for r in rows.mappings()} + + async def current_snapshots(self, dataset_id: uuid.UUID, table_names: list[str]) -> dict[str, str]: + """Return ``{table_name: current_snapshot_id}`` for the given tables. + + Used to record THIS turn's snapshot pins so a later drill-down turn + can reproduce the exact schema version it answered against. + """ + if not table_names: + return {} + rows = await self._session.execute( + sa.text(""" + SELECT name, current_snapshot_id + FROM flyquery_tables + WHERE dataset_id = :ds AND name = ANY(:names) AND is_active = true + AND current_snapshot_id IS NOT NULL + """), + {"ds": dataset_id, "names": list(table_names)}, + ) + return {r["name"]: str(r["current_snapshot_id"]) for r in rows.mappings()} diff --git a/src/flyquery/core/services/ingestion/readers/excel_reader.py b/src/flyquery/core/services/ingestion/readers/excel_reader.py index 2423222..d2fca8f 100644 --- a/src/flyquery/core/services/ingestion/readers/excel_reader.py +++ b/src/flyquery/core/services/ingestion/readers/excel_reader.py @@ -55,6 +55,7 @@ from __future__ import annotations import asyncio +import datetime import re import tempfile from pathlib import Path @@ -68,11 +69,126 @@ ) _NAME_SAFE_RE = re.compile(r"[^A-Za-z0-9_]+") -_SECTION_RE = re.compile(r"^(?P.+)#section\[(?P\d+):(?P\d+)\]$") +# ``#section[
:]`` (header row + data on the next rows, contiguous) +# OR ``#section[
::]`` when the header row is NOT +# contiguous with the data (an inherited period header -- see below). +_SECTION_RE = re.compile(r"^(?P.+)#section\[(?P\d+):(?P\d+)(?::(?P\d+))?\]$") +# A period *value* must be the WHOLE cell (a bare year, an ISO date with an +# optional time, or an FY/H/Q-prefixed year) -- NOT merely a string that +# happens to contain a 4-digit year (which would match invoice/reference codes +# like ``INV-2024-0007`` or ``Form 2020`` and wrongly flag a row as a header). +_PERIOD_RE = re.compile( + r"^(?:fy|h[12]|q[1-4])?[\s/-]?(?:19|20)\d{2}(?:-\d{2}-\d{2})?(?:[ t]\d{2}:\d{2}(?::\d{2})?)?$", + re.IGNORECASE, +) # Heuristic constants _MIN_HEADER_NON_EMPTY = 2 # a row needs >= 2 non-empty cells to be a header _SECTION_BREAK_EMPTY_ROWS = 2 # >= 2 consecutive empty rows close a section +# When the first header candidate looks like a numeric/spacer row, look this +# many rows ahead for a clearly-better string-label row before settling. +_HEADER_LOOKAHEAD = 3 +# Max rows a data-first section may look BACK to inherit a period/date header +# (financial reports repeat a date header above each block of sub-sections). +_PERIOD_HEADER_MAX_DISTANCE = 60 + + +def _normalise_leading_blanks(rows: list[list[Any]]) -> list[list[Any]]: + """Drop leading rows that are *entirely* empty. + + calamine's ``to_python(skip_empty_area=True)`` trims the used-range + bounding box, but the exact number of leading blank rows it keeps can + differ between near-identical re-ingests (HDR-UNSTABLE) -- which flips + every section's absolute row index and causes column-name churn. We + normalise here by consistently removing fully-empty leading rows so the + same logical sheet yields the same header index. This MUST be applied + identically in ``_enumerate_sync`` (where indices are computed) and + ``_materialise_sync`` (where they are sliced) for the indices to line up. + """ + start = 0 + n = len(rows) + while start < n and all(c in ("", None) for c in rows[start]): + start += 1 + # Avoid a needless copy when nothing was trimmed. + return rows if start == 0 else rows[start:] + + +def _looks_like_label_row(row: list[Any]) -> bool: + """True when a row's populated cells are predominantly non-empty STRINGS. + + Header rows hold labels (text); spacer/index rows hold bare numbers + (1, 2, 3, ...). We require a strict string majority so genuinely numeric, + period, or date headers are not misclassified as data. + """ + non_empty = [c for c in row if c not in ("", None)] + if not non_empty: + return False + str_cells = sum(1 for c in non_empty if isinstance(c, str) and c.strip() != "") + return str_cells * 2 > len(non_empty) + + +def _is_numeric_spacer_row(row: list[Any]) -> bool: + """True when a header candidate looks like a numeric spacer, not labels. + + A row is a spacer when every populated cell is numeric (int/float, or a + numeric-looking string) and none is a real text label. This includes the + classic contiguous ``1..k`` column-numbering run Excel exports sometimes + inject above the real header band. A row with any genuine string label is + never a spacer. + """ + non_empty = [c for c in row if c not in ("", None)] + if len(non_empty) < _MIN_HEADER_NON_EMPTY: + return False + + def _as_number(cell: Any) -> float | None: + if isinstance(cell, bool): + return None + if isinstance(cell, (int, float)): + return float(cell) + if isinstance(cell, str): + try: + return float(cell.strip()) + except ValueError: + return None + return None + + numbers = [_as_number(c) for c in non_empty] + # A spacer row has EVERY populated cell numeric and no text label. This + # covers both the all-numeric case and, as a strict subset, the contiguous + # ``1..k`` column-numbering run -- both are spacers, never label rows. Any + # non-numeric (None) populated cell means it is not a spacer. + return all(num is not None for num in numbers) + + +def _is_period_value(cell: Any) -> bool: + """True when a cell IS an accounting period / date / year label. + + Full-match (not substring) so reference/invoice codes that merely embed a + year (``INV-2024-0007``, ``Form 2020``) are not misread as period headers. + """ + if isinstance(cell, (datetime.date, datetime.datetime)): + return True + if isinstance(cell, str): + return bool(_PERIOD_RE.match(cell.strip())) + return False + + +def _is_period_header_row(row: list[Any]) -> bool: + """True when a row is a period/date header (the bulk of its cells are years/dates). + + Financial-report exports place a single date header (``2024-12-31 ...``) + above a block of sub-sections; we use this to let a following data-first + section inherit those column labels instead of guessing names. + """ + non_empty = [c for c in row if c not in ("", None)] + if len(non_empty) < _MIN_HEADER_NON_EMPTY: + return False + period = sum(1 for c in non_empty if _is_period_value(c)) + return period >= 2 and period * 5 >= len(non_empty) * 3 # >= 60% period-like + + +def _populated_cols(row: list[Any]) -> set[int]: + return {c for c, v in enumerate(row) if v not in ("", None)} class ExcelReader: @@ -130,6 +246,7 @@ def _extract_sections(rows: list[list[Any]]) -> list[dict[str, Any]]: sections: list[dict[str, Any]] = [] i = 0 pending_label: str | None = None + last_period_header_idx: int | None = None n = len(rows) while i < n: row = rows[i] @@ -146,9 +263,70 @@ def _extract_sections(rows: list[list[Any]]) -> list[dict[str, Any]]: i += 1 continue - # Multi-cell row = section header - header_idx = i - j = i + 1 + # Multi-cell row = section header candidate. + # + # A naive reader takes the FIRST >=2-non-empty row as the header. + # But dashboard exports sometimes inject a numeric spacer / column- + # numbering row (e.g. ``1 2 3 4``) just above the real label band + # (HDR-MULTIROW). If we treat that spacer as the header, the real + # labels become data and columns get opaque positional names. So + # when this candidate looks like a numeric/contiguous spacer, peek + # a small window ahead and prefer the first following row that is + # clearly a string-label row. We only skip the candidate when such + # a better row exists -- genuinely numeric / period / date headers + # (no string-label row just below) are left untouched, as are + # single-row sheets. + own_header_idx = i + if _is_numeric_spacer_row(row) and not _looks_like_label_row(row): + look_end = min(i + 1 + _HEADER_LOOKAHEAD, n) + for la in range(i + 1, look_end): + la_ne = [c for c in rows[la] if c not in ("", None)] + if len(la_ne) == 0: + # A blank row before any label row means the spacer is + # really the last populated row -- stop looking ahead. + break + if len(la_ne) >= _MIN_HEADER_NON_EMPTY and _looks_like_label_row(rows[la]): + # Found a better string-label header just below; treat + # the skipped numeric/title rows as pre-header. + own_header_idx = la + break + + # Period-header inheritance. Financial-report exports (Orbis/BvD, + # etc.) place ONE date header (``2024-12-31 2023-12-31 ...``) above + # a block of sub-sections (P&L, ratios, ...), each introduced by its + # own title. Section-splitting starts each sub-section at its first + # DATA row, orphaning that shared header above the title -- so the + # value columns get opaque/guessed names instead of the years. When + # a section starts directly with data (its own header row is neither + # label-like nor a period header) and a recent period header covers + # its value columns, adopt that period header as this section's + # column header and treat the section's own first row as data. + own_is_period = _is_period_header_row(rows[own_header_idx]) + own_is_label = _looks_like_label_row(rows[own_header_idx]) + header_row_idx = own_header_idx + data_start = own_header_idx + 1 + inherited = False + if ( + last_period_header_idx is not None + and not own_is_label + and not own_is_period + and own_header_idx - last_period_header_idx <= _PERIOD_HEADER_MAX_DISTANCE + ): + sec_cols = _populated_cols(rows[own_header_idx]) + ph_cols = _populated_cols(rows[last_period_header_idx]) + extra = sec_cols - ph_cols + # The period header must cover the section's value columns. The + # ONLY column it may legitimately not cover is a row-label column + # to the LEFT of the period columns -- never a trailing value + # column (that would shift the inherited year labels by one). + if len(ph_cols & sec_cols) >= _MIN_HEADER_NON_EMPTY and ( + not extra or (len(extra) == 1 and min(extra) < min(ph_cols)) + ): + header_row_idx = last_period_header_idx + data_start = own_header_idx + inherited = True + + j = data_start consecutive_empty = 0 data_end = j # exclusive while j < n: @@ -167,22 +345,30 @@ def _extract_sections(rows: list[list[Any]]) -> list[dict[str, Any]]: data_end = j + 1 j += 1 - n_data_rows = data_end - (header_idx + 1) + # Maintain the active period header. A genuine period/date header + # opens (or renews) a band that the following data-first sub-sections + # inherit; a real label-headed table CLOSES the band so a stale date + # header can't bleed into an unrelated (positionally-overlapping) + # table further down. + if own_is_period: + last_period_header_idx = own_header_idx + elif own_is_label: + last_period_header_idx = None + + n_data_rows = data_end - data_start if n_data_rows >= 1: - # Compute the union of populated column indices across - # the entire section. This is the real column count after - # compaction (we drop empty/merged-padding columns at - # materialise time). - populated_cols: set[int] = set() - for k in range(header_idx, data_end): - for col_idx, cell in enumerate(rows[k]): - if cell not in ("", None): - populated_cols.add(col_idx) + # Populated columns. For a contiguous section the header row is + # part of the table, so include it. For an INHERITED header we + # count only DATA columns, so a period the sub-section does not + # report does not become an all-NULL column. + populated_cols: set[int] = set() if inherited else set(_populated_cols(rows[header_row_idx])) + for k in range(data_start, data_end): + populated_cols |= _populated_cols(rows[k]) sections.append( { "label": pending_label or f"section_{len(sections):02d}", - "header_row_idx": header_idx, - "data_start_idx": header_idx + 1, + "header_row_idx": header_row_idx, + "data_start_idx": data_start, "data_end_idx": data_end, "n_cols": len(populated_cols), "n_data_rows": n_data_rows, @@ -208,7 +394,10 @@ def _enumerate_sync(source_path: str, rules: TableExtractionRules) -> list[Propo if allow is not None and sheet_name not in allow: continue sheet = wb.get_sheet_by_name(sheet_name) - rows = sheet.to_python(skip_empty_area=True) + # Normalise leading fully-empty rows so section indices are stable + # across re-ingests (HDR-UNSTABLE). The SAME normalisation runs in + # ``_materialise_sync`` so the stored indices slice the same rows. + rows = _normalise_leading_blanks(sheet.to_python(skip_empty_area=True)) if not rows: continue @@ -230,9 +419,7 @@ def _enumerate_sync(source_path: str, rules: TableExtractionRules) -> list[Propo out.append( ProposedTable( name=ExcelReader._sanitise(sheet_name), - sheet_or_json_path=( - f"{sheet_name}#section[{s['header_row_idx']}:{s['data_end_idx']}]" - ), + sheet_or_json_path=ExcelReader._section_path(sheet_name, s), n_columns=s["n_cols"], n_rows_estimate=s["n_data_rows"], ) @@ -252,9 +439,7 @@ def _enumerate_sync(source_path: str, rules: TableExtractionRules) -> list[Propo out.append( ProposedTable( name=final_name, - sheet_or_json_path=( - f"{sheet_name}#section[{s['header_row_idx']}:{s['data_end_idx']}]" - ), + sheet_or_json_path=ExcelReader._section_path(sheet_name, s), n_columns=s["n_cols"], n_rows_estimate=s["n_data_rows"], ) @@ -262,16 +447,30 @@ def _enumerate_sync(source_path: str, rules: TableExtractionRules) -> list[Propo return out @staticmethod - def _parse_section_path(path: str) -> tuple[str, int | None, int | None]: - """Split ``#section[:]`` -> (sheet, start, end). + def _section_path(sheet_name: str, s: dict[str, Any]) -> str: + """Encode a section span. ``[h:e]`` when header+data are contiguous, + ``[h:ds:e]`` when the header row is inherited (not adjacent to data).""" + h, ds, e = s["header_row_idx"], s["data_start_idx"], s["data_end_idx"] + return f"{sheet_name}#section[{h}:{e}]" if ds == h + 1 else f"{sheet_name}#section[{h}:{ds}:{e}]" - For backward compat, a plain sheet name (no ``#section[...]``) - returns ``(sheet, None, None)``. + @staticmethod + def _parse_section_path(path: str) -> tuple[str, int | None, int | None, int | None]: + """Split a section path -> ``(sheet, header_idx, data_start, data_end)``. + + Accepts both ``#section[
:]`` (contiguous; data starts + at ``header+1``) and ``#section[
::]`` (an + inherited period header that is NOT adjacent to its data). For backward + compat, a plain sheet name (no ``#section[...]``) returns all ``None``. """ m = _SECTION_RE.match(path or "") if not m: - return path or "", None, None - return m.group("sheet"), int(m.group("start")), int(m.group("end")) + return path or "", None, None, None + header_idx = int(m.group("a")) + if m.group("c") is not None: + data_start, data_end = int(m.group("b")), int(m.group("c")) + else: + data_start, data_end = header_idx + 1, int(m.group("b")) + return m.group("sheet"), header_idx, data_start, data_end @staticmethod def _materialise_sync( @@ -290,16 +489,36 @@ def _materialise_sync( from python_calamine import CalamineWorkbook # pyright: ignore[reportMissingImports] Path(target_parquet_key).parent.mkdir(parents=True, exist_ok=True) - sheet_name, sec_start, sec_end = ExcelReader._parse_section_path( + sheet_name, header_idx, data_start, data_end = ExcelReader._parse_section_path( table.sheet_or_json_path or table.name ) wb = CalamineWorkbook.from_path(source_path) sheet = wb.get_sheet_by_name(sheet_name) - rows = sheet.to_python(skip_empty_area=True) - - if sec_start is not None and sec_end is not None: - # Section-encoded path -- slice precisely. - section_rows = rows[sec_start:sec_end] + # Apply the SAME leading-blank normalisation used at enumerate time so + # the stored section indices slice the intended rows (HDR-UNSTABLE). + rows = _normalise_leading_blanks(sheet.to_python(skip_empty_area=True)) + + inherited_header = False + if header_idx is not None and data_end is not None: + # Section-encoded path. The header row + the data rows, which may be + # NON-contiguous when the header was inherited from a period header + # above the section's title (financial-report layout). For the + # common contiguous case (data_start == header_idx + 1) this is + # exactly ``rows[header_idx:data_end]``. + # + # ``rows[header_idx]`` is a scalar index, so guard it: if the sheet + # changed between enumerate and materialise (re-ingest drift -- the + # raison d'être of HDR-UNSTABLE) the index can fall past the end. The + # old slice degraded silently; we raise a contextual ValueError + # instead of an opaque IndexError. + if header_idx >= len(rows): + raise ValueError( + f"sheet {sheet_name!r}: header row {header_idx} out of range " + f"(rows={len(rows)}) for table {table.name!r} " + f"(path={table.sheet_or_json_path!r}); sheet changed since enumerate?" + ) + inherited_header = data_start != header_idx + 1 + section_rows = [rows[header_idx]] + rows[data_start:data_end] else: # Legacy / no-section path -- apply the merged-cell title heuristic. body_start = 0 @@ -329,8 +548,12 @@ def _materialise_sync( # Compacting to ONLY the populated indices yields rows that # match the visual "5 yearly columns + label" view a human # sees in Excel. + # For an inherited (non-contiguous) header, compact over the DATA rows + # only -- a period the sub-section doesn't report must not survive as an + # all-NULL year column just because the shared header names it. + compact_rows = section_rows[1:] if inherited_header and len(section_rows) > 1 else section_rows populated: set[int] = set() - for r in section_rows: + for r in compact_rows: for col_idx, cell in enumerate(r): if cell not in ("", None): populated.add(col_idx) diff --git a/src/flyquery/core/services/ingestion/stages/embed.py b/src/flyquery/core/services/ingestion/stages/embed.py index cd6e65d..d3d9211 100644 --- a/src/flyquery/core/services/ingestion/stages/embed.py +++ b/src/flyquery/core/services/ingestion/stages/embed.py @@ -94,8 +94,14 @@ async def run_embed( async with session_factory() as s: result = await s.execute( sa.text( + # profile_json/sample_values_json are pulled in so the embed + # text (and thus content_tsv) covers the column's actual + # VALUES, making value-bearing columns retrievable by + # BM25/vector -- e.g. a question for "Total Revenue" finds + # the column whose distinct values include it. """ - SELECT id, qualified_name, data_type, description, synonyms_json + SELECT id, qualified_name, data_type, description, synonyms_json, + profile_json, sample_values_json FROM flyquery_schema_objects WHERE snapshot_id = :sid AND tenant_id = :tenant ORDER BY kind, qualified_name @@ -152,9 +158,61 @@ def _build_embed_text(row: dict) -> str: flat = list(synonyms.values()) if flat: parts.append("Synonyms: " + ", ".join(str(s) for s in flat)) + values = _render_values(row) + if values: + parts.append(values) return "\n".join(p for p in parts if p) +def _render_values(row: dict, *, max_chars: int = 300) -> str: + """Compact rendering of a column's actual VALUES for the embed corpus. + + Indexing the values (not just name + description) is what lets a + question like "Total Revenue" retrieve the column whose distinct set + contains that literal. + + PII safety: prefer the PII-gated ``sample_values_json`` -- the pii_tag + stage wipes it to ``[]`` when a redact/reject policy fires, so an empty + list here means "do not surface raw samples". When no gated samples are + present we fall back to ``profile_json.top_values``, the stored distinct + set for low-cardinality columns (aggregate / low-cardinality, so lower + PII risk). The result is capped to ``max_chars`` either way. + """ + seen: set[str] = set() + uniq: list[str] = [] + + # Preferred source: PII-gated samples (empty list = intentionally wiped). + samples = row.get("sample_values_json") + if isinstance(samples, list) and samples: + for v in samples: + if v is None: + continue + s = str(v) + if s not in seen: + seen.add(s) + uniq.append(s) + else: + # Fallback: stored distinct set (aggregate, low-cardinality). + prof = row.get("profile_json") + top_values = (prof or {}).get("top_values") if isinstance(prof, dict) else None + for tv in top_values or []: + v = tv.get("value") if isinstance(tv, dict) else tv + if v is None: + continue + s = str(v) + if s not in seen: + seen.add(s) + uniq.append(s) + + if not uniq: + return "" + body = " | ".join(uniq) + if len(body) > max_chars: + # Trim on a value boundary so we never emit a half-truncated literal. + body = body[:max_chars].rsplit("|", 1)[0].strip() + " …" + return f"Values: {body}" + + async def _update_object( *, object_id: uuid.UUID, diff --git a/src/flyquery/core/services/ingestion/stages/parse.py b/src/flyquery/core/services/ingestion/stages/parse.py index d54a8c2..fea6eba 100644 --- a/src/flyquery/core/services/ingestion/stages/parse.py +++ b/src/flyquery/core/services/ingestion/stages/parse.py @@ -182,6 +182,17 @@ async def _propose_meaningful_column_names( "stage=parse rename_skipped reason=no_api_key fallback_prefix=%s", section_prefix, ) + if fallback == current_names: + return mat_result + # Rewrite the physical Parquet so its column names match the + # section-prefixed fallback we record in the MaterialiseResult -- + # otherwise the persisted schema_objects names diverge from the + # Parquet header and downstream DuckDB stages hit Binder Errors. + await _rename_parquet_columns( + parquet_path=parquet_path, + current_columns=current_names, + proposed_columns=fallback, + ) return _rebuild_mat_result(mat_result, fallback) try: diff --git a/src/flyquery/core/services/ingestion/stages/profile.py b/src/flyquery/core/services/ingestion/stages/profile.py index 1c6be8b..4938650 100644 --- a/src/flyquery/core/services/ingestion/stages/profile.py +++ b/src/flyquery/core/services/ingestion/stages/profile.py @@ -18,7 +18,9 @@ - null_fraction - approx_count_distinct (via approx_count_distinct()) - min / max (numeric + temporal columns only) - - top 5 values (low-cardinality only: distinct_estimate ≤ 100) + - full distinct value set, capped at 100 (low-cardinality only: + distinct_estimate ≤ 100) -- surfaced into the NL→SQL prompts so + filter/CASE literals are copied verbatim from real values Skips the whole column if the snapshot's n_rows_actual exceeds FLYQUERY_PROFILE_ROW_THRESHOLD (default 10M rows). @@ -31,6 +33,7 @@ import asyncio import json import logging +import re import uuid from typing import Any @@ -39,6 +42,49 @@ logger = logging.getLogger(__name__) +# Generic, language-agnostic markers for pre-aggregated subtotal / rollup +# values that can coexist with detail rows in a categorical dimension. +# This is a CURATED set of common total markers (English + Spanish), NOT +# tied to any specific dataset/column name. A value matching this is flagged +# as a likely subtotal so the query prompt can avoid mixing subtotal + detail +# rows (summing across all rows double-counts; filtering to it drops detail). +# Token boundary is a string edge or a non-alphanumeric separator (space, +# underscore, hyphen, etc.) so labels like "Total_Department" are caught while +# words that merely embed a marker (e.g. "allocation", "North America") are not. +_SUBTOTAL_MARKERS = ( + r"sub[ _-]?total|grand[ _-]?total|gran[ _-]?total|totals?|totales|" + r"all|todos|todas|suma|consolidad[oa]s?|overall" +) +# A value is flagged ONLY when it is essentially a total *marker by itself* +# (the whole value is the marker), OR a machine-generated pivot label where the +# marker is joined to another token by ``_``/``-`` (e.g. ``Total_Department``, +# ``Department-Total``). This deliberately does NOT match space-separated +# natural-language line items like ``Total Revenue`` / ``Total Nexium`` -- those +# are legitimate measure values the agent must keep, not pre-aggregated rows. +_SUBTOTAL_EXACT = re.compile(rf"^(?:{_SUBTOTAL_MARKERS})$", re.IGNORECASE) +_SUBTOTAL_JOINED = re.compile( + rf"(?:^|[_-])(?:{_SUBTOTAL_MARKERS})(?=[_-])|(?<=[_-])(?:{_SUBTOTAL_MARKERS})$", + re.IGNORECASE, +) +_BLANK_PLACEHOLDERS = {"(blank)", "(empty)", "(null)", "(en blanco)", "(vacío)", "(vacio)"} + + +def _looks_like_subtotal(value: str) -> bool: + """Heuristic: True only for unambiguous total/rollup pivot labels. + + Conservative + side-effect free: matches a standalone total marker, a + separator-joined pivot label (``Total_Department``), or a blank-ish + placeholder -- but NOT space-separated natural-language values such as + ``Total Revenue``. Consumed downstream as a prompt hint only. + """ + if value is None: + return False + stripped = value.strip() + if stripped == "" or stripped.lower() in _BLANK_PLACEHOLDERS: + return True + return bool(_SUBTOTAL_EXACT.match(stripped) or _SUBTOTAL_JOINED.search(stripped)) + + # Data-type compatibility groups _NUMERIC_TYPES = frozenset( { @@ -100,6 +146,9 @@ async def run_profile( columns = await _load_columns(tenant_id, snapshot_id, session_factory) profiled = 0 + # Collect profiles first so a second pass can detect self-referencing + # hierarchy columns (manager->report) before persisting. + computed: dict[str, tuple[uuid.UUID, str, dict[str, Any]]] = {} for col in columns: col_id: uuid.UUID = col["id"] col_name: str = col["qualified_name"].rsplit(".", 1)[-1] @@ -109,8 +158,23 @@ async def run_profile( _profile_column_sync, parquet_key, col_name, data_type, n_rows_actual ) if profile is not None: - await _persist_profile(col_id, profile, tenant_id, session_factory) - profiled += 1 + computed[col_name] = (col_id, data_type, profile) + + # Self-reference detection: annotate columns whose values are mostly + # contained in a higher-cardinality "entity" column of the same table + # (e.g. a manager/owner column whose values are people from the employee + # column). Purely structural (value containment) -- no name/vocabulary + # heuristics -- so it generalises to any self-referencing hierarchy. + try: + refs = await asyncio.to_thread(_detect_self_references, parquet_key, computed, n_rows_actual) + for ref_col, entity_col in refs.items(): + computed[ref_col][2]["references_column"] = entity_col + except Exception as exc: # noqa: BLE001 -- detection is best-effort + logger.warning("stage=profile self-reference detection failed snapshot=%s err=%s", snapshot_id, exc) + + for col_id, _dt, profile in computed.values(): + await _persist_profile(col_id, profile, tenant_id, session_factory) + profiled += 1 logger.info( "stage=profile snapshot_id=%s columns_profiled=%d", @@ -120,6 +184,79 @@ async def run_profile( return {"snapshot_id": str(snapshot_id), "columns_profiled": profiled} +def _detect_self_references( + parquet_key: str, + computed: dict[str, tuple[uuid.UUID, str, dict[str, Any]]], + n_rows: int, +) -> dict[str, str]: + """Find columns that reference a higher-cardinality entity column. + + A reference (foreign-key-like, including a self-referencing org + hierarchy) is detected purely structurally: a text column ``R`` whose + distinct values are mostly a SUBSET of another, higher-cardinality text + column ``E`` in the same table. ``R`` must be on the many-to-one side + (meaningfully fewer distinct values than ``E``) so two near-duplicate + name columns are not flagged as a hierarchy. + + Returns ``{ref_col: entity_col}``. No names/keywords are inspected -- + this generalises to any dataset's manager/owner/parent columns. + """ + import duckdb + + text_cols = { + name: prof["distinct_estimate"] + for name, (_cid, dt, prof) in computed.items() + if not _is_numeric(dt) and not _is_temporal(dt) and prof.get("distinct_estimate") + } + if len(text_cols) < 2: + return {} + # Entity columns: high-cardinality name/id columns (the "one" side). + entity_cols = [c for c, d in text_cols.items() if d >= max(8, 0.2 * n_rows)] + if not entity_cols: + return {} + + refs: dict[str, str] = {} + con = duckdb.connect() + try: + for ref_col, d_ref in text_cols.items(): + if d_ref < 3: + continue + for ent_col in entity_cols: + if ent_col == ref_col or d_ref > 0.7 * text_cols[ent_col]: + continue # ref must be the many-to-one (smaller) side + s_ref = '"' + ref_col.replace('"', '""') + '"' + s_ent = '"' + ent_col.replace('"', '""') + '"' + try: + row = con.execute( + f"SELECT count(DISTINCT {s_ref}), " + f"count(DISTINCT CASE WHEN {s_ref} IN " + f"(SELECT {s_ent} FROM read_parquet(?)) THEN {s_ref} END) " + f"FROM read_parquet(?) WHERE {s_ref} IS NOT NULL AND {s_ref} <> ''", + [parquet_key, parquet_key], + ).fetchone() + except Exception: # noqa: BLE001 -- skip uncomparable columns + continue + if row and row[0] and (row[1] / row[0]) >= 0.6: + # Discriminate a genuine cross-reference (the value is a + # DIFFERENT entity than the row's own -- a manager, owner, + # parent) from a row-wise DUPLICATE of the entity column (the + # SAME entity copied, e.g. a second name column). On a + # duplicate, ref == entity on most rows; on a real reference + # they differ. Skip duplicates. + eq = con.execute( + f"SELECT avg(CASE WHEN {s_ref} = {s_ent} THEN 1.0 ELSE 0.0 END) " + f"FROM read_parquet(?) WHERE {s_ref} IS NOT NULL AND {s_ent} IS NOT NULL", + [parquet_key], + ).fetchone() + if eq and eq[0] is not None and eq[0] > 0.5: + continue # row-wise copy -> not a hierarchy reference + refs[ref_col] = ent_col + break + finally: + con.close() + return refs + + # --------------------------------------------------------------------------- # DuckDB profiling (runs in thread) # --------------------------------------------------------------------------- @@ -172,7 +309,13 @@ def _profile_column_sync( profile["min"] = str(col_min) if col_min is not None else None profile["max"] = str(col_max) if col_max is not None else None - # Top values for low-cardinality columns (distinct ≤ 100) + # Top values for low-cardinality columns (distinct ≤ 100). + # We store the FULL distinct set (capped at 100) rather than + # just the top 5: the NL→SQL grounding/generation agents copy + # WHERE/CASE literals verbatim from these values, so a + # truncated list silently breaks any filter on a value that + # fell outside the top 5 (e.g. the P&L-line members of an + # operating-profit formula, or a brand outside the 5 biggest). if distinct_estimate <= 100 and distinct_estimate > 0: top_rows = conn.execute( f"SELECT {safe_col}, count(*) AS cnt " @@ -180,11 +323,24 @@ def _profile_column_sync( f"WHERE {safe_col} IS NOT NULL " f"GROUP BY {safe_col} " f"ORDER BY cnt DESC " - f"LIMIT 5", + f"LIMIT 100", [parquet_key], ).fetchall() profile["top_values"] = [{"value": str(r[0]), "count": r[1]} for r in top_rows] + # Subtotal detection (heuristic hint, cheap + side-effect free). + # For a low-cardinality TEXT dimension, flag values whose text + # matches a generic total/aggregate marker (en/es) or a blank + # placeholder. Such values are likely pre-aggregated rollup rows + # coexisting with detail rows; the query prompt consumes this to + # avoid mixing subtotal + detail (summing double-counts; filtering + # to the subtotal drops detail). No measure join is available + # here, so this stays purely structural/textual. + if not _is_numeric(data_type) and not _is_temporal(data_type): + subtotal_values = [str(r[0]) for r in top_rows if _looks_like_subtotal(str(r[0]))] + if subtotal_values: + profile["subtotal_values"] = subtotal_values + return profile finally: diff --git a/src/flyquery/core/services/query/query_service.py b/src/flyquery/core/services/query/query_service.py index b11ecf6..ffbd8b4 100644 --- a/src/flyquery/core/services/query/query_service.py +++ b/src/flyquery/core/services/query/query_service.py @@ -128,6 +128,33 @@ def _render_grounding_prompt( out.append(f"- `{qn}` :: {text}") out.append("") + # 2b. The "Column value catalogue" lists EVERY column with its real + # values (distinct set for low-cardinality columns; numeric + # range otherwise). This is the ground truth the agent must copy + # filter/CASE literals from -- it prevents guessing wrong + # literals (`Year IN (2023)` when the values are `FY23`), maps a + # question entity to the column whose values contain it + # (`DAPA` lives in a column's values, not a column name), and + # reveals tall/EAV layouts (P&L line items are VALUES of a single + # column) and scaled-duplicate measures (`FY` vs `FY (Real)`). + inv_columns = [h for h in inventory if (getattr(h, "metadata", {}) or {}).get("kind") == "COLUMN"] + cols_with_values = [h for h in inv_columns if (getattr(h, "metadata", {}) or {}).get("values")] + if cols_with_values: + out.append(f"# Column value catalogue ({len(cols_with_values)} columns)") + out.append( + "Real values per column. When the question names an entity (a brand, " + "year, market, category, P&L line, team…) that is NOT a column name, " + "find the column whose values contain it and filter THAT column. Copy " + "filter/CASE literals VERBATIM from these values (values may be encoded, " + "e.g. a year shown as `FY23`). If several columns share members, prefer " + "the one whose values match the question most precisely." + ) + for h in cols_with_values: + md = getattr(h, "metadata", None) or {} + qn = md.get("qualified_name") or "?" + out.append(f"- `{qn}` :: {md.get('values')}") + out.append("") + examples = bundle.get("examples", []) or [] if examples: out.append(f"# Approved Q→SQL examples ({len(examples)})") @@ -214,6 +241,13 @@ def _render_generation_prompt( inv = schema_inventory or [] inv_tables = [h for h in inv if (getattr(h, "metadata", {}) or {}).get("kind") == "TABLE"] + # Map qualified_name -> value fingerprint so the generator copies + # filter/CASE literals verbatim from real values rather than guessing. + value_index: dict[str, str] = {} + for h in inv: + md = getattr(h, "metadata", None) or {} + if md.get("kind") == "COLUMN" and md.get("values"): + value_index[md.get("qualified_name")] = md.get("values") if inv_tables: out.append(f"# Complete dataset catalogue ({len(inv_tables)} tables)") out.append( @@ -243,9 +277,11 @@ def _render_generation_prompt( out.append(f"- `{getattr(t, 'table_qualified_name', t)}`") out.append("") if g_columns: - out.append("## Columns in scope") + out.append("## Columns in scope (with real values — copy literals verbatim)") for c in g_columns: - out.append(f"- `{getattr(c, 'column_qualified_name', c)}`") + cqn = getattr(c, "column_qualified_name", c) + vals = value_index.get(cqn) + out.append(f"- `{cqn}`" + (f" :: {vals}" if vals else "")) out.append("") if g_joins: out.append("## Approved joins") @@ -257,6 +293,20 @@ def _render_generation_prompt( ) out.append("") + # Full column-value catalogue -- the grounding agent may under-select + # columns, so expose every column's real values here too. This is the + # source of truth for WHERE / CASE literals. + if value_index: + out.append(f"# Column value catalogue ({len(value_index)} columns)") + out.append( + "Real values per column. Copy filter/CASE literals VERBATIM from these. " + "If the question names an entity that is not a column name, filter the " + "column whose values contain it." + ) + for qn, vals in value_index.items(): + out.append(f"- `{qn}` :: {vals}") + out.append("") + out.append("# Task") out.append( "Generate up to N candidate DuckDB SQL queries that answer the question, " @@ -267,6 +317,36 @@ def _render_generation_prompt( "rather than `FROM orbis_companies.IVI_MALAGA_SL__Activos`.\n" "- Quote any column name that isn't a plain identifier (e.g. date-shaped " 'names like `2024-12-31` must be `"2024-12-31"`).\n' + "- Copy every WHERE / CASE / IN literal VERBATIM from the column value " + "catalogue above -- never invent or reformat a value (a year is `FY23`, " + "not `2023`; a market may be `Brazil` or `44000BR Brazil` -- use exactly " + "what is listed).\n" + "- When the metric the user names is not a column but appears among a " + "column's listed values, filter that column (tall/EAV layout): e.g. P&L " + "line items like `Total Revenue` / `Manpower` are VALUES of a single " + "category column, selected with `CASE WHEN \"\" = 'Total Revenue' …`.\n" + "- When two numeric columns are near-duplicates whose ranges differ by a " + "constant factor (~10^k), they are the same measure at different scales -- " + "prefer the larger-magnitude one for monetary sums.\n" + "- Do NOT add a WHERE filter on a dimension the question did not ask to " + "slice by -- aggregate across ALL of its values, and do NOT drop a row " + "just because a category value's name contains 'total'/'all' (those are " + "usually legitimate, often 'unallocated', buckets). Exclude a value only " + "if you can confirm it is literally the sum of the other rows.\n" + "- Return what is ASKED FOR: if the question asks for names, a list, " + "'who', 'which', or 'dame los nombres/quiénes', SELECT the identifying " + "column(s) (e.g. the name) and return the matching ROWS -- do NOT collapse " + "to a COUNT. Use COUNT/aggregates only when a count or total is requested. " + "If BOTH a count and the names are asked, return the names (the count is " + "derivable from the row count).\n" + "- HIERARCHY questions: when the question asks about a person's TEAM, " + "direct reports, the people 'at their charge' / under them, their org, or " + "movements in THEIR structure, it is a self-referencing hierarchy. A " + "column marked 'HIERARCHY: holds entities/people from column X' holds each " + "row's manager/owner. The person's team = the ROWS where such a column " + "equals that person (filter it with case-insensitive LIKE '%name%'), NOT " + "the person's own row. Try EVERY hierarchy column (a person may appear in " + "more than one), and match names tolerantly (accents/spacing).\n" "- Be a SINGLE statement (no multi-statement; no DDL).\n" "- Be a SELECT (DuckDB-flavored)." ) @@ -542,7 +622,11 @@ async def answer( # into the persisted query record for reproducibility. metric_name = grounded.metrics[0].metric_name compiled, metric_version = await self._compiled_metric_sql( - metric_name, dataset_id, tenant_id=tenant_id, workspace_id=workspace_id + metric_name, + dataset_id, + tenant_id=tenant_id, + workspace_id=workspace_id, + extra_filter=getattr(grounded.metrics[0], "extra_filter", None), ) if compiled: chosen_sql = compiled @@ -577,7 +661,17 @@ async def answer( gen_run = await self._generation_agent.run(gen_prompt) gen_out = getattr(gen_run, "output", gen_run) candidates_json = [c.model_dump() for c in gen_out.candidates] - chosen_sql = gen_out.candidates[0].sql + # Don't blindly take the highest-confidence candidate: probe them + # (DuckDB only, no extra LLM) and prefer one that passes the firewall + # and returns non-empty, non-degenerate rows. Empty candidate list + # degrades to "" (FAILED) instead of raising an IndexError. + chosen_sql = await self._select_best_candidate( + [c.sql for c in gen_out.candidates if c.sql], + dataset_id=dataset_id, + scopes=scopes, + dataset_allowlist=dataset_allowlist, + pins=prior_snapshot_pins, + ) # ------------------------------------------------------------------ # 5. AST classify + scope guard @@ -666,6 +760,7 @@ async def answer( attached = await self._table_resolver.resolve( dataset_id, list(ast.table_refs), + pins=prior_snapshot_pins, ) result = await self._executor.execute(chosen_sql, attached) retries = 0 @@ -715,7 +810,9 @@ async def answer( ) retries += 1 continue - attached = await self._table_resolver.resolve(dataset_id, list(ast.table_refs)) + attached = await self._table_resolver.resolve( + dataset_id, list(ast.table_refs), pins=prior_snapshot_pins + ) result = await self._executor.execute(chosen_sql, attached) retries += 1 @@ -747,6 +844,24 @@ async def answer( # 9. Clarification frame (emitted alongside answer when confidence is low) # ------------------------------------------------------------------ clarification = self._clarification(grounded) + # A syntactically-valid query that returns 0 rows (or a single + # all-NULL/zero aggregate) is suspicious: the usual cause is a + # filter literal that doesn't match how the data is encoded. Rather + # than report it as a confident empty answer, surface a clarification + # and downgrade confidence so the caller knows to verify. + suspicious_empty = isinstance(result, ExecutionResult) and self._is_suspicious_empty(result) + if clarification is None and suspicious_empty: + from flyquery.interfaces.query import ClarificationFrame + + clarification = ClarificationFrame( + questions=[ + "The query executed successfully but returned no matching data " + "(0 rows / empty result). The filter values may not match how the " + "data is encoded -- please verify the exact column values (e.g. " + "category labels or period format) or rephrase the question." + ], + reasons=[], + ) clarification_emitted = clarification is not None # ------------------------------------------------------------------ @@ -796,7 +911,12 @@ async def answer( # ------------------------------------------------------------------ # 12. Auto-learn (only on first-shot OK + no PII + no clarification) # ------------------------------------------------------------------ - if execution_status == "OK" and isinstance(result, ExecutionResult) and not clarification_emitted: + if ( + execution_status == "OK" + and isinstance(result, ExecutionResult) + and result.row_count > 0 + and not clarification_emitted + ): await self._auto_learner.maybe_propose( tenant_id=tenant_id, workspace_id=workspace_id, @@ -811,6 +931,16 @@ async def answer( # ------------------------------------------------------------------ # 13. Persist conversation turn (Phase E drill-down) # ------------------------------------------------------------------ + # THIS turn's snapshot pins: the snapshot each resolved table was + # answered against. Tables already pinned by an earlier turn keep + # their pin (prior wins); newly-referenced tables pin to current. + # Persisting THIS turn's pins (not the prior turn's) is what makes + # drill-down reproducible across a mid-conversation re-ingest. + this_turn_pins: dict[str, str] = { + **(await self._table_resolver.current_snapshots(dataset_id, list(ast.table_refs))), + **prior_snapshot_pins, + } + if ( conversation_id is not None and self._conversation_service is not None @@ -824,7 +954,7 @@ async def answer( executed_sql=chosen_sql, summary=explanation_obj.summary if explanation_obj else None, table_qnames_json=list(ast.table_refs), - snapshot_pins_json=prior_snapshot_pins, + snapshot_pins_json=this_turn_pins, elapsed_ms=elapsed, ) @@ -839,13 +969,85 @@ async def answer( chart_hint=explanation_obj.chart_hint if explanation_obj else None, explanation=explanation_obj.summary if explanation_obj else None, clarification=clarification, - grounded_summary=self._grounded_summary(grounded), + grounded_summary=self._grounded_summary( + grounded, confidence_cap=0.4 if suspicious_empty else None + ), + snapshot_pins=this_turn_pins, ) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ + @staticmethod + def _is_suspicious_empty(result) -> bool: + """True when an executed result is empty/degenerate enough to doubt. + + Catches the canonical wrong-literal symptom: a valid query that + matched nothing (0 rows), or a single-row single-column aggregate + whose only value is NULL / 0 / 0.0 (e.g. a SUM/CASE pivot where + every branch missed). + """ + if result.row_count == 0: + return True + rows = getattr(result, "rows", None) or [] + if result.row_count == 1 and len(rows) == 1 and isinstance(rows[0], dict) and len(rows[0]) == 1: + (only_value,) = rows[0].values() + return only_value is None or only_value == 0 + return False + + async def _select_best_candidate( + self, + candidate_sqls: list[str], + *, + dataset_id: uuid.UUID, + scopes: set[str], + dataset_allowlist: set[uuid.UUID] | None, + pins: dict[str, str], + ) -> str: + """Pick the candidate SQL that best answers the question. + + Generation emits N candidates ranked by self-reported confidence, but + the top one sometimes filters on the wrong column (or under-searches a + set of hierarchy columns) and returns 0 rows while a lower-ranked + candidate is correct. So we probe the candidates and prefer the first + that (a) passes the firewall, (b) executes, and (c) returns non-empty, + non-degenerate rows. This is DuckDB-only -- NO extra LLM calls -- and + general: it just prefers a candidate that actually returns data. + + Falls back to the first candidate that executed at all, else the first + candidate (so the existing scope/critic handling downstream is + unchanged when nothing is clearly better). + """ + if len(candidate_sqls) <= 1: + return candidate_sqls[0] if candidate_sqls else "" + + first_executed: str | None = None + for sql in candidate_sqls: + ast = self._ast_classifier.classify(sql) + table_kinds = await self._table_kinds_by_name(list(ast.table_refs), dataset_id) + dataset_of_table = await self._dataset_of_tables(list(ast.table_refs), dataset_id) + try: + self._scope_guard.check( + classification=ast, + scopes=scopes, + table_kinds_by_name=table_kinds, + dataset_allowlist=dataset_allowlist, + dataset_of_table=dataset_of_table, + ) + except ScopeGuardError: + continue # unsafe candidate -- skip + if sorted({t for t in ast.table_refs if t} - set(table_kinds.keys())): + continue # references a table not in the dataset -- skip + attached = await self._table_resolver.resolve(dataset_id, list(ast.table_refs), pins=pins) + result = await self._executor.execute(sql, attached) + if isinstance(result, ExecutionResult): + if not self._is_suspicious_empty(result): + return sql # passes firewall + returns real rows -- best + if first_executed is None: + first_executed = sql # remember first successful-but-empty + return first_executed or candidate_sqls[0] + async def _table_kinds_by_name( self, table_names: list[str], @@ -886,11 +1088,19 @@ async def _compiled_metric_sql( *, tenant_id: str, workspace_id: uuid.UUID, + extra_filter: str | None = None, ) -> tuple[str | None, int | None]: """Fetch + bind the compiled SQL for a PUBLISHED metric. Returns ``(bound_sql, current_version)`` so the version can be pinned in the query record, or ``(None, None)`` when no usable metric is found. + + ``extra_filter`` is the per-question slice the grounding agent derived + (e.g. ``Market = 'Brazil' AND Year = 'FY24'``) to be appended to the + metric's WHERE via the compiler's ``{extra_filter_clause}`` slot. It is + an LLM-supplied predicate, so it is re-run through the publish-time + firewall before binding; an unsafe filter is dropped (the metric still + returns its unfiltered value) rather than executed. """ if self._semantic_repo is None: return None, None @@ -903,9 +1113,31 @@ async def _compiled_metric_sql( return None, None if not row or not row.get("compiled_sql_template"): return None, None - bound = SemanticCompiler.bind(row["compiled_sql_template"]) + + safe_filter = self._firewall_extra_filter(row["compiled_sql_template"], extra_filter) + bound = SemanticCompiler.bind(row["compiled_sql_template"], extra_filter=safe_filter) return bound, row.get("current_version") + @staticmethod + def _firewall_extra_filter(template: str, extra_filter: str | None) -> str | None: + """Validate an LLM-supplied metric filter via the publish-time firewall. + + Returns the filter when the bound SQL passes ``assert_safe_template``, + else ``None`` (filter dropped). Defensive: any firewall/parse failure + also drops the filter rather than risking an unsafe predicate. + """ + if not extra_filter: + return None + try: + from flyquery.core.services.semantic.firewall import assert_safe_template + + probe = SemanticCompiler.bind(template, extra_filter=extra_filter) + assert_safe_template(probe) + return extra_filter + except Exception as exc: # noqa: BLE001 -- any failure → drop the filter + logger.warning("dropping unsafe semantic extra_filter %r: %s", extra_filter, exc) + return None + def _clarification(self, grounded) -> Any: """Build a ClarificationFrame if grounding confidence is low.""" from flyquery.interfaces.query import ClarificationFrame @@ -917,11 +1149,19 @@ def _clarification(self, grounded) -> Any: ) return None - def _grounded_summary(self, grounded) -> dict: - """Convert GroundedContext to a summary dict for the response.""" + def _grounded_summary(self, grounded, confidence_cap: float | None = None) -> dict: + """Convert GroundedContext to a summary dict for the response. + + ``confidence_cap`` lets the caller lower the reported confidence when + the executed result is suspicious (e.g. 0 rows from a wrong literal), + so a confidently-wrong empty answer is not surfaced at high confidence. + """ + confidence = grounded.confidence + if confidence_cap is not None: + confidence = min(confidence, confidence_cap) return { "path": grounded.path, - "confidence": grounded.confidence, + "confidence": confidence, "table_count": len(grounded.tables), "missing_info": grounded.missing_info, } diff --git a/src/flyquery/core/services/retrieval/reranker.py b/src/flyquery/core/services/retrieval/reranker.py index d3a9fa4..6c117f7 100644 --- a/src/flyquery/core/services/retrieval/reranker.py +++ b/src/flyquery/core/services/retrieval/reranker.py @@ -21,10 +21,16 @@ from __future__ import annotations +import logging from typing import Any, Protocol from flyquery.core.services.retrieval.search_index import Hit +logger = logging.getLogger(__name__) + +# Guard so the "reranking disabled" warning is emitted at most once per process. +_warned_noop_fallback = False + class Reranker(Protocol): """Protocol for a reranking step in the retrieval pipeline.""" @@ -88,10 +94,21 @@ def build_reranker(settings: Any) -> NoopReranker | CrossEncoderReranker: :param settings: ``FlyquerySettings`` instance :return: a ready-to-use reranker """ + global _warned_noop_fallback model_name = getattr(settings, "reranker_model", "") or "" if not model_name: return NoopReranker() try: return CrossEncoderReranker(model_name) - except Exception: # noqa: BLE001 + except Exception as exc: # noqa: BLE001 + if not _warned_noop_fallback: + _warned_noop_fallback = True + logger.warning( + "reranker model=%s unavailable (%s) -- falling back to NoopReranker. " + "Relevance reranking is DISABLED; results are truncated by retrieval " + "order only (install sentence-transformers / make the cross-encoder " + "model loadable to enable it).", + model_name, + exc, + ) return NoopReranker() diff --git a/src/flyquery/core/services/retrieval/search_index.py b/src/flyquery/core/services/retrieval/search_index.py index 8d10f5c..beb828b 100644 --- a/src/flyquery/core/services/retrieval/search_index.py +++ b/src/flyquery/core/services/retrieval/search_index.py @@ -42,6 +42,115 @@ class Hit: metadata: dict = field(default_factory=dict) +def value_fingerprint( + data_type: str | None, + sample_values_json: object, + profile_json: object, + *, + max_values: int = 40, + max_chars: int = 400, +) -> str: + """Compact, human-readable summary of a column's ACTUAL values. + + Surfaced into the grounding/generation prompts so the agents copy + WHERE / CASE literals verbatim from real values instead of guessing + -- e.g. a fiscal year stored as ``FY23`` (not ``2023``), the members + of a tall/EAV category column (``P&L Line System`` rows like + ``Total Revenue`` / ``Manpower``), or the magnitude gap between a + scaled-duplicate measure (``FY`` ~0.02 vs ``FY (Real)`` ~25748). + + Returns ``""`` when there is nothing useful to show. + """ + prof = profile_json if isinstance(profile_json, dict) else {} + + # NOTE: profiling stores ``subtotal_values`` (name-based candidates for + # pre-aggregated rows). We deliberately do NOT surface them as an + # exclusion directive: a "Total_*" value in a dimension is just as often a + # legitimate additive bucket (e.g. unallocated/corporate) as a true rollup, + # and any hint makes the agent wrongly drop it. Whether to exclude requires + # the structural test (does the value's aggregate == the sum of the others?) + # which is not available per-column at profile time. The general "aggregate + # across ALL values / don't drop a value" prompt rule handles this safely. + subtotal_note = "" + + # Self-referencing hierarchy hint: this column's values are entities from + # another (higher-cardinality) column -- e.g. a manager column whose values + # are people from the employee column. Surfaced even for high-cardinality + # columns that have no listable values, because that is exactly when the + # agent cannot otherwise tell who a person reports to. + ref_col = prof.get("references_column") + ref_note = ( + f" | HIERARCHY: holds entities/people from column '{ref_col}' (e.g. each row's " + f"manager/owner/parent). To get a given person's group/team/reports, filter THIS " + f"column to that person (case-insensitive LIKE), not the person's own row." + if ref_col + else "" + ) + + # Categorical: the stored distinct value set (low-cardinality columns). + top_values = prof.get("top_values") or [] + if top_values: + seen: set[str] = set() + uniq: list[str] = [] + for tv in top_values: + v = tv.get("value") if isinstance(tv, dict) else tv + if v is None: + continue + s = str(v) + if s not in seen: + seen.add(s) + uniq.append(s) + shown = uniq[:max_values] + body = " | ".join(shown) + if len(body) > max_chars: + body = body[:max_chars].rsplit("|", 1)[0].strip() + " | …" + more = "" if len(uniq) <= len(shown) else f" (+{len(uniq) - len(shown)} more)" + return f"values: {body}{more}{subtotal_note}{ref_note}" if body else ref_note.strip(" |") + + # Numeric / temporal: range + cardinality (exposes scaled duplicates). + col_min, col_max = prof.get("min"), prof.get("max") + if col_min is not None or col_max is not None: + rng = f"range: {col_min} .. {col_max}" + dist = prof.get("distinct_estimate") + if dist is not None: + rng += f" (~{dist} distinct)" + return rng + ref_note + + # Fallback: a few raw sample values (high-cardinality columns). + samples = sample_values_json if isinstance(sample_values_json, list) else [] + if samples: + seen2: set[str] = set() + uniq2: list[str] = [] + for v in samples: + s = str(v) + if s not in seen2: + seen2.add(s) + uniq2.append(s) + if uniq2: + return "e.g.: " + " | ".join(uniq2[:8]) + ref_note + return ref_note.strip(" |") + + +def _column_hit(r, score: float) -> Hit: + """Build a ranked schema-object Hit, enriched with a value fingerprint. + + Used by the BM25 + vector column searches so the ranked "Top-ranked + column matches" the grounding agent sees carry real values, not just + name + description. + """ + fp = value_fingerprint(r.data_type, r.sample_values_json, r.profile_json) + text = f"{r.qualified_name}: {r.data_type}\n{r.description or ''}" + if fp: + text += f"\n{fp}" + return Hit( + source_kind="schema_object", + id=r.id, + text=text, + score=score, + metadata={"qualified_name": r.qualified_name, "table_id": str(r.table_id), "values": fp}, + ) + + class SearchIndex: """Read-only query helpers that operate on a shared ``AsyncSession``.""" @@ -60,10 +169,12 @@ async def bm25_schema_objects(self, query: str, dataset_id: uuid.UUID, limit: in sa.text( """ SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, + o.sample_values_json, o.profile_json, ts_rank(o.content_tsv, plainto_tsquery('english', :q)) AS score FROM flyquery_schema_objects o JOIN flyquery_tables t ON t.id = o.table_id WHERE t.dataset_id = :ds AND o.is_active = true + AND o.snapshot_id = t.current_snapshot_id AND o.content_tsv @@ plainto_tsquery('english', :q) ORDER BY score DESC LIMIT :lim @@ -71,16 +182,7 @@ async def bm25_schema_objects(self, query: str, dataset_id: uuid.UUID, limit: in ), {"q": query, "ds": dataset_id, "lim": limit}, ) - return [ - Hit( - source_kind="schema_object", - id=r.id, - text=f"{r.qualified_name}: {r.data_type}\n{r.description or ''}", - score=float(r.score), - metadata={"qualified_name": r.qualified_name, "table_id": str(r.table_id)}, - ) - for r in rows.mappings() - ] + return [_column_hit(r, float(r.score)) for r in rows.mappings()] async def vector_schema_objects( self, @@ -99,26 +201,19 @@ async def vector_schema_objects( sa.text( """ SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, + o.sample_values_json, o.profile_json, 1 - (o.embedding <=> CAST(:emb AS vector)) AS score FROM flyquery_schema_objects o JOIN flyquery_tables t ON t.id = o.table_id - WHERE t.dataset_id = :ds AND o.is_active = true AND o.embedding IS NOT NULL + WHERE t.dataset_id = :ds AND o.is_active = true + AND o.snapshot_id = t.current_snapshot_id AND o.embedding IS NOT NULL ORDER BY o.embedding <=> CAST(:emb AS vector) LIMIT :lim """ ), {"emb": str(query_embedding), "ds": dataset_id, "lim": limit}, ) - return [ - Hit( - source_kind="schema_object", - id=r.id, - text=f"{r.qualified_name}: {r.data_type}\n{r.description or ''}", - score=float(r.score), - metadata={"qualified_name": r.qualified_name, "table_id": str(r.table_id)}, - ) - for r in rows.mappings() - ] + return [_column_hit(r, float(r.score)) for r in rows.mappings()] async def all_schema_objects( self, @@ -158,7 +253,9 @@ async def all_schema_objects( ON c.table_id = o.table_id AND c.kind = 'COLUMN' AND c.is_active = true + AND c.snapshot_id = t.current_snapshot_id WHERE t.dataset_id = :ds AND o.is_active = true AND o.kind = 'TABLE' + AND o.snapshot_id = t.current_snapshot_id GROUP BY o.id, o.qualified_name, o.description, o.table_id, o.kind ORDER BY o.qualified_name """ @@ -175,10 +272,12 @@ async def all_schema_objects( await self._session.execute( sa.text( """ - SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, o.kind + SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, o.kind, + o.sample_values_json, o.profile_json FROM flyquery_schema_objects o JOIN flyquery_tables t ON t.id = o.table_id WHERE t.dataset_id = :ds AND o.is_active = true AND o.kind = 'COLUMN' + AND o.snapshot_id = t.current_snapshot_id ORDER BY o.qualified_name LIMIT :lim """ @@ -223,16 +322,21 @@ async def all_schema_objects( ) for r in column_rows: + fp = value_fingerprint(r["data_type"], r["sample_values_json"], r["profile_json"]) + text = f"{r['qualified_name']}: {r['data_type'] or ''}\n{r['description'] or ''}" + if fp: + text += f"\n{fp}" hits.append( Hit( source_kind="schema_object", id=r["id"], - text=f"{r['qualified_name']}: {r['data_type'] or ''}\n{r['description'] or ''}", + text=text, score=1.0, metadata={ "qualified_name": r["qualified_name"], "table_id": str(r["table_id"]), "kind": "COLUMN", + "values": fp, }, ) ) diff --git a/src/flyquery/web/controllers/query_controller.py b/src/flyquery/web/controllers/query_controller.py index 93169d6..532cfbf 100644 --- a/src/flyquery/web/controllers/query_controller.py +++ b/src/flyquery/web/controllers/query_controller.py @@ -51,7 +51,12 @@ from flyquery.core.services.execution.scope_guard import ScopeGuard, ScopeGuardError from flyquery.core.services.execution.table_resolver import TableResolver from flyquery.core.services.query.query_repository import QueryRepository -from flyquery.core.services.query.query_service import QueryService +from flyquery.core.services.query.query_service import ( + QueryService, + _render_critic_prompt, + _render_generation_prompt, + _render_grounding_prompt, +) from flyquery.core.services.query.result_uploader import ResultUploader from flyquery.core.services.retrieval.embedder import Embedder from flyquery.core.services.retrieval.hybrid_retriever import HybridRetriever @@ -131,6 +136,54 @@ def __init__( self._scope_guard = ScopeGuard() self._executor = DuckDBExecutor(settings) + async def _guarded_execute( + self, + sql: str, + ast: Any, + table_kinds: dict[str, str], + resolver: TableResolver, + dataset_id: uuid.UUID, + bundle: dict, + ) -> Any: + """Apply ScopeGuard + bad-tables firewall, then execute. + + Mirrors the guards in ``QueryService.answer`` so the streaming path + enforces the same dataset isolation and table-existence checks. + Returns an ``ExecutionResult`` or an ``ExecutionError`` (which the + caller's critic loop can attempt to refine). + """ + from flyquery.core.services.execution.duckdb_executor import ExecutionError + + try: + self._scope_guard.check( + classification=ast, + scopes=_DEFAULT_USER_SCOPES, + table_kinds_by_name=table_kinds, + dataset_allowlist=None, + dataset_of_table={t: str(dataset_id) for t in ast.table_refs}, + ) + except ScopeGuardError as exc: + return ExecutionError(message=f"Rejected by firewall: {exc}") + + ref_set = {t for t in ast.table_refs if t} + bad_tables = sorted(ref_set - set(table_kinds.keys())) + if bad_tables: + real_tables = sorted(table_kinds.keys()) + [ + (getattr(h, "metadata", {}) or {}).get("qualified_name", "").rsplit(".", 1)[-1] + for h in (bundle.get("schema_inventory") or []) + if (getattr(h, "metadata", {}) or {}).get("kind") == "TABLE" + ] + real_tables = [t for t in dict.fromkeys(real_tables) if t] + return ExecutionError( + message=( + f"Table(s) {bad_tables!r} do not exist in this dataset. " + f"Pick ONLY from: {real_tables[:80]!r}." + ) + ) + + attached = await resolver.resolve(dataset_id, list(ast.table_refs)) + return await self._executor.execute(sql, attached) + def _build_service(self, db_session: AsyncSession) -> QueryService: """Build a per-request QueryService around the provided session.""" index = SearchIndex(db_session) @@ -335,24 +388,28 @@ async def explain( bundle["schema_objects"] = reranked grounding_agent = build_grounding_agent(self._settings) - grounded = await grounding_agent.run( - {"question": body.question, "bundle": bundle, "starting_point_sql": None} + grounded_run = await grounding_agent.run( + _render_grounding_prompt(question=body.question, bundle=bundle, starting_point_sql=None) ) + grounded = getattr(grounded_run, "output", grounded_run) generation_agent = build_generation_agent(self._settings) - gen_out = await generation_agent.run( - {"grounded": grounded, "question": body.question, "starting_point_sql": None} + gen_run = await generation_agent.run( + _render_generation_prompt( + body.question, grounded, None, schema_inventory=bundle.get("schema_inventory") + ) ) - candidate = gen_out.candidates[0] + gen_out = getattr(gen_run, "output", gen_run) + candidate = gen_out.candidates[0] if gen_out.candidates else None clarification: ClarificationFrame | None = None if grounded.confidence < self._settings.grounding_min_confidence and grounded.missing_info: clarification = ClarificationFrame(questions=grounded.missing_info, reasons=[]) return ExplainResponse( - sql=candidate.sql, - reasoning=candidate.reasoning, - confidence=candidate.confidence, + sql=candidate.sql if candidate else "", + reasoning=candidate.reasoning if candidate else "generation produced no candidate", + confidence=candidate.confidence if candidate else 0.0, grounded_summary={ "path": grounded.path, "confidence": grounded.confidence, @@ -398,15 +455,19 @@ async def validate( bundle["schema_objects"] = reranked grounding_agent = build_grounding_agent(self._settings) - grounded = await grounding_agent.run( - {"question": body.question, "bundle": bundle, "starting_point_sql": None} + grounded_run = await grounding_agent.run( + _render_grounding_prompt(question=body.question, bundle=bundle, starting_point_sql=None) ) + grounded = getattr(grounded_run, "output", grounded_run) generation_agent = build_generation_agent(self._settings) - gen_out = await generation_agent.run( - {"grounded": grounded, "question": body.question, "starting_point_sql": None} + gen_run = await generation_agent.run( + _render_generation_prompt( + body.question, grounded, None, schema_inventory=bundle.get("schema_inventory") + ) ) - chosen_sql = gen_out.candidates[0].sql + gen_out = getattr(gen_run, "output", gen_run) + chosen_sql = gen_out.candidates[0].sql if gen_out.candidates else "" ast = self._ast_classifier.classify(chosen_sql) @@ -494,21 +555,26 @@ async def _stream_events( retriever = HybridRetriever(index=index, embedder=self._embedder, rrf_k=self._settings.rrf_k) reranker = build_reranker(self._settings) - # Stage 1: retrieve + ground + # Stage 1: retrieve + ground (same retrieval params + rendered + # prompt as the sync POST /query path, so the streaming path is + # not blind to the column-value catalogue, examples, metrics). bundle = await retriever.retrieve( request.question, dataset_id=request.dataset_id, workspace_id=workspace_id, top_k_schema=self._settings.top_k_schema * 3, + top_k_examples=self._settings.top_k_examples, + top_k_metrics=self._settings.top_k_metrics, ) schema_hits = bundle.get("schema_objects", []) reranked = await reranker.rerank(request.question, schema_hits, top_n=self._settings.top_k_schema) bundle["schema_objects"] = reranked grounding_agent = build_grounding_agent(self._settings) - grounded = await grounding_agent.run( - {"question": request.question, "bundle": bundle, "starting_point_sql": None} + grounded_run = await grounding_agent.run( + _render_grounding_prompt(question=request.question, bundle=bundle, starting_point_sql=None) ) + grounded = getattr(grounded_run, "output", grounded_run) yield _sse_frame( "schema_linked", @@ -537,11 +603,14 @@ async def _stream_events( # Stage 3: generate SQL generation_agent = build_generation_agent(self._settings) - gen_out = await generation_agent.run( - {"grounded": grounded, "question": request.question, "starting_point_sql": None} + gen_run = await generation_agent.run( + _render_generation_prompt( + request.question, grounded, None, schema_inventory=bundle.get("schema_inventory") + ) ) + gen_out = getattr(gen_run, "output", gen_run) candidates = gen_out.candidates - chosen_sql = candidates[0].sql + chosen_sql = candidates[0].sql if candidates else "" yield _sse_frame( "sql_generated", @@ -554,28 +623,38 @@ async def _stream_events( }, ) - # Stage 4: execute (with critic loop) + # Stage 4: AST + firewall/scope guards + execute (with critic loop). + # _guarded_execute applies the SAME ScopeGuard + bad-tables firewall + # the sync path enforces, so streaming callers cannot bypass dataset + # isolation or run SQL against a non-existent/cross-dataset table. ast = self._ast_classifier.classify(chosen_sql) table_resolver = TableResolver(session=db_session, settings=self._settings) - attached = await table_resolver.resolve(request.dataset_id, list(ast.table_refs)) - - exec_result = await self._executor.execute(chosen_sql, attached) + table_kinds = await table_resolver.table_kinds_by_name(request.dataset_id, list(ast.table_refs)) + exec_result = await self._guarded_execute( + chosen_sql, ast, table_kinds, table_resolver, request.dataset_id, bundle + ) retries = 0 while isinstance(exec_result, ExecutionError) and retries < self._settings.max_refine_retries: critic_agent = build_critic_agent(self._settings) - refined = await critic_agent.run( - { - "sql": chosen_sql, - "error": exec_result.message, - "grounded": grounded, - "question": request.question, - } + refined_run = await critic_agent.run( + _render_critic_prompt( + question=request.question, + failing_sql=chosen_sql, + error_message=exec_result.message, + grounded=grounded, + schema_inventory=bundle.get("schema_inventory"), + ) ) + refined = getattr(refined_run, "output", refined_run) chosen_sql = refined.sql ast = self._ast_classifier.classify(chosen_sql) - attached = await table_resolver.resolve(request.dataset_id, list(ast.table_refs)) - exec_result = await self._executor.execute(chosen_sql, attached) + table_kinds = await table_resolver.table_kinds_by_name( + request.dataset_id, list(ast.table_refs) + ) + exec_result = await self._guarded_execute( + chosen_sql, ast, table_kinds, table_resolver, request.dataset_id, bundle + ) retries += 1 snapshot_pins: dict = {} diff --git a/tests/integration/test_hybrid_retrieval.py b/tests/integration/test_hybrid_retrieval.py index dc9abd3..54bbdb6 100644 --- a/tests/integration/test_hybrid_retrieval.py +++ b/tests/integration/test_hybrid_retrieval.py @@ -78,6 +78,14 @@ async def _seed_table_with_column( ), {"id": snap_id, "t": tenant, "ws": ws_id, "ds": ds_id, "tbl": tbl_id, "hash": "testhash"}, ) + # Publish: point the table at this snapshot. Retrieval is scoped to + # ``current_snapshot_id`` (so re-ingests don't return stale/duplicate + # columns), so a seeded table must have its current snapshot set -- exactly + # as the publish stage does in real ingestion. + await s.execute( + sa.text("UPDATE flyquery_tables SET current_snapshot_id = :snap WHERE id = :tbl"), + {"snap": snap_id, "tbl": tbl_id}, + ) if embedding is not None: vec = str(embedding) await s.execute( diff --git a/tests/unit/test_excel_period_header_inheritance.py b/tests/unit/test_excel_period_header_inheritance.py new file mode 100644 index 0000000..6e7fc34 --- /dev/null +++ b/tests/unit/test_excel_period_header_inheritance.py @@ -0,0 +1,115 @@ +# Copyright 2024-2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for period-header inheritance in the XLSX section extractor. + +Financial-report exports (Orbis/BvD) place ONE date header above a block of +sub-sections; section-splitting orphans that header above each sub-section's +title. These tests pin the inheritance behavior + its guards (no false +inherit, leftmost-label-only, legacy path round-trip). +""" + +from __future__ import annotations + +import datetime + +from flyquery.core.services.ingestion.readers.excel_reader import ( + ExcelReader, + _is_period_header_row, + _is_period_value, +) + + +def test_is_period_value_full_match_rejects_embedded_year_codes() -> None: + for ok in ("2024", "2024-12-31", "FY2024", "Q1 2024", "H1 2023"): + assert _is_period_value(ok), ok + assert _is_period_value(datetime.datetime(2024, 12, 31)) + # Codes / addresses that merely *contain* a year must NOT be period values. + for bad in ("INV-2024-0007", "Form 2020", "2001 Main St", "REG-2019-X", "86.21", "Andalucia"): + assert not _is_period_value(bad), bad + + +def test_is_period_header_row_requires_majority_period_cells() -> None: + assert _is_period_header_row(["", "2024", "2023", "2022"]) + assert _is_period_header_row([datetime.date(2024, 12, 31), datetime.date(2023, 12, 31)]) + # A single year among labels is not a header. + assert not _is_period_header_row(["Founded 2024", "CEO", "Revenue"]) + + +def _orbis_like_rows() -> list[list[object]]: + """A miniature Orbis-style sheet: + + one date header governs a data-first sub-section introduced by its own + title; then a label-headed table closes the band; then another data-first + section that must NOT inherit the (now stale) date header. + """ + return [ + ["Financial data", "", "", ""], # 0 title + ["", "2024", "2023", "2022"], # 1 period header (cols 1-3) + ["Profit & Loss", "", "", ""], # 2 title + ["Revenue", 100, 90, 80], # 3 data-first -> inherits row 1 + ["Costs", 40, 30, 20], # 4 data + ["", "", "", ""], # 5 blank + ["", "", "", ""], # 6 blank (section break) + ["Board", "", "", ""], # 7 title + ["Name", "Role", "", ""], # 8 label header (closes the period band) + ["Alice", "CEO", "", ""], # 9 data + ["", "", "", ""], # 10 blank + ["", "", "", ""], # 11 blank + ["Extra metrics", "", "", ""], # 12 title + ["Metric A", 1, 2, 3], # 13 data-first, columns OVERLAP the date header + ["Metric B", 4, 5, 6], # 14 data + ] + + +def test_data_first_section_inherits_period_header() -> None: + secs = ExcelReader._extract_sections(_orbis_like_rows()) + pnl = next(s for s in secs if s["label"] == "Profit & Loss") + # Header is the date row (1), data is non-contiguous (starts at 3). + assert pnl["header_row_idx"] == 1 + assert pnl["data_start_idx"] == 3 + assert pnl["data_end_idx"] == 5 + + +def test_label_headed_table_does_not_inherit_and_closes_the_band() -> None: + secs = ExcelReader._extract_sections(_orbis_like_rows()) + board = next(s for s in secs if s["label"] == "Board") + # Board has its OWN label header -> contiguous, no inheritance. + assert board["data_start_idx"] == board["header_row_idx"] + 1 + + # The later "Extra metrics" section shares column positions with the date + # header, but the band was CLOSED by the Board table -> it must NOT inherit. + extra = next(s for s in secs if s["label"] == "Extra metrics") + assert extra["data_start_idx"] == extra["header_row_idx"] + 1 + + +def test_section_path_round_trip_contiguous_and_inherited() -> None: + # Contiguous -> 2-index form (byte-identical to legacy). + s_contig = {"header_row_idx": 5, "data_start_idx": 6, "data_end_idx": 9} + p = ExcelReader._section_path("Sheet1", s_contig) + assert p == "Sheet1#section[5:9]" + assert ExcelReader._parse_section_path(p) == ("Sheet1", 5, 6, 9) + + # Inherited (non-contiguous) -> 3-index form. + s_inh = {"header_row_idx": 1, "data_start_idx": 3, "data_end_idx": 5} + p2 = ExcelReader._section_path("Sheet1", s_inh) + assert p2 == "Sheet1#section[1:3:5]" + assert ExcelReader._parse_section_path(p2) == ("Sheet1", 1, 3, 5) + + +def test_parse_section_path_legacy_and_plain() -> None: + # Legacy 2-index path still parses (already-stored tables). + assert ExcelReader._parse_section_path("S#section[2:7]") == ("S", 2, 3, 7) + # Plain sheet name -> all None. + assert ExcelReader._parse_section_path("JustASheet") == ("JustASheet", None, None, None) diff --git a/tests/unit/test_query_controller_sse.py b/tests/unit/test_query_controller_sse.py index 9f137c0..c786e09 100644 --- a/tests/unit/test_query_controller_sse.py +++ b/tests/unit/test_query_controller_sse.py @@ -89,7 +89,13 @@ class _FakeTableResolver: def __init__(self): self._session = _FakeSession() - async def resolve(self, dataset_id, table_names, object_store_base=None): + async def resolve(self, dataset_id, table_names, object_store_base=None, pins=None): + return {} + + async def table_kinds_by_name(self, dataset_id, table_names): + return {} + + async def current_snapshots(self, dataset_id, table_names): return {} diff --git a/tests/unit/test_query_service.py b/tests/unit/test_query_service.py index 94667aa..6cd6167 100644 --- a/tests/unit/test_query_service.py +++ b/tests/unit/test_query_service.py @@ -88,7 +88,10 @@ class _FakeTableResolver: def __init__(self): self._session = _FakeSession() - async def resolve(self, dataset_id, table_names, object_store_base=None): + async def resolve(self, dataset_id, table_names, object_store_base=None, pins=None): + return {} + + async def current_snapshots(self, dataset_id, table_names): return {}