-
Notifications
You must be signed in to change notification settings - Fork 2.9k
enhanced MCP session management with mcp-session-id support #4466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -221,9 +221,9 @@ def __init__( | |||||
| self._connection_params = connection_params | ||||||
| self._errlog = errlog | ||||||
|
|
||||||
| # Session pool: maps session keys to (session, exit_stack, loop) tuples | ||||||
| # Session pool: maps session keys to (session, exit_stack, loop, mcp_session_id) tuples | ||||||
| self._sessions: Dict[ | ||||||
| str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop] | ||||||
| str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop, Optional[str]] | ||||||
| ] = {} | ||||||
|
|
||||||
| # Map of event loops to their respective locks to prevent race conditions | ||||||
|
|
@@ -241,39 +241,54 @@ def _session_lock(self) -> asyncio.Lock: | |||||
| return self._session_lock_map[current_loop] | ||||||
|
|
||||||
| def _generate_session_key( | ||||||
| self, merged_headers: Optional[Dict[str, str]] = None | ||||||
| self, merged_headers: Optional[Dict[str, str]] = None, | ||||||
| mcp_session_id: Optional[str] = None | ||||||
| ) -> str: | ||||||
| """Generates a session key based on connection params and merged headers. | ||||||
| """Generates a session key based on connection params, headers, and mcp-session-id. | ||||||
|
|
||||||
| For StdioConnectionParams, returns a constant key since headers are not | ||||||
| supported. For SSE and StreamableHTTP connections, generates a key based | ||||||
| on the provided merged headers. | ||||||
| on the provided merged headers and mcp-session-id if available. | ||||||
|
|
||||||
| Args: | ||||||
| merged_headers: Already merged headers (base + additional). | ||||||
| mcp_session_id: Optional MCP session ID from server initialization. | ||||||
|
|
||||||
| Returns: | ||||||
| A unique session key string. | ||||||
| """ | ||||||
| if isinstance(self._connection_params, StdioConnectionParams): | ||||||
| # For stdio connections, headers are not supported, so use constant key | ||||||
| # For stdio connections, headers are not supported | ||||||
| # But we should still include mcp_session_id if available | ||||||
| if mcp_session_id: | ||||||
| return f'stdio_session_{mcp_session_id}' | ||||||
| return 'stdio_session' | ||||||
|
|
||||||
| # For SSE and StreamableHTTP connections, use merged headers | ||||||
| # For SSE and StreamableHTTP connections, use merged headers and mcp_session_id | ||||||
| key_parts = [] | ||||||
|
|
||||||
| if mcp_session_id: | ||||||
| key_parts.append(f'mcp_id:{mcp_session_id}') | ||||||
|
|
||||||
| if merged_headers: | ||||||
| headers_json = json.dumps(merged_headers, sort_keys=True) | ||||||
| headers_hash = hashlib.md5(headers_json.encode()).hexdigest() | ||||||
| return f'session_{headers_hash}' | ||||||
| key_parts.append(f'headers:{headers_hash}') | ||||||
|
|
||||||
| if key_parts: | ||||||
| return 'session_' + '_'.join(key_parts) | ||||||
| else: | ||||||
| return 'session_no_headers' | ||||||
|
|
||||||
| def _merge_headers( | ||||||
| self, additional_headers: Optional[Dict[str, str]] = None | ||||||
| self, additional_headers: Optional[Dict[str, str]] = None, | ||||||
| mcp_session_id: Optional[str] = None | ||||||
| ) -> Optional[Dict[str, str]]: | ||||||
| """Merges base connection headers with additional headers. | ||||||
| """Merges base connection headers with additional headers and mcp-session-id. | ||||||
|
|
||||||
| Args: | ||||||
| additional_headers: Optional headers to merge with connection headers. | ||||||
| mcp_session_id: Optional MCP session ID to include as a header. | ||||||
|
|
||||||
| Returns: | ||||||
| Merged headers dictionary, or None if no headers are provided. | ||||||
|
|
@@ -293,8 +308,12 @@ def _merge_headers( | |||||
|
|
||||||
| if additional_headers: | ||||||
| base_headers.update(additional_headers) | ||||||
|
|
||||||
| # Add mcp-session-id header if available | ||||||
| if mcp_session_id: | ||||||
| base_headers['mcp-session-id'] = mcp_session_id | ||||||
|
|
||||||
| return base_headers | ||||||
| return base_headers if base_headers else None | ||||||
|
|
||||||
| def _is_session_disconnected(self, session: ClientSession) -> bool: | ||||||
| """Checks if a session is disconnected or closed. | ||||||
|
|
@@ -414,7 +433,9 @@ async def create_session( | |||||
|
|
||||||
| This method will check if an existing session for the given headers | ||||||
| is still connected. If it's disconnected, it will be cleaned up and | ||||||
| a new session will be created. | ||||||
| a new session will be created. The mcp-session-id returned by the | ||||||
| server during initialization is captured and reused in subsequent | ||||||
| requests to maintain stateful server sessions. | ||||||
|
|
||||||
| Args: | ||||||
| headers: Optional headers to include in the session. These will be | ||||||
|
|
@@ -424,17 +445,38 @@ async def create_session( | |||||
| Returns: | ||||||
| ClientSession: The initialized MCP client session. | ||||||
| """ | ||||||
| # Merge headers once at the beginning | ||||||
| merged_headers = self._merge_headers(headers) | ||||||
|
|
||||||
| # Generate session key using merged headers | ||||||
| session_key = self._generate_session_key(merged_headers) | ||||||
| # First, try to find an existing session with a stored mcp-session-id | ||||||
| # This allows us to reuse sessions across invocations | ||||||
| existing_mcp_session_id = None | ||||||
|
|
||||||
| # Check all sessions to find one with matching base characteristics | ||||||
| # (ignoring mcp-session-id in the key for this lookup) | ||||||
| temp_merged_headers = self._merge_headers(headers, mcp_session_id=None) | ||||||
| temp_key = self._generate_session_key(temp_merged_headers, mcp_session_id=None) | ||||||
|
|
||||||
| # Look for existing sessions that match the base key pattern | ||||||
| async with self._session_lock: | ||||||
| for key, (session, exit_stack, stored_loop, mcp_id) in list(self._sessions.items()): | ||||||
| # Check if this session matches our connection params (ignoring mcp-session-id) | ||||||
| if key.startswith(temp_key.split('_mcp_id')[0]) and mcp_id: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic for finding a matching session is incorrect. With the suggested change to
Suggested change
|
||||||
| current_loop = asyncio.get_running_loop() | ||||||
| if stored_loop is current_loop and not self._is_session_disconnected(session): | ||||||
| # Found a valid session with mcp-session-id, use it | ||||||
| existing_mcp_session_id = mcp_id | ||||||
| logger.debug('Reusing existing mcp-session-id: %s', existing_mcp_session_id) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| break | ||||||
|
|
||||||
| # Merge headers with mcp-session-id if we have one | ||||||
| merged_headers = self._merge_headers(headers, mcp_session_id=existing_mcp_session_id) | ||||||
|
|
||||||
| # Generate session key using merged headers and mcp-session-id | ||||||
| session_key = self._generate_session_key(merged_headers, mcp_session_id=existing_mcp_session_id) | ||||||
|
|
||||||
| # Use async lock to prevent race conditions | ||||||
| async with self._session_lock: | ||||||
| # Check if we have an existing session | ||||||
| # Check if we have an existing session with this exact key | ||||||
| if session_key in self._sessions: | ||||||
| session, exit_stack, stored_loop = self._sessions[session_key] | ||||||
| session, exit_stack, stored_loop, mcp_id = self._sessions[session_key] | ||||||
|
|
||||||
| # Check if the existing session is still connected and bound to the current loop | ||||||
| current_loop = asyncio.get_running_loop() | ||||||
|
|
@@ -468,7 +510,7 @@ async def create_session( | |||||
| client = self._create_client(merged_headers) | ||||||
| is_stdio = isinstance(self._connection_params, StdioConnectionParams) | ||||||
|
|
||||||
| session = await asyncio.wait_for( | ||||||
| session_context = await asyncio.wait_for( | ||||||
| exit_stack.enter_async_context( | ||||||
| SessionContext( | ||||||
| client=client, | ||||||
|
|
@@ -479,12 +521,28 @@ async def create_session( | |||||
| ), | ||||||
| timeout=timeout_in_seconds, | ||||||
| ) | ||||||
|
|
||||||
| # Store session, exit stack, and loop in the pool | ||||||
|
|
||||||
| session = session_context.session | ||||||
|
|
||||||
| # Extract mcp-session-id from the session context if available | ||||||
| new_mcp_session_id = session_context.mcp_session_id | ||||||
|
|
||||||
| # If we got a new mcp-session-id, update the session key and re-store | ||||||
| if new_mcp_session_id and new_mcp_session_id != existing_mcp_session_id: | ||||||
| logger.info('Received new mcp-session-id from server: %s', new_mcp_session_id) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| # Remove old session entry if it exists | ||||||
| if session_key in self._sessions: | ||||||
| del self._sessions[session_key] | ||||||
| # Generate new key with the mcp-session-id | ||||||
| merged_headers_with_id = self._merge_headers(headers, mcp_session_id=new_mcp_session_id) | ||||||
| session_key = self._generate_session_key(merged_headers_with_id, mcp_session_id=new_mcp_session_id) | ||||||
|
|
||||||
| # Store session, exit stack, loop, and mcp_session_id in the pool | ||||||
| self._sessions[session_key] = ( | ||||||
| session, | ||||||
| exit_stack, | ||||||
| asyncio.get_running_loop(), | ||||||
| new_mcp_session_id or existing_mcp_session_id, | ||||||
| ) | ||||||
| logger.debug('Created new session: %s', session_key) | ||||||
| return session | ||||||
|
|
@@ -504,7 +562,7 @@ async def close(self): | |||||
| """Closes all sessions and cleans up resources.""" | ||||||
| async with self._session_lock: | ||||||
| for session_key in list(self._sessions.keys()): | ||||||
| _, exit_stack, stored_loop = self._sessions[session_key] | ||||||
| _, exit_stack, stored_loop, _ = self._sessions[session_key] | ||||||
| await self._cleanup_session(session_key, exit_stack, stored_loop) | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -394,6 +394,23 @@ async def close(self) -> None: | |||||
| # Log the error but don't re-raise to avoid blocking shutdown | ||||||
| print(f"Warning: Error during McpToolset cleanup: {e}", file=self._errlog) | ||||||
|
|
||||||
| def get_session_info(self) -> dict[str, Any]: | ||||||
| """Returns information about the current MCP session state. | ||||||
|
|
||||||
| This is useful for debugging and understanding session reuse. | ||||||
|
|
||||||
| Returns: | ||||||
| Dictionary with session information including active session keys | ||||||
| and mcp-session-id values. | ||||||
| """ | ||||||
| session_info = {} | ||||||
| for key, (_, _, _, mcp_id) in self._mcp_session_manager._sessions.items(): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accessing the private member |
||||||
| session_info[key] = { | ||||||
| 'mcp_session_id': mcp_id, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| 'active': True | ||||||
| } | ||||||
| return session_info | ||||||
|
|
||||||
| @override | ||||||
| def get_auth_config(self) -> Optional[AuthConfig]: | ||||||
| """Returns the auth config for this toolset. | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,6 +69,7 @@ def __init__( | |
| self._sse_read_timeout = sse_read_timeout | ||
| self._is_stdio = is_stdio | ||
| self._session: Optional[ClientSession] = None | ||
| self._mcp_session_id: Optional[str] = None | ||
| self._ready_event = asyncio.Event() | ||
| self._close_event = asyncio.Event() | ||
| self._task: Optional[asyncio.Task] = None | ||
|
|
@@ -79,6 +80,11 @@ def session(self) -> Optional[ClientSession]: | |
| """Get the managed ClientSession, if available.""" | ||
| return self._session | ||
|
|
||
| @property | ||
| def mcp_session_id(self) -> Optional[str]: | ||
| """Get the MCP session ID returned by the server during initialization.""" | ||
| return self._mcp_session_id | ||
|
|
||
| async def start(self) -> ClientSession: | ||
| """Start the runner and wait for the session to be ready. | ||
|
|
||
|
|
@@ -178,8 +184,21 @@ async def _run(self): | |
| else None, | ||
| ) | ||
| ) | ||
| await asyncio.wait_for(session.initialize(), timeout=self._timeout) | ||
|
|
||
| # Initialize the session and capture the response | ||
| init_result = await asyncio.wait_for( | ||
| session.initialize(), timeout=self._timeout | ||
| ) | ||
| logger.debug('Session has been successfully initialized') | ||
|
|
||
| # Extract mcp-session-id from initialization result if present | ||
| # The MCP protocol returns this in the InitializeResult's meta field | ||
| if hasattr(init_result, '_meta') and init_result._meta: | ||
| self._mcp_session_id = init_result._meta.get('mcp-session-id') | ||
| if self._mcp_session_id: | ||
| logger.debug( | ||
| 'Captured mcp-session-id from server: %s', self._mcp_session_id | ||
| ) | ||
|
Comment on lines
+199
to
+201
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| self._session = session | ||
| self._ready_event.set() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To enable correct and simple session reuse logic in
create_session, the constant part of the session key (derived from headers) should precede themcp_session_idpart. This allows for simple prefix matching. Please swap the order of theseifblocks.