diff --git a/src/py/kaleido/__init__.py b/src/py/kaleido/__init__.py index 900e40b9..3dd2e857 100644 --- a/src/py/kaleido/__init__.py +++ b/src/py/kaleido/__init__.py @@ -162,29 +162,24 @@ async def write_fig_from_object( ) +def _ensure_server() -> None: + if not _global_server.is_running(): + _global_server.open(silence_warnings=True) + + def calc_fig_sync(*args: Any, **kwargs: Any): """Call `calc_fig` but blocking.""" - if _global_server.is_running(): - return _global_server.call_function("calc_fig", *args, **kwargs) - else: - return _sync_server.oneshot_async_run(calc_fig, args=args, kwargs=kwargs) + _ensure_server() + return _global_server.call_function("calc_fig", *args, **kwargs) def write_fig_sync(*args: Any, **kwargs: Any): """Call `write_fig` but blocking.""" - if _global_server.is_running(): - return _global_server.call_function("write_fig", *args, **kwargs) - else: - return _sync_server.oneshot_async_run(write_fig, args=args, kwargs=kwargs) + _ensure_server() + _global_server.call_function("write_fig", *args, **kwargs) def write_fig_from_object_sync(*args: Any, **kwargs: Any): """Call `write_fig_from_object` but blocking.""" - if _global_server.is_running(): - return _global_server.call_function("write_fig_from_object", *args, **kwargs) - else: - return _sync_server.oneshot_async_run( - write_fig_from_object, - args=args, - kwargs=kwargs, - ) + _ensure_server() + _global_server.call_function("write_fig_from_object", *args, **kwargs) diff --git a/src/py/kaleido/_sync_server.py b/src/py/kaleido/_sync_server.py index 740bd902..57cee43a 100644 --- a/src/py/kaleido/_sync_server.py +++ b/src/py/kaleido/_sync_server.py @@ -1,9 +1,7 @@ from __future__ import annotations import asyncio -import atexit import warnings -from functools import partial from queue import Queue from threading import Thread from typing import TYPE_CHECKING, NamedTuple @@ -18,6 +16,7 @@ class Task(NamedTuple): fn: str args: Any kwargs: Any + result_queue: Queue # per-caller mailbox class _BadFunctionName(BaseException): @@ -37,13 +36,11 @@ async def _server(self, *args, **kwargs): if not hasattr(k, task.fn): raise _BadFunctionName(f"Kaleido has no attribute {task.fn}") try: - self._return_queue.put( + task.result_queue.put( await getattr(k, task.fn)(*task.args, **task.kwargs), ) except Exception as e: # noqa: BLE001 - self._return_queue.put(e) - - self._task_queue.task_done() + task.result_queue.put(e) def __new__(cls): # Create the singleton on first instantiation @@ -72,11 +69,8 @@ def open(self, *args: Any, silence_warnings=False, **kwargs: Any) -> None: daemon=True, ) self._task_queue: Queue[Task | None] = Queue() - self._return_queue: Queue[Any] = Queue() self._thread.start() self._initialized = True - close = partial(self.close, silence_warnings=True) - atexit.register(close) def close(self, *, silence_warnings=False): """Reset the singleton back to an uninitialized state.""" @@ -92,7 +86,6 @@ def close(self, *, silence_warnings=False): self._thread.join() del self._thread del self._task_queue - del self._return_queue self._initialized = False def call_function(self, cmd: str, *args: Any, **kwargs: Any): @@ -117,9 +110,9 @@ def call_function(self, cmd: str, *args: Any, **kwargs: Any): UserWarning, stacklevel=3, ) - self._task_queue.put(Task(cmd, args, kwargs)) - self._task_queue.join() - res = self._return_queue.get() + my_queue: Queue[Any] = Queue(maxsize=1) + self._task_queue.put(Task(cmd, args, kwargs, my_queue)) + res = my_queue.get() if isinstance(res, BaseException): raise res else: