diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index e5fe262..af03f53 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -147,6 +147,7 @@ def __setstate__(self, state): self._register_callbacks([]) self.add_listener(*listeners.keys()) self._engine = self._get_engine(rtc) + self._engine.start() def _get_initial_state(self): initial_state_value = self.start_value if self.start_value else self.initial_state.value diff --git a/tests/test_copy.py b/tests/test_copy.py index 2f5a981..e4276f5 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -5,11 +5,11 @@ from enum import auto import pytest +from statemachine.exceptions import TransitionNotAllowed +from statemachine.states import States from statemachine import State from statemachine import StateMachine -from statemachine.exceptions import TransitionNotAllowed -from statemachine.states import States logger = logging.getLogger(__name__) DEBUG = logging.DEBUG @@ -181,3 +181,33 @@ def test_copy_with_custom_init_and_vars(copy_method): assert sm2.custom == 1 assert sm2.value == [1, 2, 3] assert sm2.current_state == MyStateMachine.started + + +class _AsyncTrafficLightForPickleTest(StateMachine): + """Defined at module level to be picklable for test_pickle_async_statemachine.""" + + green = State(initial=True) + yellow = State() + red = State() + + cycle = green.to(yellow) | yellow.to(red) | red.to(green) + + async def on_enter_state(self, target): + pass + + +def test_pickle_async_statemachine(): + """Regression test for issue #544: async SM fails after pickle.""" + import asyncio + + sm = _AsyncTrafficLightForPickleTest() + + sm_copy = pickle.loads(pickle.dumps(sm)) + + async def verify(): + await sm_copy.activate_initial_state() # type: ignore[awaitable] + assert sm_copy.current_state == _AsyncTrafficLightForPickleTest.green + await sm_copy.cycle() + assert sm_copy.current_state == _AsyncTrafficLightForPickleTest.yellow + + asyncio.run(verify())