Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 81 additions & 23 deletions src/google/adk/tools/mcp_tool/mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ def __init__(
self._connection_params = connection_params
self._errlog = errlog

# Session pool: maps session keys to (session, exit_stack, loop) tuples
# Session pool: maps session keys to (session, exit_stack, loop, mcp_session_id) tuples
self._sessions: Dict[
str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop]
str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop, Optional[str]]
] = {}

# Map of event loops to their respective locks to prevent race conditions
Expand All @@ -241,39 +241,54 @@ def _session_lock(self) -> asyncio.Lock:
return self._session_lock_map[current_loop]

def _generate_session_key(
self, merged_headers: Optional[Dict[str, str]] = None
self, merged_headers: Optional[Dict[str, str]] = None,
mcp_session_id: Optional[str] = None
) -> str:
"""Generates a session key based on connection params and merged headers.
"""Generates a session key based on connection params, headers, and mcp-session-id.

For StdioConnectionParams, returns a constant key since headers are not
supported. For SSE and StreamableHTTP connections, generates a key based
on the provided merged headers.
on the provided merged headers and mcp-session-id if available.

Args:
merged_headers: Already merged headers (base + additional).
mcp_session_id: Optional MCP session ID from server initialization.

Returns:
A unique session key string.
"""
if isinstance(self._connection_params, StdioConnectionParams):
# For stdio connections, headers are not supported, so use constant key
# For stdio connections, headers are not supported
# But we should still include mcp_session_id if available
if mcp_session_id:
return f'stdio_session_{mcp_session_id}'
return 'stdio_session'

# For SSE and StreamableHTTP connections, use merged headers
# For SSE and StreamableHTTP connections, use merged headers and mcp_session_id
key_parts = []

if mcp_session_id:
key_parts.append(f'mcp_id:{mcp_session_id}')

if merged_headers:
headers_json = json.dumps(merged_headers, sort_keys=True)
headers_hash = hashlib.md5(headers_json.encode()).hexdigest()
return f'session_{headers_hash}'
key_parts.append(f'headers:{headers_hash}')
Comment on lines +270 to +276
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To enable correct and simple session reuse logic in create_session, the constant part of the session key (derived from headers) should precede the mcp_session_id part. This allows for simple prefix matching. Please swap the order of these if blocks.

Suggested change
if mcp_session_id:
key_parts.append(f'mcp_id:{mcp_session_id}')
if merged_headers:
headers_json = json.dumps(merged_headers, sort_keys=True)
headers_hash = hashlib.md5(headers_json.encode()).hexdigest()
return f'session_{headers_hash}'
key_parts.append(f'headers:{headers_hash}')
if merged_headers:
headers_json = json.dumps(merged_headers, sort_keys=True)
headers_hash = hashlib.md5(headers_json.encode()).hexdigest()
key_parts.append(f'headers:{headers_hash}')
if mcp_session_id:
key_parts.append(f'mcp_id:{mcp_session_id}')


if key_parts:
return 'session_' + '_'.join(key_parts)
else:
return 'session_no_headers'

def _merge_headers(
self, additional_headers: Optional[Dict[str, str]] = None
self, additional_headers: Optional[Dict[str, str]] = None,
mcp_session_id: Optional[str] = None
) -> Optional[Dict[str, str]]:
"""Merges base connection headers with additional headers.
"""Merges base connection headers with additional headers and mcp-session-id.

Args:
additional_headers: Optional headers to merge with connection headers.
mcp_session_id: Optional MCP session ID to include as a header.

Returns:
Merged headers dictionary, or None if no headers are provided.
Expand All @@ -293,8 +308,12 @@ def _merge_headers(

if additional_headers:
base_headers.update(additional_headers)

# Add mcp-session-id header if available
if mcp_session_id:
base_headers['mcp-session-id'] = mcp_session_id

return base_headers
return base_headers if base_headers else None

def _is_session_disconnected(self, session: ClientSession) -> bool:
"""Checks if a session is disconnected or closed.
Expand Down Expand Up @@ -414,7 +433,9 @@ async def create_session(

This method will check if an existing session for the given headers
is still connected. If it's disconnected, it will be cleaned up and
a new session will be created.
a new session will be created. The mcp-session-id returned by the
server during initialization is captured and reused in subsequent
requests to maintain stateful server sessions.

Args:
headers: Optional headers to include in the session. These will be
Expand All @@ -424,17 +445,38 @@ async def create_session(
Returns:
ClientSession: The initialized MCP client session.
"""
# Merge headers once at the beginning
merged_headers = self._merge_headers(headers)

# Generate session key using merged headers
session_key = self._generate_session_key(merged_headers)
# First, try to find an existing session with a stored mcp-session-id
# This allows us to reuse sessions across invocations
existing_mcp_session_id = None

# Check all sessions to find one with matching base characteristics
# (ignoring mcp-session-id in the key for this lookup)
temp_merged_headers = self._merge_headers(headers, mcp_session_id=None)
temp_key = self._generate_session_key(temp_merged_headers, mcp_session_id=None)

# Look for existing sessions that match the base key pattern
async with self._session_lock:
for key, (session, exit_stack, stored_loop, mcp_id) in list(self._sessions.items()):
# Check if this session matches our connection params (ignoring mcp-session-id)
if key.startswith(temp_key.split('_mcp_id')[0]) and mcp_id:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for finding a matching session is incorrect. temp_key.split('_mcp_id')[0] will not work as _mcp_id is not in the key string, and the original key part order makes prefix matching fail.

With the suggested change to _generate_session_key (placing the headers part first), this check can be simplified and corrected to use startswith.

Suggested change
if key.startswith(temp_key.split('_mcp_id')[0]) and mcp_id:
if key.startswith(temp_key + '_') and mcp_id:

current_loop = asyncio.get_running_loop()
if stored_loop is current_loop and not self._is_session_disconnected(session):
# Found a valid session with mcp-session-id, use it
existing_mcp_session_id = mcp_id
logger.debug('Reusing existing mcp-session-id: %s', existing_mcp_session_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The mcp_session_id is a sensitive session identifier used to maintain stateful sessions with the MCP server. Logging it at the DEBUG level can lead to session identifier leakage in log files, which could be exploited to interfere with active sessions.

Suggested change
logger.debug('Reusing existing mcp-session-id: %s', existing_mcp_session_id)
logger.debug('Reusing existing mcp-session-id')

break

# Merge headers with mcp-session-id if we have one
merged_headers = self._merge_headers(headers, mcp_session_id=existing_mcp_session_id)

# Generate session key using merged headers and mcp-session-id
session_key = self._generate_session_key(merged_headers, mcp_session_id=existing_mcp_session_id)

# Use async lock to prevent race conditions
async with self._session_lock:
# Check if we have an existing session
# Check if we have an existing session with this exact key
if session_key in self._sessions:
session, exit_stack, stored_loop = self._sessions[session_key]
session, exit_stack, stored_loop, mcp_id = self._sessions[session_key]

# Check if the existing session is still connected and bound to the current loop
current_loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -468,7 +510,7 @@ async def create_session(
client = self._create_client(merged_headers)
is_stdio = isinstance(self._connection_params, StdioConnectionParams)

session = await asyncio.wait_for(
session_context = await asyncio.wait_for(
exit_stack.enter_async_context(
SessionContext(
client=client,
Expand All @@ -479,12 +521,28 @@ async def create_session(
),
timeout=timeout_in_seconds,
)

# Store session, exit stack, and loop in the pool

session = session_context.session

# Extract mcp-session-id from the session context if available
new_mcp_session_id = session_context.mcp_session_id

# If we got a new mcp-session-id, update the session key and re-store
if new_mcp_session_id and new_mcp_session_id != existing_mcp_session_id:
logger.info('Received new mcp-session-id from server: %s', new_mcp_session_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The mcp_session_id is logged at the INFO level. Since INFO logs are frequently enabled in production environments, this poses a significant risk of leaking session identifiers to log management systems.

Suggested change
logger.info('Received new mcp-session-id from server: %s', new_mcp_session_id)
logger.info('Received new mcp-session-id from server')

# Remove old session entry if it exists
if session_key in self._sessions:
del self._sessions[session_key]
# Generate new key with the mcp-session-id
merged_headers_with_id = self._merge_headers(headers, mcp_session_id=new_mcp_session_id)
session_key = self._generate_session_key(merged_headers_with_id, mcp_session_id=new_mcp_session_id)

# Store session, exit stack, loop, and mcp_session_id in the pool
self._sessions[session_key] = (
session,
exit_stack,
asyncio.get_running_loop(),
new_mcp_session_id or existing_mcp_session_id,
)
logger.debug('Created new session: %s', session_key)
return session
Expand All @@ -504,7 +562,7 @@ async def close(self):
"""Closes all sessions and cleans up resources."""
async with self._session_lock:
for session_key in list(self._sessions.keys()):
_, exit_stack, stored_loop = self._sessions[session_key]
_, exit_stack, stored_loop, _ = self._sessions[session_key]
await self._cleanup_session(session_key, exit_stack, stored_loop)


Expand Down
17 changes: 17 additions & 0 deletions src/google/adk/tools/mcp_tool/mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,23 @@ async def close(self) -> None:
# Log the error but don't re-raise to avoid blocking shutdown
print(f"Warning: Error during McpToolset cleanup: {e}", file=self._errlog)

def get_session_info(self) -> dict[str, Any]:
"""Returns information about the current MCP session state.

This is useful for debugging and understanding session reuse.

Returns:
Dictionary with session information including active session keys
and mcp-session-id values.
"""
session_info = {}
for key, (_, _, _, mcp_id) in self._mcp_session_manager._sessions.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing the private member _sessions of _mcp_session_manager breaks encapsulation. This makes McpToolset dependent on the internal implementation of MCPSessionManager and more brittle to future changes. It would be better to add a public method to MCPSessionManager to provide this debugging information.

session_info[key] = {
'mcp_session_id': mcp_id,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The get_session_info method returns the raw mcp_session_id for all active sessions. As a public method intended for debugging, it should avoid exposing full session identifiers which could be leaked if the method's output is displayed in a UI or exposed via an API. Consider masking the identifier.

Suggested change
'mcp_session_id': mcp_id,
'mcp_session_id': mcp_id[:8] + '...' if mcp_id else None,

'active': True
}
return session_info

@override
def get_auth_config(self) -> Optional[AuthConfig]:
"""Returns the auth config for this toolset.
Expand Down
21 changes: 20 additions & 1 deletion src/google/adk/tools/mcp_tool/session_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
self._sse_read_timeout = sse_read_timeout
self._is_stdio = is_stdio
self._session: Optional[ClientSession] = None
self._mcp_session_id: Optional[str] = None
self._ready_event = asyncio.Event()
self._close_event = asyncio.Event()
self._task: Optional[asyncio.Task] = None
Expand All @@ -79,6 +80,11 @@ def session(self) -> Optional[ClientSession]:
"""Get the managed ClientSession, if available."""
return self._session

@property
def mcp_session_id(self) -> Optional[str]:
"""Get the MCP session ID returned by the server during initialization."""
return self._mcp_session_id

async def start(self) -> ClientSession:
"""Start the runner and wait for the session to be ready.

Expand Down Expand Up @@ -178,8 +184,21 @@ async def _run(self):
else None,
)
)
await asyncio.wait_for(session.initialize(), timeout=self._timeout)

# Initialize the session and capture the response
init_result = await asyncio.wait_for(
session.initialize(), timeout=self._timeout
)
logger.debug('Session has been successfully initialized')

# Extract mcp-session-id from initialization result if present
# The MCP protocol returns this in the InitializeResult's meta field
if hasattr(init_result, '_meta') and init_result._meta:
self._mcp_session_id = init_result._meta.get('mcp-session-id')
if self._mcp_session_id:
logger.debug(
'Captured mcp-session-id from server: %s', self._mcp_session_id
)
Comment on lines +199 to +201
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

Logging the mcp_session_id returned by the server during initialization can expose sensitive session state identifiers in debug logs.

Suggested change
logger.debug(
'Captured mcp-session-id from server: %s', self._mcp_session_id
)
logger.debug(
'Captured mcp-session-id from server'
)


self._session = session
self._ready_event.set()
Expand Down