Skip to content
Open
27 changes: 26 additions & 1 deletion src/database/tasks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Row, text
from sqlalchemy import Row, RowMapping, text
from sqlalchemy.ext.asyncio import AsyncConnection

ALLOWED_LOOKUP_TABLES = ["estimation_procedure", "evaluation_measure", "task_type", "dataset"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you determine which tables to allow? It's been a little while since I've looked at this code so correct me if I am wrong, but based on the docstring of fill_template it seems that LOOKUP statements always must reference a table that is a input in task_inputs.

PK_MAPPING = {
"task_type": "ttid",
"dataset": "did",
}


async def get(id_: int, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
Expand Down Expand Up @@ -115,3 +121,22 @@ async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]:
)
tag_rows = rows.all()
return [row.tag for row in tag_rows]


async def get_lookup_data(table: str, id_: int, expdb: AsyncConnection) -> RowMapping | None:
if table not in ALLOWED_LOOKUP_TABLES:
msg = f"Table {table} is not allowed for lookup."
raise ValueError(msg)

pk = PK_MAPPING.get(table, "id")
result = await expdb.execute(
text(
f"""
SELECT *
FROM {table}
WHERE `{pk}` = :id_
""", # noqa: S608
),
parameters={"id_": id_},
)
return result.mappings().one_or_none()
43 changes: 25 additions & 18 deletions src/routers/openml/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Annotated, cast

import xmltodict
from fastapi import APIRouter, Depends
from sqlalchemy import RowMapping, text
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import RowMapping
from sqlalchemy.ext.asyncio import AsyncConnection

import config
Expand All @@ -17,6 +17,7 @@
router = APIRouter(prefix="/tasks", tags=["tasks"])

type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None
ALLOWED_LOOKUP_TABLES = {"estimation_procedure", "evaluation_measure", "task_type", "dataset"}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ALLOWED_LOOKUP_TABLES = {"estimation_procedure", "evaluation_measure", "task_type", "dataset"}

This variable does not seem to be used.



def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
Expand Down Expand Up @@ -95,7 +96,7 @@ async def fill_template(
)


async def _fill_json_template( # noqa: C901
async def _fill_json_template( # noqa: C901, PLR0912
template: JSON,
task: RowMapping,
task_inputs: dict[str, str | int],
Expand Down Expand Up @@ -128,23 +129,29 @@ async def _fill_json_template( # noqa: C901
(field,) = match.groups()
if field not in fetched_data:
table, _ = field.split(".")
result = await connection.execute(
text(
f"""
SELECT *
FROM {table}
WHERE `id` = :id_
""", # noqa: S608
),
# Not sure how parametrize table names, as the parametrization adds
# quotes which is not legal.
parameters={"id_": int(task_inputs[table])},
)
rows = result.mappings()
row_data = next(rows, None)
# List of tables allowed for [LOOKUP:table.column] directive.
# This is a security measure to prevent SQL injection via table names.
if table not in task_inputs or not task_inputs[table]:
msg = f"Missing or empty input for lookup table: {table}"
raise HTTPException(status_code=400, detail=msg)

try:
id_val = int(task_inputs[table])
except ValueError:
msg = f"Invalid integer id for table {table}: {task_inputs[table]}"
raise HTTPException(status_code=400, detail=msg) from None

try:
row_data = await database.tasks.get_lookup_data(
table=table,
id_=id_val,
expdb=connection,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
if row_data is None:
msg = f"No data found for table {table} with id {task_inputs[table]}"
raise ValueError(msg)
raise HTTPException(status_code=400, detail=msg)
for column, value in row_data.items():
fetched_data[f"{table}.{column}"] = value
if match.string == template:
Expand Down