Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class AuthenticationRequiredError(ProblemDetailError):
uri = "https://openml.org/problems/authentication-required"
title = "Authentication Required"
_default_status_code = HTTPStatus.UNAUTHORIZED
_default_code = 103 # PHP API doesn't differentiate


class AuthenticationFailedError(ProblemDetailError):
Expand Down
15 changes: 11 additions & 4 deletions src/routers/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncConnection

from core.errors import AuthenticationFailedError
from core.errors import AuthenticationFailedError, AuthenticationRequiredError
from database.setup import expdb_database, user_database
from database.users import APIKey, User

Expand All @@ -26,15 +26,22 @@ async def fetch_user(
api_key: APIKey | None = None,
user_data: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None,
) -> User | None:
return await User.fetch(api_key, user_data) if api_key and user_data else None
if not (api_key and user_data):
return None

user = await User.fetch(api_key, user_data)
if user:
return user
msg = "Invalid API key provided."
raise AuthenticationFailedError(msg)


def fetch_user_or_raise(
user: Annotated[User | None, Depends(fetch_user)] = None,
) -> User:
if user is None:
msg = "Authentication failed"
raise AuthenticationFailedError(msg)
msg = "No API key provided."
raise AuthenticationRequiredError(msg)
return user


Expand Down
Empty file added tests/dependencies/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions tests/dependencies/fetch_user_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
from sqlalchemy.ext.asyncio import AsyncConnection

from core.errors import AuthenticationFailedError, AuthenticationRequiredError
from database.users import User
from routers.dependencies import fetch_user, fetch_user_or_raise
from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey


@pytest.mark.parametrize(
("api_key", "user"),
[
(ApiKey.ADMIN, ADMIN_USER),
(ApiKey.OWNER_USER, OWNER_USER),
(ApiKey.SOME_USER, SOME_USER),
],
)
async def test_fetch_user(api_key: str, user: User, user_test: AsyncConnection) -> None:
db_user = await fetch_user(api_key, user_data=user_test)
assert isinstance(db_user, User)
assert user.user_id == db_user.user_id
assert set(await user.get_groups()) == set(await db_user.get_groups())


async def test_fetch_user_no_key_no_user() -> None:
assert await fetch_user(api_key=None) is None


async def test_fetch_user_invalid_key_raises(user_test: AsyncConnection) -> None:
with pytest.raises(AuthenticationFailedError):
await fetch_user(api_key=ApiKey.INVALID, user_data=user_test)


async def test_fetch_user_or_raise_raises_if_no_user() -> None:
# This function calls `fetch_user` through dependency injection,
# so it only needs to correctly handle possible output of `fetch_user`.
with pytest.raises(AuthenticationRequiredError):
fetch_user_or_raise(user=None)
Comment on lines +34 to +38
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add a positive test for fetch_user_or_raise when a valid user is provided.

Right now we only cover the error path (no user → AuthenticationRequiredError). Please also add a test passing a valid User instance and asserting it’s returned unchanged and no exception is raised. This will guard against regressions where the function might start altering or re-validating the user.

8 changes: 1 addition & 7 deletions tests/routers/openml/migration/datasets_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,10 @@ async def test_error_unknown_dataset(
assert error["detail"].startswith("No dataset")


@pytest.mark.parametrize(
"api_key",
[None, ApiKey.INVALID],
)
async def test_private_dataset_no_user_no_access(
py_api: httpx.AsyncClient,
api_key: str | None,
) -> None:
query = f"?api_key={api_key}" if api_key else ""
response = await py_api.get(f"/datasets/130{query}")
response = await py_api.get("/datasets/130")

# New response is 403: Forbidden instead of 412: PRECONDITION FAILED
assert response.status_code == HTTPStatus.FORBIDDEN
Expand Down
2 changes: 1 addition & 1 deletion tests/routers/openml/setups_tag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async def test_setup_tag_missing_auth(py_api: httpx.AsyncClient) -> None:
response = await py_api.post("/setup/tag", json={"setup_id": 1, "tag": "test_tag"})
assert response.status_code == HTTPStatus.UNAUTHORIZED
assert response.json()["code"] == "103"
assert response.json()["detail"] == "Authentication failed"
assert response.json()["detail"] == "No API key provided."


async def test_setup_tag_unknown_setup(py_api: httpx.AsyncClient) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/routers/openml/setups_untag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async def test_setup_untag_missing_auth(py_api: httpx.AsyncClient) -> None:
response = await py_api.post("/setup/untag", json={"setup_id": 1, "tag": "test_tag"})
assert response.status_code == HTTPStatus.UNAUTHORIZED
assert response.json()["code"] == "103"
assert response.json()["detail"] == "Authentication failed"
assert response.json()["detail"] == "No API key provided."


async def test_setup_untag_unknown_setup(py_api: httpx.AsyncClient) -> None:
Expand Down
27 changes: 0 additions & 27 deletions tests/routers/openml/users_test.py

This file was deleted.

Loading