diff --git a/docs/design/thread-safe-mode.md b/docs/design/thread-safe-mode.md new file mode 100644 index 000000000..297cb619b --- /dev/null +++ b/docs/design/thread-safe-mode.md @@ -0,0 +1,373 @@ +# Thread-Safe Mode Specification + +## Problem + +DataJoint uses global state (`dj.config`, `dj.conn()`) that is not thread-safe. Multi-tenant applications (web servers, async workers) need isolated connections per request/task. + +## Solution + +Introduce **Instance** objects that encapsulate config and connection. The `dj` module provides a global config that can be modified before connecting, and a lazily-loaded singleton connection. New isolated instances are created with `dj.Instance()`. + +## API + +### Legacy API (global config + singleton connection) + +```python +import datajoint as dj + +# Configure credentials (no connection yet) +dj.config.database.user = "user" +dj.config.database.password = "password" + +# First call to conn() or Schema() creates the singleton connection +dj.conn() # Creates connection using dj.config credentials +schema = dj.Schema("my_schema") + +@schema +class Mouse(dj.Manual): + definition = "..." +``` + +Alternatively, pass credentials directly to `conn()`: +```python +dj.conn(host="localhost", user="user", password="password") +``` + +Internally: +- `dj.config` → delegates to `_global_config` (with thread-safety check) +- `dj.conn()` → returns `_singleton_connection` (created lazily) +- `dj.Schema()` → uses `_singleton_connection` +- `dj.FreeTable()` → uses `_singleton_connection` + +### New API (isolated instance) + +```python +import datajoint as dj + +inst = dj.Instance( + host="localhost", + user="user", + password="password", +) +schema = inst.Schema("my_schema") + +@schema +class Mouse(dj.Manual): + definition = "..." +``` + +### Instance structure + +Each instance has: +- `inst.config` - Config (created fresh at instance creation) +- `inst.connection` - Connection (created at instance creation) +- `inst.Schema()` - Schema factory using instance's connection +- `inst.FreeTable()` - FreeTable factory using instance's connection + +```python +inst = dj.Instance(host="localhost", user="u", password="p") +inst.config # Config instance +inst.connection # Connection instance +inst.Schema("name") # Creates schema using inst.connection +inst.FreeTable("db.tbl") # Access table using inst.connection +``` + +### Table base classes vs instance methods + +**Base classes** (`dj.Manual`, `dj.Lookup`, etc.) - Used with `@schema` decorator: +```python +@schema +class Mouse(dj.Manual): # dj.Manual - schema links to connection + definition = "..." +``` + +**Instance methods** (`inst.Schema()`, `inst.FreeTable()`) - Need connection directly: +```python +schema = inst.Schema("my_schema") # Uses inst.connection +table = inst.FreeTable("db.table") # Uses inst.connection +``` + +### Thread-safe mode + +```bash +export DJ_THREAD_SAFE=true +``` + +`thread_safe` is checked dynamically on each access to global state. + +When `thread_safe=True`, accessing global state raises `ThreadSafetyError`: +- `dj.config` raises `ThreadSafetyError` +- `dj.conn()` raises `ThreadSafetyError` +- `dj.Schema()` raises `ThreadSafetyError` (without explicit connection) +- `dj.FreeTable()` raises `ThreadSafetyError` (without explicit connection) +- `dj.Instance()` works - isolated instances are always allowed + +```python +# thread_safe=True + +dj.config # ThreadSafetyError +dj.conn() # ThreadSafetyError +dj.Schema("name") # ThreadSafetyError + +inst = dj.Instance(host="h", user="u", password="p") # OK +inst.Schema("name") # OK +``` + +## Behavior Summary + +| Operation | `thread_safe=False` | `thread_safe=True` | +|-----------|--------------------|--------------------| +| `dj.config` | `_global_config` | `ThreadSafetyError` | +| `dj.conn()` | `_singleton_connection` | `ThreadSafetyError` | +| `dj.Schema()` | Uses singleton | `ThreadSafetyError` | +| `dj.FreeTable()` | Uses singleton | `ThreadSafetyError` | +| `dj.Instance()` | Works | Works | +| `inst.config` | Works | Works | +| `inst.connection` | Works | Works | +| `inst.Schema()` | Works | Works | + +## Lazy Loading + +The global config is created at module import time. The singleton connection is created lazily on first access: + +```python +dj.config.database.user = "user" # Modifies global config (no connection yet) +dj.config.database.password = "pw" +dj.conn() # Creates singleton connection using global config +dj.Schema("name") # Uses existing singleton connection +``` + +## Usage Example + +```python +import datajoint as dj + +# Create isolated instance +inst = dj.Instance( + host="localhost", + user="user", + password="password", +) + +# Create schema +schema = inst.Schema("my_schema") + +@schema +class Mouse(dj.Manual): + definition = """ + mouse_id: int + """ + +# Use tables +Mouse().insert1({"mouse_id": 1}) +Mouse().fetch() +``` + +## Architecture + +### Object graph + +There is exactly **one** global `Config` object created at import time in `settings.py`. Both the legacy API and the `Instance` API hang off `Connection` objects, each of which carries a `_config` reference. + +``` +settings.py + config = _create_config() ← THE single global Config + +instance.py + _global_config = settings.config ← same object (not a copy) + _singleton_connection = None ← lazily created Connection + +__init__.py + dj.config = _ConfigProxy() ← proxy → _global_config (with thread-safety check) + dj.conn() ← returns _singleton_connection + dj.Schema() ← uses _singleton_connection + dj.FreeTable() ← uses _singleton_connection + +Connection (singleton) + _config → _global_config ← same Config that dj.config writes to + +Connection (Instance) + _config → fresh Config ← isolated per-instance +``` + +### Config flow: singleton path + +``` +dj.config["safemode"] = False + ↓ _ConfigProxy.__setitem__ +_global_config["safemode"] = False (same object as settings.config) + ↓ +Connection._config["safemode"] (points to _global_config) + ↓ +schema.drop() reads self.connection._config["safemode"] → False ✓ +``` + +### Config flow: Instance path + +``` +inst = dj.Instance(host=..., user=..., password=...) + ↓ +inst.config = _create_config() (fresh Config, independent) +inst.connection._config = inst.config + ↓ +inst.config["safemode"] = False + ↓ +schema.drop() reads self.connection._config["safemode"] → False ✓ +``` + +### Key invariant + +**All runtime config reads go through `self.connection._config`**, never through the global `config` directly. This ensures both the singleton and Instance paths read the correct config. + +### Connection-scoped config reads + +Every module that previously imported `from .settings import config` now reads config from the connection: + +| Module | What was read | How it's read now | +|--------|--------------|-------------------| +| `schemas.py` | `config["safemode"]`, `config.database.create_tables` | `self.connection._config[...]` | +| `table.py` | `config["safemode"]` in `delete()`, `drop()` | `self.connection._config["safemode"]` | +| `expression.py` | `config["loglevel"]` in `__repr__()` | `self.connection._config["loglevel"]` | +| `preview.py` | `config["display.*"]` (8 reads) | `query_expression.connection._config[...]` | +| `autopopulate.py` | `config.jobs.allow_new_pk_fields`, `auto_refresh` | `self.connection._config.jobs.*` | +| `jobs.py` | `config.jobs.default_priority`, `stale_timeout`, `keep_completed` | `self.connection._config.jobs.*` | +| `declare.py` | `config.jobs.add_job_metadata` | `config` param (threaded from `table.py`) | +| `diagram.py` | `config.display.diagram_direction` | `self._connection._config.display.*` | +| `staged_insert.py` | `config.get_store_spec()` | `self._table.connection._config.get_store_spec()` | + +### Functions that receive config as a parameter + +Some module-level functions cannot access `self.connection`. Config is threaded through: + +| Function | Caller | How config arrives | +|----------|--------|--------------------| +| `declare()` in `declare.py` | `Table.declare()` in `table.py` | `config=self.connection._config` kwarg | +| `_get_job_version()` in `jobs.py` | `AutoPopulate._make_tuples()`, `Job.reserve()` | `config=self.connection._config` positional arg | + +Both functions accept `config=None` and fall back to the global `settings.config` for backward compatibility. + +## Implementation + +### 1. Create Instance class + +```python +class Instance: + def __init__(self, host, user, password, port=3306, **kwargs): + self.config = _create_config() # Fresh config with defaults + # Apply any config overrides from kwargs + self.connection = Connection(host, user, password, port, ...) + self.connection._config = self.config + + def Schema(self, name, **kwargs): + return Schema(name, connection=self.connection, **kwargs) + + def FreeTable(self, full_table_name): + return FreeTable(self.connection, full_table_name) +``` + +### 2. Global config and singleton connection + +```python +# settings.py - THE single global config +config = _create_config() # Created at import time + +# instance.py - reuses the same config object +_global_config = settings.config # Same reference, not a copy +_singleton_connection = None # Created lazily + +def _check_thread_safe(): + if _load_thread_safe(): + raise ThreadSafetyError( + "Global DataJoint state is disabled in thread-safe mode. " + "Use dj.Instance() to create an isolated instance." + ) + +def _get_singleton_connection(): + _check_thread_safe() + global _singleton_connection + if _singleton_connection is None: + _singleton_connection = Connection( + host=_global_config.database.host, + user=_global_config.database.user, + password=_global_config.database.password, + ... + ) + _singleton_connection._config = _global_config + return _singleton_connection +``` + +### 3. Legacy API with thread-safety checks + +```python +# dj.config -> global config with thread-safety check +class _ConfigProxy: + def __getattr__(self, name): + _check_thread_safe() + return getattr(_global_config, name) + def __setattr__(self, name, value): + _check_thread_safe() + setattr(_global_config, name, value) + +config = _ConfigProxy() + +# dj.conn() -> singleton connection (persistent across calls) +def conn(host=None, user=None, password=None, *, reset=False): + _check_thread_safe() + if reset or (_singleton_connection is None and credentials_provided): + _singleton_connection = Connection(...) + _singleton_connection._config = _global_config + return _get_singleton_connection() + +# dj.Schema() -> uses singleton connection +def Schema(name, connection=None, **kwargs): + if connection is None: + _check_thread_safe() + connection = _get_singleton_connection() + return _Schema(name, connection=connection, **kwargs) + +# dj.FreeTable() -> uses singleton connection +def FreeTable(conn_or_name, full_table_name=None): + if full_table_name is None: + _check_thread_safe() + return _FreeTable(_get_singleton_connection(), conn_or_name) + else: + return _FreeTable(conn_or_name, full_table_name) +``` + +## Global State Audit + +All module-level mutable state was reviewed for thread-safety implications. + +### Guarded (blocked in thread-safe mode) + +| State | Location | Mechanism | +|-------|----------|-----------| +| `config` singleton | `settings.py:979` | `_ConfigProxy` raises `ThreadSafetyError`; use `inst.config` instead | +| `conn()` singleton | `connection.py:108` | `_check_thread_safe()` guard; use `inst.connection` instead | + +These are the two globals that carry connection-scoped state (credentials, database settings) and are the primary source of cross-tenant interference. + +### Safe by design (no guard needed) + +| State | Location | Rationale | +|-------|----------|-----------| +| `_codec_registry` | `codecs.py:47` | Effectively immutable after import. Registration runs in `__init_subclass__` under Python's import lock. Runtime mutation (`_load_entry_points`) is idempotent under the GIL. Codecs are part of the type system, not connection-scoped. | +| `_entry_points_loaded` | `codecs.py:48` | Bool flag for idempotent lazy loading; worst case under concurrent access is redundant work, not corruption. | + +### Low risk (no guard needed) + +| State | Location | Rationale | +|-------|----------|-----------| +| Logging side effects | `logging.py:8,17,40-45,56` | Standard Python logging configuration. Monkey-patches `Logger` and replaces `sys.excepthook` at import time. Not DataJoint-specific mutable state. | +| `use_32bit_dims` | `blob.py:65` | Runtime flag affecting deserialization. Rarely changed; not connection-scoped. | +| `compression` dict | `blob.py:61` | Decompressor function registry. Populated at import time, effectively read-only thereafter. | +| `_lazy_modules` | `__init__.py:92` | Import caching via `globals()` mutation. Protected by Python's import lock. | +| `ADAPTERS` dict | `adapters/__init__.py:16` | Backend registry. Populated at import time, read-only in practice. | + +### Design principle + +Only state that is **connection-scoped** (credentials, database settings, connection objects) needs thread-safe guards. State that is **code-scoped** (type registries, import caches, logging configuration) is shared across all threads by design and does not vary between tenants. + +## Error Messages + +- Singleton access: `"Global DataJoint state is disabled in thread-safe mode. Use dj.Instance() to create an isolated instance."` diff --git a/src/datajoint/__init__.py b/src/datajoint/__init__.py index 7f809487d..68eac160f 100644 --- a/src/datajoint/__init__.py +++ b/src/datajoint/__init__.py @@ -23,6 +23,7 @@ "config", "conn", "Connection", + "Instance", "Schema", "VirtualModule", "virtual_schema", @@ -52,6 +53,7 @@ "errors", "migrate", "DataJointError", + "ThreadSafetyError", "logger", "cli", "ValidationResult", @@ -72,17 +74,186 @@ NpyRef, ) from .blob import MatCell, MatStruct -from .connection import Connection, conn -from .errors import DataJointError +from .connection import Connection +from .errors import DataJointError, ThreadSafetyError from .expression import AndList, Not, Top, U +from .instance import Instance, _ConfigProxy, _get_singleton_connection, _global_config, _check_thread_safe from .logging import logger from .objectref import ObjectRef -from .schemas import Schema, VirtualModule, list_schemas, virtual_schema -from .settings import config -from .table import FreeTable, Table, ValidationResult +from .schemas import Schema as _Schema, VirtualModule, list_schemas, virtual_schema +from .table import FreeTable as _FreeTable, Table, ValidationResult from .user_tables import Computed, Imported, Lookup, Manual, Part from .version import __version__ +# ============================================================================= +# Singleton-aware API +# ============================================================================= +# config is a proxy that delegates to the singleton instance's config +config = _ConfigProxy() + + +def conn( + host: str | None = None, + user: str | None = None, + password: str | None = None, + *, + reset: bool = False, + use_tls: bool | dict | None = None, +) -> Connection: + """ + Return a persistent connection object. + + When called without arguments, returns the singleton connection using + credentials from dj.config. When connection parameters are provided, + updates the singleton connection with the new credentials. + + Parameters + ---------- + host : str, optional + Database hostname. If provided, updates singleton. + user : str, optional + Database username. If provided, updates singleton. + password : str, optional + Database password. If provided, updates singleton. + reset : bool, optional + If True, reset existing connection. Default False. + use_tls : bool or dict, optional + TLS encryption option. + + Returns + ------- + Connection + Database connection. + + Raises + ------ + ThreadSafetyError + If thread_safe mode is enabled. + """ + import datajoint.instance as instance_module + from pydantic import SecretStr + + _check_thread_safe() + + # If reset requested, always recreate + # If credentials provided and no singleton exists, create one + # If credentials provided and singleton exists, return existing singleton + if reset or ( + instance_module._singleton_connection is None and (host is not None or user is not None or password is not None) + ): + # Use provided values or fall back to config + host = host if host is not None else _global_config.database.host + user = user if user is not None else _global_config.database.user + raw_password = password if password is not None else _global_config.database.password + password = raw_password.get_secret_value() if isinstance(raw_password, SecretStr) else raw_password + port = _global_config.database.port + use_tls = use_tls if use_tls is not None else _global_config.database.use_tls + + if user is None: + from .errors import DataJointError + + raise DataJointError("Database user not configured. Set dj.config['database.user'] or pass user= argument.") + if password is None: + from .errors import DataJointError + + raise DataJointError( + "Database password not configured. Set dj.config['database.password'] or pass password= argument." + ) + + instance_module._singleton_connection = Connection(host, user, password, port, use_tls) + instance_module._singleton_connection._config = _global_config + + return _get_singleton_connection() + + +def Schema( + schema_name: str | None = None, + context: dict | None = None, + *, + connection: Connection | None = None, + create_schema: bool = True, + create_tables: bool | None = None, + add_objects: dict | None = None, +) -> _Schema: + """ + Create a Schema for binding table classes to a database schema. + + When connection is not provided, uses the singleton connection. + + Parameters + ---------- + schema_name : str, optional + Database schema name. + context : dict, optional + Namespace for foreign key lookup. + connection : Connection, optional + Database connection. Defaults to singleton connection. + create_schema : bool, optional + If False, raise error if schema doesn't exist. Default True. + create_tables : bool, optional + If False, raise error when accessing missing tables. + add_objects : dict, optional + Additional objects for declaration context. + + Returns + ------- + Schema + A Schema bound to the specified connection. + + Raises + ------ + ThreadSafetyError + If thread_safe mode is enabled and using singleton. + """ + if connection is None: + # Use singleton connection - will raise ThreadSafetyError if thread_safe=True + _check_thread_safe() + connection = _get_singleton_connection() + + return _Schema( + schema_name, + context=context, + connection=connection, + create_schema=create_schema, + create_tables=create_tables, + add_objects=add_objects, + ) + + +def FreeTable(conn_or_name, full_table_name: str | None = None) -> _FreeTable: + """ + Create a FreeTable for accessing a table without a dedicated class. + + Can be called in two ways: + - ``FreeTable("schema.table")`` - uses singleton connection + - ``FreeTable(connection, "schema.table")`` - uses provided connection + + Parameters + ---------- + conn_or_name : Connection or str + Either a Connection object, or the full table name if using singleton. + full_table_name : str, optional + Full table name when first argument is a connection. + + Returns + ------- + FreeTable + A FreeTable instance for the specified table. + + Raises + ------ + ThreadSafetyError + If thread_safe mode is enabled and using singleton. + """ + if full_table_name is None: + # Called as FreeTable("db.table") - use singleton connection + _check_thread_safe() + return _FreeTable(_get_singleton_connection(), conn_or_name) + else: + # Called as FreeTable(conn, "db.table") - use provided connection + return _FreeTable(conn_or_name, full_table_name) + + # ============================================================================= # Lazy imports — heavy dependencies loaded on first access # ============================================================================= diff --git a/src/datajoint/adapters/mysql.py b/src/datajoint/adapters/mysql.py index 88339335f..21aab2908 100644 --- a/src/datajoint/adapters/mysql.py +++ b/src/datajoint/adapters/mysql.py @@ -75,7 +75,6 @@ def connect( Password for authentication. **kwargs : Any Additional MySQL-specific parameters: - - init_command: SQL initialization command - ssl: TLS/SSL configuration dict (deprecated, use use_tls) - use_tls: bool or dict - DataJoint's SSL parameter (preferred) - charset: Character set (default from kwargs) @@ -85,7 +84,6 @@ def connect( pymysql.Connection MySQL connection object. """ - init_command = kwargs.get("init_command") # Handle both ssl (old) and use_tls (new) parameter names ssl_config = kwargs.get("use_tls", kwargs.get("ssl")) # Convert boolean True to dict for PyMySQL (PyMySQL expects dict or SSLContext) @@ -99,7 +97,6 @@ def connect( "port": port, "user": user, "passwd": password, - "init_command": init_command, "sql_mode": "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", "charset": charset, diff --git a/src/datajoint/autopopulate.py b/src/datajoint/autopopulate.py index 7660e43ec..ae8be3b82 100644 --- a/src/datajoint/autopopulate.py +++ b/src/datajoint/autopopulate.py @@ -146,10 +146,8 @@ def _declare_check(self, primary_key: list[str], fk_attribute_map: dict[str, tup If native (non-FK) PK attributes are found, unless bypassed via ``dj.config.jobs.allow_new_pk_fields_in_computed_tables = True``. """ - from .settings import config - # Check if validation is bypassed - if config.jobs.allow_new_pk_fields_in_computed_tables: + if self.connection._config.jobs.allow_new_pk_fields_in_computed_tables: return # Check for native (non-FK) primary key attributes @@ -477,8 +475,6 @@ def _populate_distributed( """ from tqdm import tqdm - from .settings import config - # Define a signal handler for SIGTERM def handler(signum, frame): logger.info("Populate terminated by SIGTERM") @@ -489,7 +485,7 @@ def handler(signum, frame): try: # Refresh job queue if configured if refresh is None: - refresh = config.jobs.auto_refresh + refresh = self.connection._config.jobs.auto_refresh if refresh: # Use delay=-1 to ensure jobs are immediately schedulable # (avoids race condition with scheduled_time <= CURRENT_TIMESTAMP(3) check) @@ -659,7 +655,7 @@ def _populate1( key, start_time=datetime.datetime.fromtimestamp(start_time), duration=duration, - version=_get_job_version(), + version=_get_job_version(self.connection._config), ) if jobs is not None: diff --git a/src/datajoint/codecs.py b/src/datajoint/codecs.py index 5c192d46e..f4741a5e4 100644 --- a/src/datajoint/codecs.py +++ b/src/datajoint/codecs.py @@ -43,7 +43,15 @@ class MyTable(dj.Manual): logger = logging.getLogger(__name__.split(".")[0]) -# Global codec registry - maps name to Codec instance +# Global codec registry - maps name to Codec instance. +# +# Thread safety: This registry is effectively immutable after import. +# Registration happens in __init_subclass__ during class definition, which is +# serialized by Python's import lock. The only runtime mutation is +# _load_entry_points(), which is idempotent and guarded by a bool flag; +# under CPython's GIL, concurrent calls may do redundant work but cannot +# corrupt the dict. Codecs are part of the type system (tied to code, not to +# any particular connection or tenant), so per-instance isolation is unnecessary. _codec_registry: dict[str, Codec] = {} _entry_points_loaded: bool = False diff --git a/src/datajoint/connection.py b/src/datajoint/connection.py index 21b48e638..827a7a9bd 100644 --- a/src/datajoint/connection.py +++ b/src/datajoint/connection.py @@ -11,7 +11,6 @@ import re import warnings from contextlib import contextmanager -from typing import Callable from . import errors from .adapters import get_adapter @@ -55,7 +54,6 @@ def conn( user: str | None = None, password: str | None = None, *, - init_fun: Callable | None = None, reset: bool = False, use_tls: bool | dict | None = None, ) -> Connection: @@ -73,8 +71,6 @@ def conn( Database username. Required if not set in config. password : str, optional Database password. Required if not set in config. - init_fun : callable, optional - Initialization function called after connection. reset : bool, optional If True, reset existing connection. Default False. use_tls : bool or dict, optional @@ -103,9 +99,8 @@ def conn( raise errors.DataJointError( "Database password not configured. Set datajoint.config['database.password'] or pass password= argument." ) - init_fun = init_fun if init_fun is not None else config["connection.init_function"] use_tls = use_tls if use_tls is not None else config["database.use_tls"] - conn.connection = Connection(host, user, password, None, init_fun, use_tls) + conn.connection = Connection(host, user, password, None, use_tls) return conn.connection @@ -150,8 +145,6 @@ class Connection: Database password. port : int, optional Port number. Overridden if specified in host. - init_fun : str, optional - SQL initialization command. use_tls : bool or dict, optional TLS encryption option. @@ -169,7 +162,6 @@ def __init__( user: str, password: str, port: int | None = None, - init_fun: str | None = None, use_tls: bool | dict | None = None, ) -> None: if ":" in host: @@ -190,13 +182,15 @@ def __init__( # use_tls=True: enable SSL with default settings self.conn_info["ssl"] = True self.conn_info["ssl_input"] = use_tls - self.init_fun = init_fun self._conn = None self._query_cache = None self._is_closed = True # Mark as closed until connect() succeeds + # Config reference - defaults to global config, but Instance sets its own + self._config = config + # Select adapter based on configured backend - backend = config["database.backend"] + backend = self._config["database.backend"] self.adapter = get_adapter(backend) self.connect() @@ -227,8 +221,7 @@ def connect(self) -> None: port=self.conn_info["port"], user=self.conn_info["user"], password=self.conn_info["passwd"], - init_command=self.init_fun, - charset=config["connection.charset"], + charset=self._config["connection.charset"], use_tls=self.conn_info.get("ssl"), ) except Exception as ssl_error: @@ -244,8 +237,7 @@ def connect(self) -> None: port=self.conn_info["port"], user=self.conn_info["user"], password=self.conn_info["passwd"], - init_command=self.init_fun, - charset=config["connection.charset"], + charset=self._config["connection.charset"], use_tls=False, # Explicitly disable SSL for fallback ) else: @@ -271,8 +263,8 @@ def set_query_cache(self, query_cache: str | None = None) -> None: def purge_query_cache(self) -> None: """Delete all cached query results.""" - if isinstance(config.get(cache_key), str) and pathlib.Path(config[cache_key]).is_dir(): - for path in pathlib.Path(config[cache_key]).iterdir(): + if isinstance(self._config.get(cache_key), str) and pathlib.Path(self._config[cache_key]).is_dir(): + for path in pathlib.Path(self._config[cache_key]).iterdir(): if not path.is_dir(): path.unlink() @@ -413,11 +405,11 @@ def query( if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query): raise errors.DataJointError("Only SELECT queries are allowed when query caching is on.") if use_query_cache: - if not config[cache_key]: + if not self._config[cache_key]: raise errors.DataJointError(f"Provide filepath dj.config['{cache_key}'] when using query caching.") # Cache key is backend-specific (no identifier normalization needed) hash_ = hashlib.md5((str(self._query_cache)).encode() + pack(args) + query.encode()).hexdigest() - cache_path = pathlib.Path(config[cache_key]) / str(hash_) + cache_path = pathlib.Path(self._config[cache_key]) / str(hash_) try: buffer = cache_path.read_bytes() except FileNotFoundError: @@ -426,7 +418,7 @@ def query( return EmulatedCursor(unpack(buffer)) if reconnect is None: - reconnect = config["database.reconnect"] + reconnect = self._config["database.reconnect"] logger.debug("Executing SQL:" + query[:query_log_max_length]) cursor = self.adapter.get_cursor(self._conn, as_dict=as_dict) try: diff --git a/src/datajoint/declare.py b/src/datajoint/declare.py index 375daa07e..fe50e8a66 100644 --- a/src/datajoint/declare.py +++ b/src/datajoint/declare.py @@ -15,7 +15,6 @@ from .codecs import lookup_codec from .condition import translate_attribute from .errors import DataJointError -from .settings import config # Core DataJoint types - scientist-friendly names that are fully supported # These are recorded in field comments using :type: syntax for reconstruction @@ -401,7 +400,7 @@ def prepare_declare( def declare( - full_table_name: str, definition: str, context: dict, adapter + full_table_name: str, definition: str, context: dict, adapter, *, config=None ) -> tuple[str, list[str], list[str], dict[str, tuple[str, str]], list[str], list[str]]: r""" Parse a definition and generate SQL CREATE TABLE statement. @@ -416,6 +415,8 @@ def declare( Namespace for resolving foreign key references. adapter : DatabaseAdapter Database adapter for backend-specific SQL generation. + config : Config, optional + Configuration object. If None, falls back to global config. Returns ------- @@ -464,6 +465,10 @@ def declare( ) = prepare_declare(definition, context, adapter) # Add hidden job metadata for Computed/Imported tables (not parts) + if config is None: + from .settings import config as _config + + config = _config if config.jobs.add_job_metadata: # Check if this is a Computed (__) or Imported (_) table, but not a Part (contains __ in middle) is_computed = table_name.startswith("__") and "__" not in table_name[2:] diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 7034d122b..75e00c21c 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -16,7 +16,6 @@ from .dependencies import topo_sort from .errors import DataJointError -from .settings import config from .table import Table, lookup_class_name from .user_tables import Computed, Imported, Lookup, Manual, Part, _AliasNode, _get_tier @@ -105,6 +104,7 @@ def __init__(self, source, context=None) -> None: self.nodes_to_show = set(source.nodes_to_show) self._expanded_nodes = set(source._expanded_nodes) self.context = source.context + self._connection = source._connection super().__init__(source) return @@ -126,6 +126,7 @@ def __init__(self, source, context=None) -> None: raise DataJointError("Could not find database connection in %s" % repr(source[0])) # initialize graph from dependencies + self._connection = connection connection.dependencies.load() super().__init__(connection.dependencies) @@ -584,7 +585,7 @@ def make_dot(self): Tables are grouped by schema, with the Python module name shown as the group label when available. """ - direction = config.display.diagram_direction + direction = self._connection._config.display.diagram_direction graph = self._make_graph() # Apply collapse logic if needed @@ -857,7 +858,7 @@ def make_mermaid(self) -> str: Session --> Neuron """ graph = self._make_graph() - direction = config.display.diagram_direction + direction = self._connection._config.display.diagram_direction # Apply collapse logic if needed graph, collapsed_counts = self._apply_collapse(graph) diff --git a/src/datajoint/errors.py b/src/datajoint/errors.py index 7e10f021d..bba032b23 100644 --- a/src/datajoint/errors.py +++ b/src/datajoint/errors.py @@ -72,3 +72,7 @@ class MissingExternalFile(DataJointError): class BucketInaccessible(DataJointError): """S3 bucket is inaccessible.""" + + +class ThreadSafetyError(DataJointError): + """Global DataJoint state is disabled in thread-safe mode.""" diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py index 883853cd3..9b36cf6d0 100644 --- a/src/datajoint/expression.py +++ b/src/datajoint/expression.py @@ -20,7 +20,6 @@ from .errors import DataJointError from .codecs import decode_attribute from .preview import preview, repr_html -from .settings import config logger = logging.getLogger(__name__.split(".")[0]) @@ -1247,7 +1246,7 @@ def __repr__(self): str String representation of the QueryExpression. """ - return super().__repr__() if config["loglevel"].lower() == "debug" else self.preview() + return super().__repr__() if self.connection._config["loglevel"].lower() == "debug" else self.preview() def preview(self, limit=None, width=None): """ diff --git a/src/datajoint/instance.py b/src/datajoint/instance.py new file mode 100644 index 000000000..c60e267e1 --- /dev/null +++ b/src/datajoint/instance.py @@ -0,0 +1,300 @@ +""" +DataJoint Instance for thread-safe operation. + +An Instance encapsulates a config and connection pair, providing isolated +database contexts for multi-tenant applications. +""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any + +from .connection import Connection +from .errors import ThreadSafetyError +from .settings import Config, _create_config, config as _settings_config + +if TYPE_CHECKING: + from .schemas import Schema as SchemaClass + from .table import FreeTable as FreeTableClass + + +def _load_thread_safe() -> bool: + """ + Load thread_safe setting from environment or config file. + + Returns + ------- + bool + True if thread-safe mode is enabled. + """ + # Check environment variable first + env_val = os.environ.get("DJ_THREAD_SAFE", "").lower() + if env_val in ("true", "1", "yes"): + return True + if env_val in ("false", "0", "no"): + return False + + # Default: thread-safe mode is off + return False + + +class Instance: + """ + Encapsulates a DataJoint configuration and connection. + + Each Instance has its own Config and Connection, providing isolation + for multi-tenant applications. Use ``dj.Instance()`` to create isolated + instances, or access the singleton via ``dj.config``, ``dj.conn()``, etc. + + Parameters + ---------- + host : str + Database hostname. + user : str + Database username. + password : str + Database password. + port : int, optional + Database port. Default from config or 3306. + use_tls : bool or dict, optional + TLS configuration. + **kwargs : Any + Additional config overrides applied to this instance's config. + + Attributes + ---------- + config : Config + Configuration for this instance. + connection : Connection + Database connection for this instance. + + Examples + -------- + >>> inst = dj.Instance(host="localhost", user="root", password="secret") + >>> inst.config.safemode = False + >>> schema = inst.Schema("my_schema") + """ + + def __init__( + self, + host: str, + user: str, + password: str, + port: int | None = None, + use_tls: bool | dict | None = None, + **kwargs: Any, + ) -> None: + # Create fresh config with defaults loaded from env/file + self.config = _create_config() + + # Apply any config overrides from kwargs + for key, value in kwargs.items(): + if hasattr(self.config, key): + setattr(self.config, key, value) + elif "__" in key: + # Handle nested keys like database__reconnect + parts = key.split("__") + obj = self.config + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + # Determine port + if port is None: + port = self.config.database.port + + # Create connection + self.connection = Connection(host, user, password, port, use_tls) + + # Attach config to connection so tables can access it + self.connection._config = self.config + + def Schema( + self, + schema_name: str, + *, + context: dict[str, Any] | None = None, + create_schema: bool = True, + create_tables: bool | None = None, + add_objects: dict[str, Any] | None = None, + ) -> "SchemaClass": + """ + Create a Schema bound to this instance's connection. + + Parameters + ---------- + schema_name : str + Database schema name. + context : dict, optional + Namespace for foreign key lookup. + create_schema : bool, optional + If False, raise error if schema doesn't exist. Default True. + create_tables : bool, optional + If False, raise error when accessing missing tables. + add_objects : dict, optional + Additional objects for declaration context. + + Returns + ------- + Schema + A Schema using this instance's connection. + """ + from .schemas import Schema + + return Schema( + schema_name, + context=context, + connection=self.connection, + create_schema=create_schema, + create_tables=create_tables, + add_objects=add_objects, + ) + + def FreeTable(self, full_table_name: str) -> "FreeTableClass": + """ + Create a FreeTable bound to this instance's connection. + + Parameters + ---------- + full_table_name : str + Full table name as ``'schema.table'`` or ```schema`.`table```. + + Returns + ------- + FreeTable + A FreeTable using this instance's connection. + """ + from .table import FreeTable + + return FreeTable(self.connection, full_table_name) + + def __repr__(self) -> str: + return f"Instance({self.connection!r})" + + +# ============================================================================= +# Singleton management +# ============================================================================= +# The global config is created at module load time and can be modified +# The singleton connection is created lazily when conn() or Schema() is called + +# Reuse the config created in settings.py — there must be exactly one global config +_global_config: Config = _settings_config +_singleton_connection: Connection | None = None + + +def _check_thread_safe() -> None: + """ + Check if thread-safe mode is enabled and raise if so. + + Raises + ------ + ThreadSafetyError + If thread_safe mode is enabled. + """ + if _load_thread_safe(): + raise ThreadSafetyError( + "Global DataJoint state is disabled in thread-safe mode. " "Use dj.Instance() to create an isolated instance." + ) + + +def _get_singleton_connection() -> Connection: + """ + Get or create the singleton Connection. + + Uses credentials from the global config. + + Raises + ------ + ThreadSafetyError + If thread_safe mode is enabled. + DataJointError + If credentials are not configured. + """ + global _singleton_connection + + _check_thread_safe() + + if _singleton_connection is None: + from .errors import DataJointError + + host = _global_config.database.host + user = _global_config.database.user + raw_password = _global_config.database.password + password = raw_password.get_secret_value() if raw_password is not None else None + port = _global_config.database.port + use_tls = _global_config.database.use_tls + + if user is None: + raise DataJointError( + "Database user not configured. Set dj.config['database.user'] or DJ_USER environment variable." + ) + if password is None: + raise DataJointError( + "Database password not configured. Set dj.config['database.password'] or DJ_PASS environment variable." + ) + + _singleton_connection = Connection(host, user, password, port, use_tls) + # Attach global config to connection + _singleton_connection._config = _global_config + + return _singleton_connection + + +class _ConfigProxy: + """ + Proxy that delegates to the global config, with thread-safety checks. + + In thread-safe mode, all access raises ThreadSafetyError. + """ + + def __getattr__(self, name: str) -> Any: + _check_thread_safe() + return getattr(_global_config, name) + + def __setattr__(self, name: str, value: Any) -> None: + _check_thread_safe() + setattr(_global_config, name, value) + + def __getitem__(self, key: str) -> Any: + _check_thread_safe() + return _global_config[key] + + def __setitem__(self, key: str, value: Any) -> None: + _check_thread_safe() + _global_config[key] = value + + def __delitem__(self, key: str) -> None: + _check_thread_safe() + del _global_config[key] + + def get(self, key: str, default: Any = None) -> Any: + _check_thread_safe() + return _global_config.get(key, default) + + def override(self, **kwargs: Any): + _check_thread_safe() + return _global_config.override(**kwargs) + + def load(self, filename: str) -> None: + _check_thread_safe() + return _global_config.load(filename) + + def get_store_spec(self, store: str | None = None, *, use_filepath_default: bool = False) -> dict[str, Any]: + _check_thread_safe() + return _global_config.get_store_spec(store, use_filepath_default=use_filepath_default) + + @staticmethod + def save_template( + path: str = "datajoint.json", + minimal: bool = True, + create_secrets_dir: bool = True, + ): + # save_template is a static method, no thread-safety check needed + return Config.save_template(path, minimal, create_secrets_dir) + + def __repr__(self) -> str: + if _load_thread_safe(): + return "ConfigProxy (thread-safe mode - use dj.Instance())" + return repr(_global_config) diff --git a/src/datajoint/jobs.py b/src/datajoint/jobs.py index e5499eb8e..cf0981836 100644 --- a/src/datajoint/jobs.py +++ b/src/datajoint/jobs.py @@ -24,16 +24,22 @@ logger = logging.getLogger(__name__.split(".")[0]) -def _get_job_version() -> str: +def _get_job_version(config=None) -> str: """ Get version string based on config settings. + Parameters + ---------- + config : Config, optional + Configuration object. If None, falls back to global config. + Returns ------- str Version string, or empty string if version tracking disabled. """ - from .settings import config + if config is None: + from .settings import config method = config.jobs.version_method if method is None or method == "none": @@ -349,17 +355,15 @@ def refresh( 3. Remove stale jobs: jobs older than stale_timeout whose keys not in key_source 4. Remove orphaned jobs: reserved jobs older than orphan_timeout (if specified) """ - from .settings import config - # Ensure jobs table exists if not self.is_declared: self.declare() # Get defaults from config if priority is None: - priority = config.jobs.default_priority + priority = self.connection._config.jobs.default_priority if stale_timeout is None: - stale_timeout = config.jobs.stale_timeout + stale_timeout = self.connection._config.jobs.stale_timeout result = {"added": 0, "removed": 0, "orphaned": 0, "re_pended": 0} @@ -392,7 +396,7 @@ def refresh( pass # Job already exists # 2. Re-pend success jobs if keep_completed=True - if config.jobs.keep_completed: + if self.connection._config.jobs.keep_completed: # Success jobs whose keys are in key_source but not in target # Disable semantic_check for Job table operations (job table PK has different lineage than target) success_to_repend = self.completed.restrict(key_source, semantic_check=False).restrict( @@ -463,7 +467,7 @@ def reserve(self, key: dict) -> bool: "pid": os.getpid(), "connection_id": self.connection.connection_id, "user": self.connection.get_user(), - "version": _get_job_version(), + "version": _get_job_version(self.connection._config), } try: @@ -490,9 +494,7 @@ def complete(self, key: dict, duration: float | None = None) -> None: - If True: updates status to ``'success'`` with completion time and duration - If False: deletes the job entry """ - from .settings import config - - if config.jobs.keep_completed: + if self.connection._config.jobs.keep_completed: # Use server time for completed_time server_now = self.connection.query("SELECT CURRENT_TIMESTAMP").fetchone()[0] pk = self._get_pk(key) @@ -550,13 +552,11 @@ def ignore(self, key: dict) -> None: key : dict Primary key dict of the job. """ - from .settings import config - pk = self._get_pk(key) if pk in self: self.update1({**pk, "status": "ignore"}) else: - priority = config.jobs.default_priority + priority = self.connection._config.jobs.default_priority self.insert1({**pk, "status": "ignore", "priority": priority}) def progress(self) -> dict: diff --git a/src/datajoint/preview.py b/src/datajoint/preview.py index 92d09d874..0b80ad15f 100644 --- a/src/datajoint/preview.py +++ b/src/datajoint/preview.py @@ -2,8 +2,6 @@ import json -from .settings import config - def _format_object_display(json_data): """Format object metadata for display in query results.""" @@ -44,6 +42,7 @@ def _get_blob_placeholder(heading, field_name, html_escape=False): def preview(query_expression, limit, width): heading = query_expression.heading rel = query_expression.proj(*heading.non_blobs) + config = query_expression.connection._config # Object fields use codecs - not specially handled in simplified model object_fields = [] if limit is None: @@ -105,6 +104,7 @@ def get_display_value(tup, f, idx): def repr_html(query_expression): heading = query_expression.heading rel = query_expression.proj(*heading.non_blobs) + config = query_expression.connection._config # Object fields use codecs - not specially handled in simplified model object_fields = [] tuples = rel.to_arrays(limit=config["display.limit"] + 1) diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py index 2955fd67d..694250c7d 100644 --- a/src/datajoint/schemas.py +++ b/src/datajoint/schemas.py @@ -16,14 +16,13 @@ import warnings from typing import TYPE_CHECKING, Any -from .connection import conn from .errors import AccessError, DataJointError +from .instance import _get_singleton_connection if TYPE_CHECKING: from .connection import Connection from .heading import Heading from .jobs import Job -from .settings import config from .table import FreeTable, lookup_class_name from .user_tables import Computed, Imported, Lookup, Manual, Part, _get_tier from .utils import to_camel_case, user_choice @@ -120,7 +119,7 @@ def __init__( self.database = None self.context = context self.create_schema = create_schema - self.create_tables = create_tables if create_tables is not None else config.database.create_tables + self.create_tables = create_tables # None means "use connection config default" self.add_objects = add_objects self.declare_list = [] if schema_name: @@ -174,7 +173,7 @@ def activate( if connection is not None: self.connection = connection if self.connection is None: - self.connection = conn() + self.connection = _get_singleton_connection() self.database = schema_name if create_schema is not None: self.create_schema = create_schema @@ -293,7 +292,10 @@ def _decorate_table(self, table_class: type, context: dict[str, Any], assert_dec # instantiate the class, declare the table if not already instance = table_class() is_declared = instance.is_declared - if not is_declared and not assert_declared and self.create_tables: + create_tables = ( + self.create_tables if self.create_tables is not None else self.connection._config.database.create_tables + ) + if not is_declared and not assert_declared and create_tables: instance.declare(context) self.connection.dependencies.clear() is_declared = is_declared or instance.is_declared @@ -409,7 +411,7 @@ def drop(self, prompt: bool | None = None) -> None: AccessError If insufficient permissions to drop the schema. """ - prompt = config["safemode"] if prompt is None else prompt + prompt = self.connection._config["safemode"] if prompt is None else prompt if not self.exists: logger.info("Schema named `{database}` does not exist. Doing nothing.".format(database=self.database)) @@ -858,7 +860,7 @@ def list_schemas(connection: Connection | None = None) -> list[str]: """ return [ r[0] - for r in (connection or conn()).query( + for r in (connection or _get_singleton_connection()).query( 'SELECT schema_name FROM information_schema.schemata WHERE schema_name <> "information_schema"' ) ] diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index e373ca38f..7019d8345 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -224,7 +224,6 @@ class ConnectionSettings(BaseSettings): model_config = SettingsConfigDict(extra="forbid", validate_assignment=True) - init_function: str | None = None charset: str = "" # pymysql uses '' as default @@ -341,11 +340,8 @@ class Config(BaseSettings): # Top-level settings loglevel: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(default="INFO", validation_alias="DJ_LOG_LEVEL") safemode: bool = True - enable_python_native_blobs: bool = True - filepath_checksum_size_limit: int | None = None - # Cache paths - cache: Path | None = None + # Cache path for query results query_cache: Path | None = None # Download path for attachments and filepaths @@ -362,7 +358,7 @@ def set_logger_level(cls, v: str) -> str: logger.setLevel(v) return v - @field_validator("cache", "query_cache", mode="before") + @field_validator("query_cache", mode="before") @classmethod def convert_path(cls, v: Any) -> Path | None: """Convert string paths to Path objects.""" @@ -819,7 +815,6 @@ def save_template( "use_tls": None, }, "connection": { - "init_function": None, "charset": "", }, "display": { @@ -844,8 +839,6 @@ def save_template( }, "loglevel": "INFO", "safemode": True, - "enable_python_native_blobs": True, - "cache": None, "query_cache": None, "download_path": ".", } diff --git a/src/datajoint/staged_insert.py b/src/datajoint/staged_insert.py index 6ac3819e4..1f6ee7afb 100644 --- a/src/datajoint/staged_insert.py +++ b/src/datajoint/staged_insert.py @@ -14,7 +14,6 @@ import fsspec from .errors import DataJointError -from .settings import config from .storage import StorageBackend, build_object_path @@ -69,7 +68,7 @@ def _ensure_backend(self): """Ensure storage backend is initialized.""" if self._backend is None: try: - spec = config.get_store_spec() # Uses stores.default + spec = self._table.connection._config.get_store_spec() # Uses stores.default self._backend = StorageBackend(spec) except DataJointError: raise DataJointError( @@ -110,7 +109,7 @@ def _get_storage_path(self, field: str, ext: str = "") -> str: ) # Get storage spec (uses stores.default) - spec = config.get_store_spec() + spec = self._table.connection._config.get_store_spec() partition_pattern = spec.get("partition_pattern") token_length = spec.get("token_length", 8) diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 59279489e..a6bc7d2c9 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -23,7 +23,6 @@ ) from .expression import QueryExpression from .heading import Heading -from .settings import config from .staged_insert import staged_insert1 as _staged_insert1 from .utils import get_master, is_camel_case, user_choice @@ -153,7 +152,7 @@ def declare(self, context=None): "Class names must be in CamelCase, starting with a capital letter." ) sql, _external_stores, primary_key, fk_attribute_map, pre_ddl, post_ddl = declare( - self.full_table_name, self.definition, context, self.connection.adapter + self.full_table_name, self.definition, context, self.connection.adapter, config=self.connection._config ) # Call declaration hook for validation (subclasses like AutoPopulate can override) @@ -1119,7 +1118,7 @@ def strip_quotes(s): raise DataJointError("Exceeded maximum number of delete attempts.") return delete_count - prompt = config["safemode"] if prompt is None else prompt + prompt = self.connection._config["safemode"] if prompt is None else prompt # Start transaction if transaction: @@ -1227,7 +1226,7 @@ def drop(self, prompt: bool | None = None): raise DataJointError( "A table with an applied restriction cannot be dropped. Call drop() on the unrestricted Table." ) - prompt = config["safemode"] if prompt is None else prompt + prompt = self.connection._config["safemode"] if prompt is None else prompt self.connection.dependencies.load() do_drop = True diff --git a/tests/conftest.py b/tests/conftest.py index 4d6adf09c..8efaab745 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -536,13 +536,13 @@ def mock_stores(stores_config): @pytest.fixture def mock_cache(tmpdir_factory): - og_cache = dj.config.get("cache") - dj.config["cache"] = tmpdir_factory.mktemp("cache") + og_cache = dj.config.get("download_path") + dj.config["download_path"] = str(tmpdir_factory.mktemp("cache")) yield if og_cache is None: - del dj.config["cache"] + del dj.config["download_path"] else: - dj.config["cache"] = og_cache + dj.config["download_path"] = og_cache @pytest.fixture(scope="session") diff --git a/tests/integration/test_jobs.py b/tests/integration/test_jobs.py index 20fa3233d..5a9203dca 100644 --- a/tests/integration/test_jobs.py +++ b/tests/integration/test_jobs.py @@ -108,10 +108,9 @@ def test_sigterm(clean_jobs, schema_any): def test_suppress_dj_errors(clean_jobs, schema_any): - """Test that DataJoint errors are suppressible without native py blobs.""" + """Test that DataJoint errors are suppressible.""" error_class = schema.ErrorClass() - with dj.config.override(enable_python_native_blobs=False): - error_class.populate(reserve_jobs=True, suppress_errors=True) + error_class.populate(reserve_jobs=True, suppress_errors=True) assert len(schema.DjExceptionName()) == len(error_class.jobs.errors) > 0 diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index ef621765d..cf053df62 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -265,5 +265,5 @@ class Recording(dj.Manual): id: smallint """ - schema2.drop() - schema1.drop() + schema2.drop(prompt=False) + schema1.drop(prompt=False) diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index af5718503..475d96df9 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -504,23 +504,23 @@ def test_display_limit(self): class TestCachePaths: """Test cache path settings.""" - def test_cache_path_string(self): - """Test setting cache path as string.""" - original = dj.config.cache + def test_query_cache_path_string(self): + """Test setting query_cache path as string.""" + original = dj.config.query_cache try: - dj.config.cache = "/tmp/cache" - assert dj.config.cache == Path("/tmp/cache") + dj.config.query_cache = "/tmp/cache" + assert dj.config.query_cache == Path("/tmp/cache") finally: - dj.config.cache = original + dj.config.query_cache = original - def test_cache_path_none(self): - """Test cache path can be None.""" - original = dj.config.cache + def test_query_cache_path_none(self): + """Test query_cache path can be None.""" + original = dj.config.query_cache try: - dj.config.cache = None - assert dj.config.cache is None + dj.config.query_cache = None + assert dj.config.query_cache is None finally: - dj.config.cache = original + dj.config.query_cache = original class TestSaveTemplate: diff --git a/tests/unit/test_thread_safe.py b/tests/unit/test_thread_safe.py new file mode 100644 index 000000000..bec45e434 --- /dev/null +++ b/tests/unit/test_thread_safe.py @@ -0,0 +1,171 @@ +"""Tests for thread-safe mode functionality.""" + +import pytest + + +class TestThreadSafeMode: + """Test thread-safe mode behavior.""" + + def test_thread_safe_env_var_true(self, monkeypatch): + """DJ_THREAD_SAFE=true enables thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + # Re-import to pick up the new env var + from datajoint.instance import _load_thread_safe + + assert _load_thread_safe() is True + + def test_thread_safe_env_var_false(self, monkeypatch): + """DJ_THREAD_SAFE=false disables thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "false") + + from datajoint.instance import _load_thread_safe + + assert _load_thread_safe() is False + + def test_thread_safe_env_var_1(self, monkeypatch): + """DJ_THREAD_SAFE=1 enables thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "1") + + from datajoint.instance import _load_thread_safe + + assert _load_thread_safe() is True + + def test_thread_safe_env_var_yes(self, monkeypatch): + """DJ_THREAD_SAFE=yes enables thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "yes") + + from datajoint.instance import _load_thread_safe + + assert _load_thread_safe() is True + + def test_thread_safe_default_false(self, monkeypatch): + """Thread-safe mode defaults to False.""" + monkeypatch.delenv("DJ_THREAD_SAFE", raising=False) + + from datajoint.instance import _load_thread_safe + + assert _load_thread_safe() is False + + +class TestConfigProxyThreadSafe: + """Test ConfigProxy behavior in thread-safe mode.""" + + def test_config_access_raises_in_thread_safe_mode(self, monkeypatch): + """Accessing config raises ThreadSafetyError in thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError): + _ = dj.config.database + + def test_config_access_works_in_normal_mode(self, monkeypatch): + """Accessing config works in normal mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "false") + + import datajoint as dj + + # Should not raise + host = dj.config.database.host + assert isinstance(host, str) + + def test_config_set_raises_in_thread_safe_mode(self, monkeypatch): + """Setting config raises ThreadSafetyError in thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError): + dj.config.safemode = False + + def test_save_template_works_in_thread_safe_mode(self, monkeypatch, tmp_path): + """save_template is a static method and works in thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + + # Should not raise - save_template is static + config_file = tmp_path / "datajoint.json" + dj.config.save_template(str(config_file), create_secrets_dir=False) + assert config_file.exists() + + +class TestConnThreadSafe: + """Test conn() behavior in thread-safe mode.""" + + def test_conn_raises_in_thread_safe_mode(self, monkeypatch): + """conn() raises ThreadSafetyError in thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError): + dj.conn() + + +class TestSchemaThreadSafe: + """Test Schema behavior in thread-safe mode.""" + + def test_schema_raises_in_thread_safe_mode(self, monkeypatch): + """Schema() raises ThreadSafetyError in thread-safe mode without connection.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError): + dj.Schema("test_schema") + + +class TestFreeTableThreadSafe: + """Test FreeTable behavior in thread-safe mode.""" + + def test_freetable_raises_in_thread_safe_mode(self, monkeypatch): + """FreeTable() raises ThreadSafetyError in thread-safe mode without connection.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError): + dj.FreeTable("test.table") + + +class TestInstance: + """Test Instance class.""" + + def test_instance_import(self): + """Instance class is importable.""" + from datajoint import Instance + + assert Instance is not None + + def test_instance_always_allowed_in_thread_safe_mode(self, monkeypatch): + """Instance() is allowed even in thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + from datajoint import Instance + + # Instance class should be accessible + # (actual creation requires valid credentials) + assert callable(Instance) + + +class TestThreadSafetyError: + """Test ThreadSafetyError exception.""" + + def test_error_is_datajoint_error(self): + """ThreadSafetyError is a subclass of DataJointError.""" + from datajoint.errors import DataJointError, ThreadSafetyError + + assert issubclass(ThreadSafetyError, DataJointError) + + def test_error_in_exports(self): + """ThreadSafetyError is exported from datajoint.""" + import datajoint as dj + + assert hasattr(dj, "ThreadSafetyError")