diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 0e9b938609..9f2d482e34 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -221,9 +221,9 @@ def __init__( self._connection_params = connection_params self._errlog = errlog - # Session pool: maps session keys to (session, exit_stack, loop) tuples + # Session pool: maps session keys to (session, exit_stack, loop, mcp_session_id) tuples self._sessions: Dict[ - str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop] + str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop, Optional[str]] ] = {} # Map of event loops to their respective locks to prevent race conditions @@ -241,39 +241,54 @@ def _session_lock(self) -> asyncio.Lock: return self._session_lock_map[current_loop] def _generate_session_key( - self, merged_headers: Optional[Dict[str, str]] = None + self, merged_headers: Optional[Dict[str, str]] = None, + mcp_session_id: Optional[str] = None ) -> str: - """Generates a session key based on connection params and merged headers. + """Generates a session key based on connection params, headers, and mcp-session-id. For StdioConnectionParams, returns a constant key since headers are not supported. For SSE and StreamableHTTP connections, generates a key based - on the provided merged headers. + on the provided merged headers and mcp-session-id if available. Args: merged_headers: Already merged headers (base + additional). + mcp_session_id: Optional MCP session ID from server initialization. Returns: A unique session key string. """ if isinstance(self._connection_params, StdioConnectionParams): - # For stdio connections, headers are not supported, so use constant key + # For stdio connections, headers are not supported + # But we should still include mcp_session_id if available + if mcp_session_id: + return f'stdio_session_{mcp_session_id}' return 'stdio_session' - # For SSE and StreamableHTTP connections, use merged headers + # For SSE and StreamableHTTP connections, use merged headers and mcp_session_id + key_parts = [] + + if mcp_session_id: + key_parts.append(f'mcp_id:{mcp_session_id}') + if merged_headers: headers_json = json.dumps(merged_headers, sort_keys=True) headers_hash = hashlib.md5(headers_json.encode()).hexdigest() - return f'session_{headers_hash}' + key_parts.append(f'headers:{headers_hash}') + + if key_parts: + return 'session_' + '_'.join(key_parts) else: return 'session_no_headers' def _merge_headers( - self, additional_headers: Optional[Dict[str, str]] = None + self, additional_headers: Optional[Dict[str, str]] = None, + mcp_session_id: Optional[str] = None ) -> Optional[Dict[str, str]]: - """Merges base connection headers with additional headers. + """Merges base connection headers with additional headers and mcp-session-id. Args: additional_headers: Optional headers to merge with connection headers. + mcp_session_id: Optional MCP session ID to include as a header. Returns: Merged headers dictionary, or None if no headers are provided. @@ -293,8 +308,12 @@ def _merge_headers( if additional_headers: base_headers.update(additional_headers) + + # Add mcp-session-id header if available + if mcp_session_id: + base_headers['mcp-session-id'] = mcp_session_id - return base_headers + return base_headers if base_headers else None def _is_session_disconnected(self, session: ClientSession) -> bool: """Checks if a session is disconnected or closed. @@ -414,7 +433,9 @@ async def create_session( This method will check if an existing session for the given headers is still connected. If it's disconnected, it will be cleaned up and - a new session will be created. + a new session will be created. The mcp-session-id returned by the + server during initialization is captured and reused in subsequent + requests to maintain stateful server sessions. Args: headers: Optional headers to include in the session. These will be @@ -424,17 +445,38 @@ async def create_session( Returns: ClientSession: The initialized MCP client session. """ - # Merge headers once at the beginning - merged_headers = self._merge_headers(headers) - - # Generate session key using merged headers - session_key = self._generate_session_key(merged_headers) + # First, try to find an existing session with a stored mcp-session-id + # This allows us to reuse sessions across invocations + existing_mcp_session_id = None + + # Check all sessions to find one with matching base characteristics + # (ignoring mcp-session-id in the key for this lookup) + temp_merged_headers = self._merge_headers(headers, mcp_session_id=None) + temp_key = self._generate_session_key(temp_merged_headers, mcp_session_id=None) + + # Look for existing sessions that match the base key pattern + async with self._session_lock: + for key, (session, exit_stack, stored_loop, mcp_id) in list(self._sessions.items()): + # Check if this session matches our connection params (ignoring mcp-session-id) + if key.startswith(temp_key.split('_mcp_id')[0]) and mcp_id: + current_loop = asyncio.get_running_loop() + if stored_loop is current_loop and not self._is_session_disconnected(session): + # Found a valid session with mcp-session-id, use it + existing_mcp_session_id = mcp_id + logger.debug('Reusing existing mcp-session-id: %s', existing_mcp_session_id) + break + + # Merge headers with mcp-session-id if we have one + merged_headers = self._merge_headers(headers, mcp_session_id=existing_mcp_session_id) + + # Generate session key using merged headers and mcp-session-id + session_key = self._generate_session_key(merged_headers, mcp_session_id=existing_mcp_session_id) # Use async lock to prevent race conditions async with self._session_lock: - # Check if we have an existing session + # Check if we have an existing session with this exact key if session_key in self._sessions: - session, exit_stack, stored_loop = self._sessions[session_key] + session, exit_stack, stored_loop, mcp_id = self._sessions[session_key] # Check if the existing session is still connected and bound to the current loop current_loop = asyncio.get_running_loop() @@ -468,7 +510,7 @@ async def create_session( client = self._create_client(merged_headers) is_stdio = isinstance(self._connection_params, StdioConnectionParams) - session = await asyncio.wait_for( + session_context = await asyncio.wait_for( exit_stack.enter_async_context( SessionContext( client=client, @@ -479,12 +521,28 @@ async def create_session( ), timeout=timeout_in_seconds, ) - - # Store session, exit stack, and loop in the pool + + session = session_context.session + + # Extract mcp-session-id from the session context if available + new_mcp_session_id = session_context.mcp_session_id + + # If we got a new mcp-session-id, update the session key and re-store + if new_mcp_session_id and new_mcp_session_id != existing_mcp_session_id: + logger.info('Received new mcp-session-id from server: %s', new_mcp_session_id) + # Remove old session entry if it exists + if session_key in self._sessions: + del self._sessions[session_key] + # Generate new key with the mcp-session-id + merged_headers_with_id = self._merge_headers(headers, mcp_session_id=new_mcp_session_id) + session_key = self._generate_session_key(merged_headers_with_id, mcp_session_id=new_mcp_session_id) + + # Store session, exit stack, loop, and mcp_session_id in the pool self._sessions[session_key] = ( session, exit_stack, asyncio.get_running_loop(), + new_mcp_session_id or existing_mcp_session_id, ) logger.debug('Created new session: %s', session_key) return session @@ -504,7 +562,7 @@ async def close(self): """Closes all sessions and cleans up resources.""" async with self._session_lock: for session_key in list(self._sessions.keys()): - _, exit_stack, stored_loop = self._sessions[session_key] + _, exit_stack, stored_loop, _ = self._sessions[session_key] await self._cleanup_session(session_key, exit_stack, stored_loop) diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index fb4e992dfd..712956bdc4 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -394,6 +394,23 @@ async def close(self) -> None: # Log the error but don't re-raise to avoid blocking shutdown print(f"Warning: Error during McpToolset cleanup: {e}", file=self._errlog) + def get_session_info(self) -> dict[str, Any]: + """Returns information about the current MCP session state. + + This is useful for debugging and understanding session reuse. + + Returns: + Dictionary with session information including active session keys + and mcp-session-id values. + """ + session_info = {} + for key, (_, _, _, mcp_id) in self._mcp_session_manager._sessions.items(): + session_info[key] = { + 'mcp_session_id': mcp_id, + 'active': True + } + return session_info + @override def get_auth_config(self) -> Optional[AuthConfig]: """Returns the auth config for this toolset. diff --git a/src/google/adk/tools/mcp_tool/session_context.py b/src/google/adk/tools/mcp_tool/session_context.py index ca637d0489..5255c95e0a 100644 --- a/src/google/adk/tools/mcp_tool/session_context.py +++ b/src/google/adk/tools/mcp_tool/session_context.py @@ -69,6 +69,7 @@ def __init__( self._sse_read_timeout = sse_read_timeout self._is_stdio = is_stdio self._session: Optional[ClientSession] = None + self._mcp_session_id: Optional[str] = None self._ready_event = asyncio.Event() self._close_event = asyncio.Event() self._task: Optional[asyncio.Task] = None @@ -79,6 +80,11 @@ def session(self) -> Optional[ClientSession]: """Get the managed ClientSession, if available.""" return self._session + @property + def mcp_session_id(self) -> Optional[str]: + """Get the MCP session ID returned by the server during initialization.""" + return self._mcp_session_id + async def start(self) -> ClientSession: """Start the runner and wait for the session to be ready. @@ -178,8 +184,21 @@ async def _run(self): else None, ) ) - await asyncio.wait_for(session.initialize(), timeout=self._timeout) + + # Initialize the session and capture the response + init_result = await asyncio.wait_for( + session.initialize(), timeout=self._timeout + ) logger.debug('Session has been successfully initialized') + + # Extract mcp-session-id from initialization result if present + # The MCP protocol returns this in the InitializeResult's meta field + if hasattr(init_result, '_meta') and init_result._meta: + self._mcp_session_id = init_result._meta.get('mcp-session-id') + if self._mcp_session_id: + logger.debug( + 'Captured mcp-session-id from server: %s', self._mcp_session_id + ) self._session = session self._ready_event.set()