Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, List, Optional, Type
import re
from typing import Any, Dict, List, Optional, Type

from httpx import Response
from uipath.core.tracing import traced
Expand All @@ -13,6 +14,31 @@
EntityRecordsBatchResponse,
)

_FORBIDDEN_SQL_KEYWORDS = {
"INSERT",
"UPDATE",
"DELETE",
"MERGE",
"DROP",
"ALTER",
"CREATE",
"TRUNCATE",
"REPLACE",
}
_DISALLOWED_SQL_OPERATORS = [
"WITH",
"UNION",
"INTERSECT",
"EXCEPT",
"OVER",
"ROLLUP",
"CUBE",
"GROUPING SETS",
"PARTITION BY",
]
_SQL_KEYWORD_PATTERN = re.compile(r"\b([A-Z]+(?:\s+BY|\s+SETS)?)\b")
_QUOTED_STRING_PATTERN = re.compile(r"'[^']*'")


class EntitiesService(BaseService):
"""Service for managing UiPath Data Service entities.
Expand Down Expand Up @@ -391,6 +417,66 @@ class CustomerRecord:
EntityRecord.from_data(data=record, model=schema) for record in records_data
]

@traced(name="entity_query_records", run_type="uipath")
def query_entity_records(
self,
sql_query: str,
) -> List[Dict[str, Any]]:
"""Query entity records using a validated SQL query.

PREVIEW: This method is in preview and may change in future releases.

Args:
sql_query (str): A SQL SELECT query to execute against Data Service entities.
Only SELECT statements are allowed. Queries without WHERE must include
a LIMIT clause. Subqueries and multi-statement queries are not permitted.

Returns:
List[Dict[str, Any]]: A list of result records as dictionaries.

Raises:
ValueError: If the SQL query fails validation (e.g., non-SELECT, missing
WHERE/LIMIT, forbidden keywords, subqueries).
"""
return self._query_entities_for_records(sql_query)

@traced(name="entity_query_records", run_type="uipath")
async def query_entity_records_async(
self,
sql_query: str,
) -> List[Dict[str, Any]]:
"""Asynchronously query entity records using a validated SQL query.

PREVIEW: This method is in preview and may change in future releases.

Args:
sql_query (str): A SQL SELECT query to execute against Data Service entities.
Only SELECT statements are allowed. Queries without WHERE must include
a LIMIT clause. Subqueries and multi-statement queries are not permitted.

Returns:
List[Dict[str, Any]]: A list of result records as dictionaries.

Raises:
ValueError: If the SQL query fails validation (e.g., non-SELECT, missing
WHERE/LIMIT, forbidden keywords, subqueries).
"""
return await self._query_entities_for_records_async(sql_query)

def _query_entities_for_records(self, sql_query: str) -> List[Dict[str, Any]]:
self._validate_sql_query(sql_query)
spec = self._query_entity_records_spec(sql_query)
response = self.request(spec.method, spec.endpoint, json=spec.json)
return response.json().get("results", [])

async def _query_entities_for_records_async(
self, sql_query: str
) -> List[Dict[str, Any]]:
self._validate_sql_query(sql_query)
spec = self._query_entity_records_spec(sql_query)
response = await self.request_async(spec.method, spec.endpoint, json=spec.json)
return response.json().get("results", [])

@traced(name="entity_record_insert_batch", run_type="uipath")
def insert_records(
self,
Expand Down Expand Up @@ -874,6 +960,16 @@ def _list_records_spec(
params=({"start": start, "limit": limit}),
)

def _query_entity_records_spec(
self,
sql_query: str,
) -> RequestSpec:
return RequestSpec(
method="POST",
endpoint=Endpoint("datafabric_/api/v1/query/execute"),
json={"query": sql_query},
)

def _insert_batch_spec(self, entity_key: str, records: List[Any]) -> RequestSpec:
return RequestSpec(
method="POST",
Expand Down Expand Up @@ -902,3 +998,58 @@ def _delete_batch_spec(self, entity_key: str, record_ids: List[str]) -> RequestS
),
json=record_ids,
)

def _validate_sql_query(self, sql_query: str) -> None:
query = sql_query.strip().rstrip(";").strip()
if not query:
raise ValueError("SQL query cannot be empty.")

# Strip quoted strings before checking for semicolons so that
# values like WHERE name = 'foo;bar' don't trigger a false positive.
unquoted = _QUOTED_STRING_PATTERN.sub("''", query)
if ";" in unquoted:
raise ValueError("Only a single SELECT statement is allowed.")

normalized_query = re.sub(r"\s+", " ", query).strip()
normalized_upper = normalized_query.upper()
extracted_keywords = set(_SQL_KEYWORD_PATTERN.findall(normalized_upper))

if not normalized_upper.startswith("SELECT "):
if not normalized_upper.startswith("WITH "):
raise ValueError("Only SELECT statements are allowed.")

for keyword in _FORBIDDEN_SQL_KEYWORDS:
if keyword in extracted_keywords:
raise ValueError(f"SQL keyword '{keyword}' is not allowed.")

for operator in _DISALLOWED_SQL_OPERATORS:
if operator in extracted_keywords:
raise ValueError(
f"SQL construct '{operator}' is not allowed in entity queries."
)

if re.search(r"\(\s*SELECT\b", normalized_upper):
raise ValueError("Subqueries are not allowed.")

has_where = bool(re.search(r"\bWHERE\b", normalized_upper))
has_limit = bool(re.search(r"\bLIMIT\s+\d+\b", normalized_upper))
if not has_where and not has_limit:
raise ValueError("Queries without WHERE must include a LIMIT clause.")

projection = self._projection_segment(normalized_query)
if "*" in projection and not has_where:
raise ValueError("SELECT * without filtering is not allowed.")
if not has_where and self._projection_column_count(projection) > 4:
raise ValueError(
"Selecting more than 4 columns without filtering is not allowed."
)

def _projection_segment(self, normalized_query: str) -> str:
match = re.match(r"(?is)\s*SELECT\s+(.*?)\s+FROM\s+", normalized_query)
return match.group(1) if match else ""

def _projection_column_count(self, projection: str) -> int:
cleaned = projection.strip()
if not cleaned:
return 0
return len([part for part in cleaned.split(",") if part.strip()])
123 changes: 123 additions & 0 deletions packages/uipath-platform/tests/services/test_entities_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import re
import uuid
from dataclasses import make_dataclass
from typing import Optional
from unittest.mock import AsyncMock, MagicMock

import pytest
from pytest_httpx import HTTPXMock
Expand Down Expand Up @@ -260,3 +262,124 @@ def test_retrieve_records_with_optional_fields(
start=0,
limit=1,
)

@pytest.mark.parametrize(
"sql_query",
[
"SELECT id FROM Customers WHERE id = 1",
"SELECT id, name FROM Customers LIMIT 10",
"SELECT * FROM Customers WHERE status = 'Active'",
"SELECT id, name, email, phone FROM Customers LIMIT 5",
"SELECT DISTINCT id FROM Customers WHERE id > 100",
"SELECT id FROM Customers WHERE name = 'foo;bar'",
"SELECT id FROM Customers WHERE id = 1;",
],
)
def test_validate_sql_query_allows_supported_select_queries(
self, sql_query: str, service: EntitiesService
) -> None:
service._validate_sql_query(sql_query)

@pytest.mark.parametrize(
"sql_query,error_message",
[
("", "SQL query cannot be empty."),
(" ", "SQL query cannot be empty."),
(
"SELECT id FROM Customers; SELECT id FROM Orders",
"Only a single SELECT statement is allowed.",
),
("INSERT INTO Customers VALUES (1)", "Only SELECT statements are allowed."),
(
"WITH cte AS (SELECT id FROM Customers) SELECT id FROM cte",
"SQL construct 'WITH' is not allowed in entity queries.",
),
(
"SELECT id FROM Customers UNION SELECT id FROM Orders",
"SQL construct 'UNION' is not allowed in entity queries.",
),
(
"SELECT id, SUM(amount) OVER (PARTITION BY id) FROM Orders LIMIT 10",
"SQL construct 'OVER' is not allowed in entity queries.",
),
(
"SELECT id FROM (SELECT id FROM Customers) c",
"Subqueries are not allowed.",
),
(
"SELECT id FROM Customers",
"Queries without WHERE must include a LIMIT clause.",
),
(
"SELECT * FROM Customers LIMIT 10",
"SELECT * without filtering is not allowed.",
),
(
"SELECT id, name, email, phone, address FROM Customers LIMIT 10",
"Selecting more than 4 columns without filtering is not allowed.",
),
],
)
def test_validate_sql_query_rejects_disallowed_queries(
self, sql_query: str, error_message: str, service: EntitiesService
) -> None:
with pytest.raises(ValueError, match=re.escape(error_message)):
service._validate_sql_query(sql_query)

def test_query_entity_records_rejects_invalid_sql_before_network_call(
self,
service: EntitiesService,
) -> None:
service.request = MagicMock() # type: ignore[method-assign]

with pytest.raises(
ValueError, match=re.escape("Only SELECT statements are allowed.")
):
service.query_entity_records("UPDATE Customers SET name = 'X'")

service.request.assert_not_called() # type: ignore[attr-defined]

def test_query_entity_records_calls_request_for_valid_sql(
self,
service: EntitiesService,
) -> None:
response = MagicMock()
response.json.return_value = {"results": [{"id": 1}, {"id": 2}]}

service.request = MagicMock(return_value=response) # type: ignore[method-assign]

result = service.query_entity_records("SELECT id FROM Customers WHERE id > 0")

assert result == [{"id": 1}, {"id": 2}]
service.request.assert_called_once() # type: ignore[attr-defined]

@pytest.mark.anyio
async def test_query_entity_records_async_rejects_invalid_sql_before_network_call(
self,
service: EntitiesService,
) -> None:
service.request_async = AsyncMock() # type: ignore[method-assign]

with pytest.raises(ValueError, match=re.escape("Subqueries are not allowed.")):
await service.query_entity_records_async(
"SELECT id FROM Customers WHERE id IN (SELECT id FROM Orders)"
)

service.request_async.assert_not_called() # type: ignore[attr-defined]

@pytest.mark.anyio
async def test_query_entity_records_async_calls_request_for_valid_sql(
self,
service: EntitiesService,
) -> None:
response = MagicMock()
response.json.return_value = {"results": [{"id": "c1"}]}

service.request_async = AsyncMock(return_value=response) # type: ignore[method-assign]

result = await service.query_entity_records_async(
"SELECT id FROM Customers WHERE id = 'c1'"
)

assert result == [{"id": "c1"}]
service.request_async.assert_called_once() # type: ignore[attr-defined]
28 changes: 23 additions & 5 deletions packages/uipath-platform/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions packages/uipath/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"mermaid-builder==0.0.3",
"graphtty==0.1.8",
"applicationinsights>=0.11.10",
"sqlparse>=0.4.4",
]
classifiers = [
"Intended Audience :: Developers",
Expand Down
Loading
Loading