@@ -117,10 +117,11 @@ async def _get_or_create_session(
117117 session_id : str ,
118118 websocket : WebSocketCommonProtocol ,
119119 ) -> ServerSession :
120+ new_session : ServerSession | None = None
121+ old_session : ServerSession | None = None
120122 async with self ._session_lock :
121- session_to_close : Session | None = None
122- new_session : ServerSession | None = None
123- if to_id not in self ._sessions :
123+ old_session = self ._sessions .get (to_id )
124+ if not old_session :
124125 logger .info (
125126 'Creating new session with "%s" using ws: %s' , to_id , websocket .id
126127 )
@@ -134,7 +135,6 @@ async def _get_or_create_session(
134135 close_session_callback = self ._delete_session ,
135136 )
136137 else :
137- old_session = self ._sessions [to_id ]
138138 if old_session .session_id != session_id :
139139 logger .info (
140140 'Create new session with "%s" for session id %s'
@@ -143,7 +143,6 @@ async def _get_or_create_session(
143143 session_id ,
144144 old_session .session_id ,
145145 )
146- session_to_close = old_session
147146 new_session = ServerSession (
148147 transport_id ,
149148 to_id ,
@@ -167,10 +166,12 @@ async def _get_or_create_session(
167166 except FailedSendingMessageException as e :
168167 raise e
169168
170- if session_to_close :
171- logger .info ("Closing stale session %s" , session_to_close .session_id )
172- await session_to_close .close ()
173169 self ._sessions [new_session ._to_id ] = new_session
170+
171+ if old_session and new_session != old_session :
172+ logger .info ("Closing stale session %s" , old_session .session_id )
173+ await old_session .close ()
174+
174175 return new_session
175176
176177 async def _send_handshake_response (
@@ -247,7 +248,7 @@ async def _establish_handshake(
247248 raise InvalidMessageException ("handshake request to wrong server" )
248249
249250 async with self ._session_lock :
250- old_session = self ._sessions .get (request_message .from_ , None )
251+ old_session = self ._sessions .get (request_message .from_ )
251252 client_next_expected_seq = (
252253 handshake_request .expectedSessionState .nextExpectedSeq
253254 )
@@ -285,10 +286,6 @@ async def _establish_handshake(
285286 )
286287 raise SessionStateMismatchException (message )
287288 elif old_session :
288- # we have an old session but the session id is different
289- # just delete the old session
290- await old_session .close ()
291- await self ._delete_session (old_session )
292289 old_session = None
293290
294291 if not old_session and (
0 commit comments