diff --git a/src/replit_river/rate_limiter.py b/src/replit_river/rate_limiter.py index 5e742ce9..e7593e19 100644 --- a/src/replit_river/rate_limiter.py +++ b/src/replit_river/rate_limiter.py @@ -2,6 +2,7 @@ import logging import random from contextvars import Context +from typing import Protocol from replit_river.error_schema import RiverException from replit_river.transport_options import ConnectionRetryOptions @@ -15,6 +16,13 @@ def __init__(self, code: str, message: str, client_id: str) -> None: self.client_id = client_id +class RateLimiter(Protocol): + def start_restoring_budget(self, user: str) -> None: ... + def get_backoff_ms(self, user: str) -> float: ... + def has_budget(self, user: str) -> bool: ... + def consume_budget(self, user: str) -> None: ... + + class LeakyBucketRateLimit: """Asynchronous leaky bucket rate limiter. diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 3dc96522..2bfc60ee 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -54,8 +54,10 @@ async def get_or_create_session(self) -> Session: call ensure_connected on whatever session is active. """ existing_session = self._session - if not existing_session or existing_session.is_closed(): + if not existing_session or existing_session.is_terminal(): logger.info("Creating new session") + if existing_session: + await existing_session.close() new_session = Session( client_id=self._client_id, server_id=self._server_id, @@ -80,7 +82,7 @@ async def _retry_connection(self) -> Session: logger.debug("Triggering get_or_create_session") return await self.get_or_create_session() - async def _delete_session(self, session: Session) -> None: + def _delete_session(self, session: Session) -> None: if self._session is session: self._session = None else: diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 46ab151d..3fe948e8 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -26,6 +26,7 @@ from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import ValidationError +from websockets import ConnectionClosedOK from websockets.asyncio.client import ClientConnection from websockets.exceptions import ConnectionClosed @@ -54,7 +55,7 @@ parse_transport_msg, send_transport_message, ) -from replit_river.rate_limiter import LeakyBucketRateLimit +from replit_river.rate_limiter import RateLimiter from replit_river.rpc import ( ACK_BIT, STREAM_OPEN_BIT, @@ -83,6 +84,8 @@ STREAM_CLOSED_BIT: STREAM_CLOSED_BIT_TYPE = 0b01000 +SESSION_CLOSE_TIMEOUT_SEC = 2 + _BackpressuredWaiter: TypeAlias = Callable[[], Awaitable[None]] @@ -109,7 +112,7 @@ class ResultError(TypedDict): trace_propagator = TraceContextTextMapPropagator() trace_setter = TransportMessageTracingSetter() -CloseSessionCallback: TypeAlias = Callable[["Session"], Coroutine[Any, Any, Any]] +CloseSessionCallback: TypeAlias = Callable[["Session"], None] RetryConnectionCallback: TypeAlias = Callable[ [], Coroutine[Any, Any, Any], @@ -141,7 +144,7 @@ class Session[HandshakeMetadata]: _wait_for_connected: asyncio.Event _client_id: str - _rate_limiter: LeakyBucketRateLimit + _rate_limiter: RateLimiter _uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] ] @@ -166,6 +169,7 @@ class Session[HandshakeMetadata]: # Terminating _terminating_task: asyncio.Task[None] | None + _closing_waiter: asyncio.Event | None def __init__( self, @@ -174,7 +178,7 @@ def __init__( transport_options: TransportOptions, close_session_callback: CloseSessionCallback, client_id: str, - rate_limiter: LeakyBucketRateLimit, + rate_limiter: RateLimiter, uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] ], @@ -225,6 +229,7 @@ def __init__( # Terminating self._terminating_task = None + self._closing_waiter = None self._start_recv_from_ws() self._start_buffered_message_sender() @@ -247,25 +252,7 @@ def get_next_sent_seq() -> int: return self._send_buffer[0].seq return self.seq - def close_session(reason: Exception | None) -> None: - # If we're already closing, just let whoever's currently doing it handle it. - if self._state in TerminalStates: - return - - # Avoid closing twice - if self._terminating_task is None: - current_state = self._state - self._state = SessionState.CLOSING - - # We can't just call self.close() directly because - # we're inside a thread that will eventually be awaited - # during the cleanup procedure. - - self._terminating_task = asyncio.create_task( - self.close(reason, current_state=current_state), - ) - - def transition_connecting() -> None: + def transition_connecting(ws: ClientConnection) -> None: if self._state in TerminalStates: return logger.debug("transition_connecting") @@ -273,6 +260,9 @@ def transition_connecting() -> None: # "Clear" here means observers should wait until we are connected. self._wait_for_connected.clear() + # Expose the current ws to be collected by close() + self._ws = ws + def transition_connected(ws: ClientConnection) -> None: if self._state in TerminalStates: return @@ -293,11 +283,11 @@ def unbind_connecting_task() -> None: # This is safe because each individual function that is waiting on this # function completeing already has a reference, so we'll last a few ticks # before GC. - # - # Let's do our best to avoid clobbering other tasks by comparing the .name current_task = asyncio.current_task() if self._connecting_task is current_task: self._connecting_task = None + else: + logger.debug("unbind_connecting_task failed, id did not match") if not self._connecting_task: self._connecting_task = asyncio.create_task( @@ -316,13 +306,15 @@ def unbind_connecting_task() -> None: close_ws_in_background=close_ws_in_background, transition_connected=transition_connected, unbind_connecting_task=unbind_connecting_task, - close_session=close_session, + close_session=self._close_internal_nowait, ) ) await self._connecting_task + if self._terminating_task: + await self._terminating_task - def is_closed(self) -> bool: + def is_terminal(self) -> bool: """ If the session is in a terminal state. Do not send messages, do not expect any more messages to be emitted, @@ -397,80 +389,127 @@ async def _enqueue_message( self._process_messages.set() async def close( - self, reason: Exception | None = None, current_state: SessionState | None = None + self, + reason: Exception | None = None, ) -> None: """Close the session and all associated streams.""" - logger.info( - f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}" - ) - if (current_state or self._state) in TerminalStates: - # already closing + if self._closing_waiter: + try: + logger.debug("Session already closing, waiting...") + async with asyncio.timeout(SESSION_CLOSE_TIMEOUT_SEC): + await self._closing_waiter.wait() + except asyncio.TimeoutError: + logger.warning( + f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} " + "seconds to close, leaking", + ) return - self._state = SessionState.CLOSING + await self._close_internal(reason) - # We're closing, so we need to wake up... - # ... tasks waiting for connection to be established - self._wait_for_connected.set() - # ... consumers waiting to enqueue messages - self._space_available.set() - # ... message processor so it can exit cleanly - self._process_messages.set() + def _close_internal_nowait(self, reason: Exception | None = None) -> None: + """ + When calling close() from asyncio Tasks, we must not block. + + This function does so, deferring to the underlying infrastructure for + creating self._terminating_task. + """ + self._close_internal(reason) + + def _close_internal(self, reason: Exception | None = None) -> asyncio.Task[None]: + """ + Internal close method. Subsequent calls past the first do not block. + + This is intended to be the primary driver of a session being torn down + and returned to its initial state. + + NB: This function is intended to be the sole lifecycle manager of + self._terminating_task. Waiting on the completion of that task is optional, + but the population of that property is critical. + + NB: We must not await the task returned from this function from chained tasks + inside this session, otherwise we will create a thread loop. + """ - # Wait a tick to permit the waiting tasks to shut down gracefully - await asyncio.sleep(0.01) + async def do_close() -> None: + logger.info( + f"{self.session_id} closing session to {self._server_id}, " + f"ws: {self._ws}" + ) + self._state = SessionState.CLOSING + self._closing_waiter = asyncio.Event() + + # We're closing, so we need to wake up... + # ... tasks waiting for connection to be established + self._wait_for_connected.set() + # ... consumers waiting to enqueue messages + self._space_available.set() + # ... message processor so it can exit cleanly + self._process_messages.set() + + # Wait to permit the waiting tasks to shut down gracefully + await asyncio.sleep(0.25) - await self._task_manager.cancel_all_tasks() + await self._task_manager.cancel_all_tasks() - for stream_meta in self._streams.values(): - stream_meta["output"].close() - # Wake up backpressured writers + for stream_meta in self._streams.values(): + stream_meta["output"].close() + # Wake up backpressured writers + try: + stream_meta["error_channel"].put_nowait( + reason + or SessionClosedRiverServiceException( + "river session is closed", + ) + ) + except ChannelFull: + logger.exception( + "Unable to tell the caller that the session is going away", + ) + stream_meta["release_backpressured_waiter"]() + # Before we GC the streams, let's wait for all tasks to be closed gracefully try: - stream_meta["error_channel"].put_nowait( - reason - or SessionClosedRiverServiceException( - "river session is closed", + async with asyncio.timeout( + self._transport_options.shutdown_all_streams_timeout_ms + ): + # Block for backpressure and emission errors from the ws + await asyncio.gather( + *[ + stream_meta["output"].join() + for stream_meta in self._streams.values() + ] ) - ) - except ChannelFull: + except asyncio.TimeoutError: + spans: list[Span] = [ + stream_meta["span"] + for stream_meta in self._streams.values() + if not stream_meta["output"].closed() + ] + span_ids = [span.get_span_context().span_id for span in spans] logger.exception( - "Unable to tell the caller that the session is going away", + "Timeout waiting for output streams to finallize", + extra={"span_ids": span_ids}, ) - stream_meta["release_backpressured_waiter"]() - # Before we GC the streams, let's wait for all tasks to be closed gracefully. - try: - async with asyncio.timeout( - self._transport_options.shutdown_all_streams_timeout_ms - ): - # Block for backpressure and emission errors from the ws - await asyncio.gather( - *[ - stream_meta["output"].join() - for stream_meta in self._streams.values() - ] - ) - except asyncio.TimeoutError: - spans: list[Span] = [ - stream_meta["span"] - for stream_meta in self._streams.values() - if not stream_meta["output"].closed() - ] - span_ids = [span.get_span_context().span_id for span in spans] - logger.exception( - "Timeout waiting for output streams to finallize", - extra={"span_ids": span_ids}, - ) - self._streams.clear() + self._streams.clear() - if self._ws: - # The Session isn't guaranteed to live much longer than this close() - # invocation, so let's await this close to avoid dropping the socket. - await self._ws.close() + if self._ws: + # The Session isn't guaranteed to live much longer than this close() + # invocation, so let's await this close to avoid dropping the socket. + await self._ws.close() + + self._state = SessionState.CLOSED + + # Clear the session in transports + # This will get us GC'd, so this should be the last thing. + self._close_session_callback(self) - self._state = SessionState.CLOSED + # Release waiters, then release the event + self._closing_waiter.set() + self._closing_waiter = None - # Clear the session in transports - # This will get us GC'd, so this should be the last thing. - await self._close_session_callback(self) + if self._terminating_task: + return self._terminating_task + + return asyncio.create_task(do_close()) def _start_buffered_message_sender( self, @@ -479,7 +518,8 @@ def _start_buffered_message_sender( Building on buffered_message_sender's documentation, we implement backpressure per-stream by way of self._streams' - error_channel: Channel[Exception | None] + error_channel: Channel[Exception] + backpressured_waiter: Callable[[], Awaitable[None]] This is accomplished via the following strategy: - If buffered_message_sender encounters an error, we transition back to @@ -491,8 +531,11 @@ def _start_buffered_message_sender( - Alternately, if buffered_message_sender successfully writes back to the - Finally, if _recv_from_ws encounters an error (transport or deserialization), - we emit an informative error to close_session which gets emitted to all - backpressured client methods. + it transitions to NO_CONNECTION and defers to the client_transport to + reestablish a connection. + + The in-flight messages are still valid, as if we can reconnect to the server + in time, those responses can be marshalled to their respective callbacks. """ async def commit(msg: TransportMessage) -> None: @@ -609,7 +652,7 @@ async def block_until_connected() -> None: get_state=lambda: self._state, get_ws=lambda: self._ws, transition_no_connection=transition_no_connection, - close_session=self.close, + close_session=self._close_internal_nowait, assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, get_stream=lambda stream_id: self._streams.get(stream_id), enqueue_message=self._enqueue_message, @@ -774,7 +817,7 @@ async def send_upload[I, R, A]( # If this request is not closed and the session is killed, we should # throw exception here async for item in request: - # Block for backpressure and emission errors from the ws + # Block for backpressure await backpressured_waiter() try: payload = request_serializer(item) @@ -935,9 +978,9 @@ async def _encode_stream() -> None: assert request_serializer, "send_stream missing request_serializer" async for item in request: - # Block for backpressure (or errors) + # Block for backpressure await backpressured_waiter() - # If there are any errors so far, raise them + await self._enqueue_message( stream_id=stream_id, control_flags=0, @@ -1015,7 +1058,7 @@ async def _do_ensure_connected[HandshakeMetadata]( client_id: str, session_id: str, server_id: str, - rate_limiter: LeakyBucketRateLimit, + rate_limiter: RateLimiter, uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] ], @@ -1023,7 +1066,7 @@ async def _do_ensure_connected[HandshakeMetadata]( get_next_sent_seq: Callable[[], int], get_current_ack: Callable[[], int], get_state: Callable[[], SessionState], - transition_connecting: Callable[[], None], + transition_connecting: Callable[[ClientConnection], None], close_ws_in_background: Callable[[ClientConnection], None], transition_connected: Callable[[ClientConnection], None], unbind_connecting_task: Callable[[], None], @@ -1043,12 +1086,12 @@ async def _do_ensure_connected[HandshakeMetadata]( attempt_count += 1 rate_limiter.consume_budget(client_id) - transition_connecting() ws: ClientConnection | None = None try: uri_and_metadata = await uri_and_metadata_factory() ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"]) + transition_connecting(ws) try: handshake_request = ControlMessageHandshakeRequest[HandshakeMetadata]( @@ -1100,6 +1143,12 @@ async def websocket_closed_callback() -> None: try: data = await ws.recv(decode=False) + except ConnectionClosedOK: + # In the case of a normal connection closure, we defer to + # the outer loop to determine next steps. + # A call to close(...) should set the SessionState to a terminal one, + # otherwise we should try again. + continue except ConnectionClosed as e: logger.debug( "_do_ensure_connected: Connection closed during waiting " @@ -1150,21 +1199,23 @@ async def websocket_closed_callback() -> None: raise err + logger.debug("Connected") # We did it! We're connected! last_error = None rate_limiter.start_restoring_budget(client_id) transition_connected(ws) break except Exception as e: - if ws: - close_ws_in_background(ws) - ws = None - last_error = e backoff_time = rate_limiter.get_backoff_ms(client_id) logger.exception( f"Error connecting, retrying with {backoff_time}ms backoff" ) + if ws: + close_ws_in_background(ws) + ws = None + last_error = e await asyncio.sleep(backoff_time / 1000) + logger.debug("Here, about to retry") unbind_connecting_task() if last_error is not None: @@ -1184,7 +1235,7 @@ async def _recv_from_ws( get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], transition_no_connection: Callable[[], Awaitable[None]], - close_session: Callable[[Exception | None], Awaitable[None]], + close_session: Callable[[Exception | None], None], assert_incoming_seq_bookkeeping: Callable[ [str, int, int], Literal[True] | _IgnoreMessage ], @@ -1200,8 +1251,8 @@ async def _recv_from_ws( """ our_task = asyncio.current_task() connection_attempts = 0 - try: - while our_task and not our_task.cancelling() and not our_task.cancelled(): + while our_task and not our_task.cancelling() and not our_task.cancelled(): + try: logger.debug(f"_recv_from_ws loop count={connection_attempts}") connection_attempts += 1 ws = None @@ -1232,6 +1283,9 @@ async def _recv_from_ws( # is no @overrides in `websockets` to hint this. try: message = await ws.recv(decode=False) + except ConnectionClosedOK as e: + close_session(e) + continue except ConnectionClosed: # This triggers a break in the inner loop so we can get back to # the outer loop. @@ -1316,7 +1370,7 @@ async def _recv_from_ws( stream_meta["output"].close() except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") - await close_session( + close_session( SessionClosedRiverServiceException( "Out of order message, closing connection" ) @@ -1326,7 +1380,7 @@ async def _recv_from_ws( logger.exception( "Got invalid transport message, closing session", ) - await close_session( + close_session( SessionClosedRiverServiceException( "Out of order message, closing connection" ) @@ -1338,21 +1392,22 @@ async def _recv_from_ws( logger.debug( "FailedSendingMessageException while serving", exc_info=True ) - break + break # Inner loop except Exception: logger.exception("caught exception at message iterator") - break + await transition_no_connection() + break # Inner loop logger.debug("_handle_messages_from_ws exiting") - except ExceptionGroup as eg: - _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) - if unhandled: - # We're in a task, there's not that much that can be done. - unhandled = ExceptionGroup( - "Unhandled exceptions on River server", unhandled.exceptions - ) - logger.exception( - "caught exception at message iterator", - exc_info=unhandled, - ) - raise unhandled + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + # We're in a task, there's not that much that can be done. + unhandled = ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + logger.exception( + "caught exception at message iterator", + exc_info=unhandled, + ) + raise unhandled logger.debug(f"_recv_from_ws exiting normally after {connection_attempts} loops") diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py new file mode 100644 index 00000000..fb5982cc --- /dev/null +++ b/tests/v2/test_v2_session_lifecycle.py @@ -0,0 +1,135 @@ +import asyncio +from typing import AsyncIterator, Awaitable, Callable, TypeAlias, TypedDict + +import pytest +from websockets import ConnectionClosedOK +from websockets.asyncio.server import ServerConnection, serve +from websockets.typing import Data + +from replit_river.messages import parse_transport_msg +from replit_river.rate_limiter import RateLimiter +from replit_river.rpc import TransportMessage +from replit_river.transport_options import TransportOptions, UriAndMetadata +from replit_river.v2.session import Session + + +class _PermissiveRateLimiter(RateLimiter): + def start_restoring_budget(self, user: str) -> None: + pass + + def get_backoff_ms(self, user: str) -> float: + return 0 + + def has_budget(self, user: str) -> bool: + return True + + def consume_budget(self, user: str) -> None: + pass + + +WsServerFixture: TypeAlias = tuple[ + Callable[[], Awaitable[UriAndMetadata[None]]], + asyncio.Queue[bytes], + Callable[[], ServerConnection | None], +] + + +class _WsServerState(TypedDict): + ipv4_laddr: tuple[str, int] | None + + +async def _ws_server_internal( + recv: asyncio.Queue[bytes], + set_conn: Callable[[ServerConnection], None], + state: _WsServerState, +) -> AsyncIterator[None]: + async def handle(websocket: ServerConnection) -> None: + set_conn(websocket) + datagram: Data + try: + while datagram := await websocket.recv(decode=False): + if isinstance(datagram, str): + continue + await recv.put(datagram) + except ConnectionClosedOK: + pass + + port: int | None = None + if state["ipv4_laddr"]: + port = state["ipv4_laddr"][1] + async with serve(handle, "localhost", port=port) as server: + for sock in server.sockets: + if (pair := sock.getsockname())[0] == "127.0.0.1": + if state["ipv4_laddr"] is None: + state["ipv4_laddr"] = pair + serve_forever = asyncio.create_task(server.serve_forever()) + yield None + serve_forever.cancel() + + +@pytest.fixture +async def ws_server() -> AsyncIterator[WsServerFixture]: + recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1) + connection: ServerConnection | None = None + state: _WsServerState = {"ipv4_laddr": None} + + def set_conn(new_conn: ServerConnection) -> None: + nonlocal connection + connection = new_conn + + server_generator = _ws_server_internal(recv, set_conn, state) + await anext(server_generator) + + async def urimeta() -> UriAndMetadata[None]: + ipv4_laddr = state["ipv4_laddr"] + assert ipv4_laddr + return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None) + + yield (urimeta, recv, lambda: connection) + + try: + await anext(server_generator) + except StopAsyncIteration: + pass + + +async def test_connect(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + session = Session( + server_id="SERVER", + session_id="SESSION1", + transport_options=TransportOptions(), + close_session_callback=lambda _: None, + client_id="CLIENT1", + rate_limiter=_PermissiveRateLimiter(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(session.ensure_connected()) + msg = parse_transport_msg(await recv.get()) + assert isinstance(msg, TransportMessage) + assert msg.payload["type"] == "HANDSHAKE_REQ" + await session.close() + await connecting + + +async def test_reconnect(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + session = Session( + server_id="SERVER", + session_id="SESSION1", + transport_options=TransportOptions(), + close_session_callback=lambda _: None, + client_id="CLIENT1", + rate_limiter=_PermissiveRateLimiter(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(session.ensure_connected()) + msg = parse_transport_msg(await recv.get()) + assert isinstance(msg, TransportMessage) + assert msg.payload["type"] == "HANDSHAKE_REQ" + await session.close() + await connecting