diff --git a/src/google/adk/tools/spanner/search_tool.py b/src/google/adk/tools/spanner/search_tool.py index 6fb4a93f0a..93eb729882 100644 --- a/src/google/adk/tools/spanner/search_tool.py +++ b/src/google/adk/tools/spanner/search_tool.py @@ -16,6 +16,7 @@ import asyncio import json +import re from typing import Any from typing import Dict from typing import List @@ -31,6 +32,89 @@ from .settings import EXACT_NEAREST_NEIGHBORS from .settings import SpannerToolSettings +# Pattern for valid SQL identifiers: alphanumeric, underscores, dots (for +# schema-qualified names), and backtick/double-quote quoting. +_SAFE_IDENTIFIER_RE = re.compile( + r'^(?:[A-Za-z_][A-Za-z0-9_]*' # unquoted identifier + r'(?:\.[A-Za-z_][A-Za-z0-9_]*)*' # optional schema.table + r'|`[^`]+`' # backtick-quoted + r'|"[^"]+")$' # double-quote-quoted +) + +# Patterns that should never appear in an additional_filter value when +# the filter is populated by the LLM at runtime. +_FILTER_DENY_PATTERNS = re.compile( + r';\s*' # statement separator + r'|--' # single-line comment + r'|/\*' # block comment start + r'|\*/' # block comment end + r'|\bUNION\b' # UNION-based injection + r'|\bINTO\b\s+\bOUTFILE\b' # INTO OUTFILE + , re.IGNORECASE +) + + +def _validate_identifier(value: str, param_name: str) -> str: + """Validate that a value is a safe SQL identifier. + + Args: + value: The identifier string to validate. + param_name: Name of the parameter (for error messages). + + Returns: + The validated identifier string. + + Raises: + ValueError: If the identifier contains unsafe characters. + """ + if not value or not _SAFE_IDENTIFIER_RE.match(value.strip()): + raise ValueError( + f"Invalid SQL identifier for {param_name}: {value!r}. " + "Identifiers must contain only alphanumeric characters, underscores, " + "and dots, or be quoted with backticks or double quotes." + ) + return value.strip() + + +def _validate_column_list(columns: List[str], param_name: str) -> List[str]: + """Validate that each column name in a list is a safe SQL identifier.""" + validated = [] + for col in columns: + _validate_identifier(col, param_name) + validated.append(col) + return validated + + +def _validate_additional_filter( + filter_value: Optional[str], +) -> Optional[str]: + """Validate that an additional_filter does not contain injection patterns. + + This is a defense-in-depth measure. The additional_filter field is + documented as a developer-trusted value, but since it can be populated + by the LLM at runtime via tool calls, we reject common injection + patterns. + + Args: + filter_value: The filter string to validate. + + Returns: + The validated filter string, or None. + + Raises: + ValueError: If the filter contains dangerous patterns. + """ + if filter_value is None: + return None + if _FILTER_DENY_PATTERNS.search(filter_value): + raise ValueError( + f"additional_filter contains a disallowed pattern: {filter_value!r}. " + "Semicolons, comments (-- or /* */), and UNION keywords are not " + "permitted in filter expressions." + ) + return filter_value + + # Embedding model settings. # Only for Spanner GoogleSQL dialect database, and use Spanner ML.PREDICT # function. @@ -62,6 +146,10 @@ def _generate_googlesql_for_embedding_query( spanner_gsql_embedding_model_name: str, ) -> str: + _validate_identifier( + spanner_gsql_embedding_model_name, + "spanner_googlesql_embedding_model_name", + ) return f""" SELECT embeddings.values FROM ML.PREDICT( @@ -75,6 +163,16 @@ def _generate_postgresql_for_embedding_query( vertex_ai_embedding_model_endpoint: str, output_dimensionality: Optional[int], ) -> str: + # Validate endpoint format: projects/.../locations/.../publishers/.../models/... + if not re.match( + r'^projects/[\w-]+/locations/[\w-]+/publishers/[\w-]+/models/[\w.-]+$', + vertex_ai_embedding_model_endpoint, + ): + raise ValueError( + f"Invalid Vertex AI endpoint format: " + f"{vertex_ai_embedding_model_endpoint!r}. Expected format: " + f"projects/$project/locations/$location/publishers/google/models/$model" + ) instances_json = f""" 'instances', JSONB_BUILD_ARRAY( @@ -166,6 +264,11 @@ def _generate_sql_for_knn( top_k: int, ) -> str: """Generates a SQL query for kNN search.""" + _validate_identifier(table_name, "table_name") + _validate_identifier(embedding_column_to_search, "embedding_column_to_search") + columns = _validate_column_list(columns, "columns") + additional_filter = _validate_additional_filter(additional_filter) + top_k = int(top_k) if dialect == DatabaseDialect.POSTGRESQL: distance_function = _get_postgresql_distance_function(distance_type) embedding_parameter = f"${_POSTGRESQL_PARAMETER_QUERY_EMBEDDING}" @@ -205,6 +308,12 @@ def _generate_sql_for_ann( num_leaves_to_search: int, ): """Generates a SQL query for ANN search.""" + _validate_identifier(table_name, "table_name") + _validate_identifier(embedding_column_to_search, "embedding_column_to_search") + columns = _validate_column_list(columns, "columns") + additional_filter = _validate_additional_filter(additional_filter) + top_k = int(top_k) + num_leaves_to_search = int(num_leaves_to_search) if dialect == DatabaseDialect.POSTGRESQL: raise NotImplementedError( f"{APPROXIMATE_NEAREST_NEIGHBORS} is not supported for PostgreSQL" diff --git a/tests/unittests/tools/spanner/test_spanner_sql_validation.py b/tests/unittests/tools/spanner/test_spanner_sql_validation.py new file mode 100644 index 0000000000..ee511c8322 --- /dev/null +++ b/tests/unittests/tools/spanner/test_spanner_sql_validation.py @@ -0,0 +1,186 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SQL identifier validation in Spanner search tool. + +Verifies that malicious SQL identifiers and filter patterns are rejected +before being interpolated into SQL queries (defense against SQL injection +via LLM-populated tool parameters). +""" + +import pytest +from google.cloud.spanner_admin_database_v1.types import DatabaseDialect + +from google.adk.tools.spanner.search_tool import _generate_sql_for_knn +from google.adk.tools.spanner.search_tool import _generate_sql_for_ann +from google.adk.tools.spanner.search_tool import _validate_additional_filter +from google.adk.tools.spanner.search_tool import _validate_column_list +from google.adk.tools.spanner.search_tool import _validate_identifier + + +class TestValidateIdentifier: + """Tests for _validate_identifier.""" + + def test_simple_identifier(self): + assert _validate_identifier("documents", "test") == "documents" + + def test_schema_qualified_identifier(self): + assert _validate_identifier("my_schema.my_table", "test") == "my_schema.my_table" + + def test_identifier_with_underscores(self): + assert _validate_identifier("embedding_col_1", "test") == "embedding_col_1" + + def test_backtick_quoted_identifier(self): + assert _validate_identifier("`my table`", "test") == "`my table`" + + def test_double_quote_quoted_identifier(self): + assert _validate_identifier('"my column"', "test") == '"my column"' + + def test_rejects_join_injection(self): + with pytest.raises(ValueError, match="Invalid SQL identifier"): + _validate_identifier( + "documents JOIN admin_credentials ac ON TRUE", "table_name" + ) + + def test_rejects_subquery_in_column(self): + with pytest.raises(ValueError, match="Invalid SQL identifier"): + _validate_identifier( + "(SELECT STRING_AGG(table_name, ',') FROM INFORMATION_SCHEMA.TABLES) AS schema_dump", + "columns", + ) + + def test_rejects_semicolon(self): + with pytest.raises(ValueError, match="Invalid SQL identifier"): + _validate_identifier("table; DROP TABLE users", "table_name") + + def test_rejects_empty(self): + with pytest.raises(ValueError, match="Invalid SQL identifier"): + _validate_identifier("", "table_name") + + def test_rejects_sql_comment(self): + with pytest.raises(ValueError, match="Invalid SQL identifier"): + _validate_identifier("table -- comment", "table_name") + + +class TestValidateColumnList: + """Tests for _validate_column_list.""" + + def test_valid_columns(self): + result = _validate_column_list(["col1", "col2", "col3"], "columns") + assert result == ["col1", "col2", "col3"] + + def test_rejects_subquery_column(self): + with pytest.raises(ValueError, match="Invalid SQL identifier"): + _validate_column_list( + [ + "(SELECT STRING_AGG(table_name, ',') FROM INFORMATION_SCHEMA.TABLES) AS dump", + "content", + ], + "columns", + ) + + +class TestValidateAdditionalFilter: + """Tests for _validate_additional_filter.""" + + def test_none_filter(self): + assert _validate_additional_filter(None) is None + + def test_simple_filter(self): + assert _validate_additional_filter("price_in_cents < 100000") == "price_in_cents < 100000" + + def test_rejects_union(self): + with pytest.raises(ValueError, match="UNION"): + _validate_additional_filter( + "1=1 UNION ALL SELECT password, 0.0 FROM admin_credentials" + ) + + def test_rejects_semicolon(self): + with pytest.raises(ValueError, match="disallowed pattern"): + _validate_additional_filter("1=1; SELECT * FROM secrets") + + def test_rejects_line_comment(self): + with pytest.raises(ValueError, match="disallowed pattern"): + _validate_additional_filter("1=1 -- bypass") + + def test_rejects_block_comment(self): + with pytest.raises(ValueError, match="disallowed pattern"): + _validate_additional_filter("1=1 /* bypass */") + + +class TestGenerateSqlForKnn: + """Tests for _generate_sql_for_knn with validation.""" + + def test_valid_query_googlesql(self): + sql = _generate_sql_for_knn( + dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, + table_name="documents", + embedding_column_to_search="embedding", + columns=["content"], + additional_filter=None, + distance_type="COSINE", + top_k=10, + ) + assert "FROM documents" in sql + assert "COSINE_DISTANCE" in sql + + def test_rejects_union_in_filter(self): + with pytest.raises(ValueError, match="UNION"): + _generate_sql_for_knn( + dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, + table_name="documents", + embedding_column_to_search="embedding", + columns=["content"], + additional_filter="1=1 UNION ALL SELECT password, 0.0 FROM admin_credentials", + distance_type="COSINE", + top_k=10, + ) + + def test_rejects_join_in_table_name(self): + with pytest.raises(ValueError, match="Invalid SQL identifier"): + _generate_sql_for_knn( + dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, + table_name="documents JOIN admin_credentials ac ON TRUE", + embedding_column_to_search="embedding", + columns=["content"], + additional_filter=None, + distance_type="COSINE", + top_k=10, + ) + + def test_rejects_subquery_in_columns(self): + with pytest.raises(ValueError, match="Invalid SQL identifier"): + _generate_sql_for_knn( + dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, + table_name="documents", + embedding_column_to_search="embedding", + columns=[ + "(SELECT STRING_AGG(table_name, ',') FROM INFORMATION_SCHEMA.TABLES) AS schema_dump", + ], + additional_filter=None, + distance_type="COSINE", + top_k=1, + ) + + def test_top_k_string_coerced_to_int(self): + sql = _generate_sql_for_knn( + dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, + table_name="documents", + embedding_column_to_search="embedding", + columns=["content"], + additional_filter=None, + distance_type="COSINE", + top_k="10", # String input + ) + assert "LIMIT 10" in sql