Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions diracx-cli/src/diracx/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from diracx.core.extensions import select_from_extension
from diracx.core.extensions import DiracEntryPoint, select_from_extension

from .auth import app

Expand All @@ -9,21 +9,26 @@

# Load all the sub commands
cli_names = set(
[entry_point.name for entry_point in select_from_extension(group="diracx.cli")]
[
entry_point.name
for entry_point in select_from_extension(group=DiracEntryPoint.CLI)
]
)
for cli_name in cli_names:
entry_point = select_from_extension(group="diracx.cli", name=cli_name)[0]
entry_point = select_from_extension(group=DiracEntryPoint.CLI, name=cli_name)[0]
app.add_typer(entry_point.load(), name=entry_point.name)


cli_hidden_names = set(
[
entry_point.name
for entry_point in select_from_extension(group="diracx.cli.hidden")
for entry_point in select_from_extension(group=DiracEntryPoint.HIDDEN_CLI)
]
)
for cli_name in cli_hidden_names:
entry_point = select_from_extension(group="diracx.cli.hidden", name=cli_name)[0]
entry_point = select_from_extension(
group=DiracEntryPoint.HIDDEN_CLI, name=cli_name
)[0]
app.add_typer(entry_point.load(), name=entry_point.name, hidden=True)


Expand Down
12 changes: 6 additions & 6 deletions diracx-cli/src/diracx/cli/internal/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from diracx.core.config import Config
from diracx.core.config.schema import Field, SupportInfo
from diracx.core.extensions import select_from_extension
from diracx.core.extensions import DiracEntryPoint, select_from_extension

from ..utils import AsyncTyper

Expand Down Expand Up @@ -77,9 +77,9 @@ def cs_sync(old_file: Path, new_file: Path):
)

_apply_fixes(raw)
config_class: Config = select_from_extension(group="diracx", name="config")[
0
].load()
config_class: Config = select_from_extension(
group=DiracEntryPoint.CORE, name="config"
)[0].load()
config = config_class.model_validate(raw)
new_file.write_text(
yaml.safe_dump(config.model_dump(exclude_unset=True, mode="json"))
Expand Down Expand Up @@ -264,7 +264,7 @@ def generate_helm_values(

from diracx.core.extensions import select_from_extension

for entry_point in select_from_extension(group="diracx.dbs.sql"):
for entry_point in select_from_extension(group=DiracEntryPoint.SQL_DB):
db_name = entry_point.name
db_config = all_db_configs.get(db_name, {})

Expand Down Expand Up @@ -310,7 +310,7 @@ def generate_helm_values(
},
}

for entry_point in select_from_extension(group="diracx.dbs.os"):
for entry_point in select_from_extension(group=DiracEntryPoint.OS_DB):
db_name = entry_point.name
db_config = all_db_configs.get(db_name, {})

Expand Down
8 changes: 4 additions & 4 deletions diracx-core/src/diracx/core/config/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints

from ..exceptions import BadConfigurationVersionError
from ..extensions import select_from_extension
from ..extensions import DiracEntryPoint, select_from_extension
from ..utils import TwoLevelCache
from .schema import Config

Expand Down Expand Up @@ -214,9 +214,9 @@ def read_raw(self, hexsha: str, modified: datetime) -> Config:
f"Error reading configuration: {e}"
) from e

config_class: Config = select_from_extension(group="diracx", name="config")[
0
].load()
config_class: Config = select_from_extension(
group=DiracEntryPoint.CORE, name="config"
)[0].load()
config = config_class.model_validate(raw_obj)
config._hexsha = hexsha
config._modified = modified
Expand Down
22 changes: 19 additions & 3 deletions diracx-core/src/diracx/core/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
"extensions_by_priority",
"select_from_extension",
"supports_extending",
"DiracEntryPoint",
]

from collections import defaultdict
from enum import StrEnum
from importlib.metadata import EntryPoint, entry_points
from typing import Callable, ParamSpec, TypeVar, cast

Expand All @@ -16,6 +18,20 @@
P = ParamSpec("P")


class DiracEntryPoint(StrEnum):
"""Available entrypoint group values."""

CORE = "diracx"
ACCESS_POLICY = "diracx.access_policies"
CLI = "diracx.cli"
HIDDEN_CLI = "diracx.cli.hidden"
OS_DB = "diracx.dbs.os"
SQL_DB = "diracx.dbs.sql"
MIN_CLIENT_VERSION = "diracx.min_client_version"
RESOURCES = "diracx.resources"
SERVICES = "diracx.services"


@cached(cache=LRUCache(maxsize=1))
def extensions_by_priority() -> list[str]:
"""Yield extension module names in order of priority.
Expand All @@ -24,10 +40,10 @@ def extensions_by_priority() -> list[str]:
importing diracx in the MetaPathFinder as part of unrelated imports
(e.g. http.client).
"""
selected = entry_points().select(group="diracx")
selected = entry_points().select(group=DiracEntryPoint.CORE)
if selected is None:
raise NotImplementedError(
"No entry points found for group 'diracx'. Do you have it installed?"
f"No entry points found for group {DiracEntryPoint.CORE}. Do you have it installed?"
)
extensions = set()
for entry_point in selected.select(name="extension"):
Expand Down Expand Up @@ -73,7 +89,7 @@ def supports_extending(
name: The entry point name to search for

Example:
@supports_extending("diracx.resources", "find_compatible_platforms")
@supports_extending(DiracEntryPoint.RESOURCES, "find_compatible_platforms")
def my_function():
return "default implementation"

Expand Down
4 changes: 2 additions & 2 deletions diracx-core/src/diracx/core/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from DIRACCommon.Core.Utilities.ReturnValues import returnValueOrRaise

from diracx.core.config import Config
from diracx.core.extensions import supports_extending
from diracx.core.extensions import DiracEntryPoint, supports_extending


@supports_extending("diracx.resources", "find_compatible_platforms")
@supports_extending(DiracEntryPoint.RESOURCES, "find_compatible_platforms")
def find_compatible_platforms(job_platforms: list[str], config: Config) -> list[str]:
"""Find compatible platforms for the given job platforms.

Expand Down
6 changes: 4 additions & 2 deletions diracx-core/tests/test_entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from importlib.metadata import entry_points

from diracx.core.extensions import DiracEntryPoint


def test_diracx_resources_entry_point():
"""Test that the diracx.resources entry point is properly configured."""
# Get all entry points for the diracx.resources group
resources_eps = entry_points().select(group="diracx.resources")
resources_eps = entry_points().select(group=DiracEntryPoint.RESOURCES)

# Check that find_compatible_platforms entry point exists
find_platforms_ep = None
Expand All @@ -29,7 +31,7 @@ def test_diracx_resources_entry_point():
def test_entry_point_functionality():
"""Test that the entry point points to the correct function."""
# Get the entry point
resources_eps = entry_points().select(group="diracx.resources")
resources_eps = entry_points().select(group=DiracEntryPoint.RESOURCES)
find_platforms_ep = None
for ep in resources_eps:
if ep.name == "find_compatible_platforms":
Expand Down
5 changes: 3 additions & 2 deletions diracx-core/tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from diracx.core.extensions import (
DiracEntryPoint,
extensions_by_priority,
select_from_extension,
supports_extending,
Expand All @@ -26,9 +27,9 @@ def test_extensions_by_priority():
def test_select_from_extension():
"""Test the select_from_extension function."""
# Test with existing group
result = select_from_extension(group="diracx", name="extension")
result = select_from_extension(group=DiracEntryPoint.CORE, name="extension")
assert len(result) >= 1
assert all(ep.group == "diracx" for ep in result)
assert all(ep.group == DiracEntryPoint.CORE for ep in result)
assert all(ep.name == "extension" for ep in result)

# Test with non-existent group
Expand Down
6 changes: 3 additions & 3 deletions diracx-db/src/diracx/db/os/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from opensearchpy import AsyncOpenSearch

from diracx.core.exceptions import InvalidQueryError
from diracx.core.extensions import select_from_extension
from diracx.core.extensions import DiracEntryPoint, select_from_extension
from diracx.db.exceptions import DBUnavailableError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -93,7 +93,7 @@ def available_implementations(cls, db_name: str) -> list[type[BaseOSDB]]:
db_classes: list[type[BaseOSDB]] = [
entry_point.load()
for entry_point in select_from_extension(
group="diracx.dbs.os", name=db_name
group=DiracEntryPoint.OS_DB, name=db_name
)
]
if not db_classes:
Expand All @@ -108,7 +108,7 @@ def available_urls(cls) -> dict[str, dict[str, Any]]:
prefixed with ``DIRACX_OS_DB_{DB_NAME}``.
"""
conn_kwargs: dict[str, dict[str, Any]] = {}
for entry_point in select_from_extension(group="diracx.dbs.os"):
for entry_point in select_from_extension(group=DiracEntryPoint.OS_DB):
db_name = entry_point.name
var_name = f"DIRACX_OS_DB_{entry_point.name.upper()}"
if var_name in os.environ:
Expand Down
6 changes: 3 additions & 3 deletions diracx-db/src/diracx/db/sql/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from uuid_utils import UUID, uuid7

from diracx.core.exceptions import InvalidQueryError
from diracx.core.extensions import select_from_extension
from diracx.core.extensions import DiracEntryPoint, select_from_extension
from diracx.core.models import (
SearchSpec,
SortDirection,
Expand Down Expand Up @@ -111,7 +111,7 @@ def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]:
db_classes: list[type[BaseSQLDB]] = [
entry_point.load()
for entry_point in select_from_extension(
group="diracx.dbs.sql", name=db_name
group=DiracEntryPoint.SQL_DB, name=db_name
)
]
if not db_classes:
Expand All @@ -126,7 +126,7 @@ def available_urls(cls) -> dict[str, str]:
prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
"""
db_urls: dict[str, str] = {}
for entry_point in select_from_extension(group="diracx.dbs.sql"):
for entry_point in select_from_extension(group=DiracEntryPoint.SQL_DB):
db_name = entry_point.name
var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
if var_name in os.environ:
Expand Down
4 changes: 2 additions & 2 deletions diracx-routers/src/diracx/routers/access_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from fastapi import Depends

from diracx.core.extensions import select_from_extension
from diracx.core.extensions import DiracEntryPoint, select_from_extension
from diracx.core.models import (
AccessTokenPayload,
RefreshTokenPayload,
Expand Down Expand Up @@ -68,7 +68,7 @@ def available_implementations(cls, access_policy_name: str):
policy_classes: list[type["BaseAccessPolicy"]] = [
entry_point.load()
for entry_point in select_from_extension(
group="diracx.access_policies", name=access_policy_name
group=DiracEntryPoint.ACCESS_POLICY, name=access_policy_name
)
]
if not policy_classes:
Expand Down
18 changes: 10 additions & 8 deletions diracx-routers/src/diracx/routers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from diracx.core.config import ConfigSource
from diracx.core.exceptions import DiracError, DiracHttpResponseError, NotReadyError
from diracx.core.extensions import select_from_extension
from diracx.core.extensions import DiracEntryPoint, select_from_extension
from diracx.core.settings import ServiceSettingsBase
from diracx.core.utils import dotenv_files_from_environment
from diracx.db.exceptions import DBUnavailableError
Expand Down Expand Up @@ -235,7 +235,7 @@ def create_app_inner(
for system_name in sorted(enabled_systems):
assert system_name not in routers
for entry_point in select_from_extension(
group="diracx.services", name=system_name
group=DiracEntryPoint.SERVICES, name=system_name
):
routers[system_name] = entry_point.load()
break
Expand Down Expand Up @@ -351,7 +351,7 @@ def create_app() -> DiracFastAPI:
# Load all available routers
enabled_systems = set()
settings_classes = set()
for entry_point in select_from_extension(group="diracx.services"):
for entry_point in select_from_extension(group=DiracEntryPoint.SERVICES):
env_var = f"DIRACX_SERVICE_{entry_point.name.upper()}_ENABLED"
enabled = TypeAdapter(bool).validate_json(os.environ.get(env_var, "true"))
logger.debug("Found service %r: enabled=%s", entry_point, enabled)
Expand All @@ -370,7 +370,7 @@ def create_app() -> DiracFastAPI:

available_access_policy_names = {
entry_point.name
for entry_point in select_from_extension(group="diracx.access_policies")
for entry_point in select_from_extension(group=DiracEntryPoint.ACCESS_POLICY)
}

all_access_policies = {}
Expand Down Expand Up @@ -540,15 +540,17 @@ def is_version_too_old(self, client_version: str) -> bool | None:

def get_min_client_version():
"""Extracting min client version from entry_points and searching for extension."""
matched_entry_points: EntryPoints = entry_points(group="diracx.min_client_version")
matched_entry_points: EntryPoints = entry_points(
group=DiracEntryPoint.MIN_CLIENT_VERSION
)
# Searching for an extension:
entry_points_dict: dict[str, EntryPoint] = {
ep.name: ep for ep in matched_entry_points
}
for ep_name, ep in entry_points_dict.items():
if ep_name != "diracx":
if ep_name != DiracEntryPoint.CORE:
return ep.load()

# Taking diracx if no extension:
if "diracx" in entry_points_dict:
return entry_points_dict["diracx"].load()
if DiracEntryPoint.CORE in entry_points_dict:
return entry_points_dict[DiracEntryPoint.CORE].load()
4 changes: 2 additions & 2 deletions diracx-routers/tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from typing import TYPE_CHECKING

from diracx.core.extensions import select_from_extension
from diracx.core.extensions import DiracEntryPoint, select_from_extension
from diracx.routers.access_policies import (
BaseAccessPolicy,
)
Expand All @@ -22,7 +22,7 @@ def test_all_routes_have_policy():

"""
missing_security: defaultdict[list[str]] = defaultdict(list)
for entry_point in select_from_extension(group="diracx.services"):
for entry_point in select_from_extension(group=DiracEntryPoint.SERVICES):
router: DiracxRouter = entry_point.load()

# If the router was created with the
Expand Down
4 changes: 3 additions & 1 deletion diracx-testing/src/diracx/testing/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

import pytest

from diracx.core.extensions import DiracEntryPoint


def get_installed_entry_points():
"""Retrieve the installed entry points from the environment."""
entry_pts = entry_points()
diracx_eps = defaultdict(dict)
for group in entry_pts.groups:
if "diracx" in group:
if DiracEntryPoint.CORE in group:
for ep in entry_pts.select(group=group):
diracx_eps[group][ep.name] = ep.value
return dict(diracx_eps)
Expand Down
Loading
Loading