diff --git a/dataframely/config.py b/dataframely/config.py index c597745..425ff0f 100644 --- a/dataframely/config.py +++ b/dataframely/config.py @@ -2,9 +2,10 @@ # SPDX-License-Identifier: BSD-3-Clause import contextlib +import os import sys from types import TracebackType -from typing import TypedDict +from typing import Any, TypedDict, cast, get_type_hints if sys.version_info >= (3, 11): from typing import Unpack @@ -17,22 +18,37 @@ class Options(TypedDict): max_sampling_iterations: int -def default_options() -> Options: +_ENV_PREFIX = "DATAFRAMELY_" + + +def _builtin_defaults() -> Options: return { "max_sampling_iterations": 10_000, } +def _init_options() -> Options: + options: dict[str, Any] = dict(_builtin_defaults()) + for key, target_type in get_type_hints(Options).items(): + env_name = f"{_ENV_PREFIX}{key.upper()}" + if env_name in os.environ: + options[key] = target_type(os.environ[env_name]) + return cast(Options, options) + + +_DEFAULT_OPTIONS = _init_options() + + class Config(contextlib.ContextDecorator): """An object to track global configuration for operations in dataframely.""" #: The currently valid config options. - options: Options = default_options() + options: Options = _DEFAULT_OPTIONS.copy() #: Singleton stack to track where to go back after exiting a context. _stack: list[Options] = [] def __init__(self, **options: Unpack[Options]) -> None: - self._local_options: Options = {**default_options(), **options} + self._local_options: Options = {**_DEFAULT_OPTIONS, **options} @staticmethod def set_max_sampling_iterations(iterations: int) -> None: @@ -43,7 +59,7 @@ def set_max_sampling_iterations(iterations: int) -> None: @staticmethod def restore_defaults() -> None: """Restore the defaults of the configuration.""" - Config.options = default_options() + Config.options = _DEFAULT_OPTIONS.copy() # ------------------------------------ CONTEXT ----------------------------------- # diff --git a/tests/test_config.py b/tests/test_config.py index c63c847..a9a8765 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,12 @@ # Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause +import importlib + +import pytest + import dataframely as dy +from dataframely import config as _config def test_config_global() -> None: @@ -40,3 +45,23 @@ def test_config_global_local() -> None: assert dy.Config.options["max_sampling_iterations"] == 50 finally: dy.Config.restore_defaults() + + +def test_config_env_var_override(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("DATAFRAMELY_MAX_SAMPLING_ITERATIONS", "123") + try: + importlib.reload(_config) + assert _config.Config.options["max_sampling_iterations"] == 123 + finally: + monkeypatch.delenv("DATAFRAMELY_MAX_SAMPLING_ITERATIONS") + importlib.reload(_config) + # Re-bind dy.Config to the reloaded module's class to keep state consistent. + dy.Config = _config.Config # type: ignore + dy.Config.restore_defaults() + + +def test_config_env_var_not_reread_after_startup( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("DATAFRAMELY_MAX_SAMPLING_ITERATIONS", "777") + assert dy.Config.options["max_sampling_iterations"] == 10_000