Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def __init__(
self._async_httpx_client = AsyncHttpxClient(**async_client_args)

# Initialize the aiohttp client session.
self._aiohttp_session: Optional[aiohttp.ClientSession] = None
self._aiohttp_session: Optional[Union['aiohttp.ClientSession', 'AsyncAuthorizedSession']] = None
if self._use_aiohttp():
try:
import aiohttp # pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -841,15 +841,15 @@ async def _get_aiohttp_session(
from google.auth.aio.transport.sessions import AsyncAuthorizedSession

async_creds = StaticCredentials(token=self._access_token()) # type: ignore[no-untyped-call]
self._aiohttp_session = AsyncAuthorizedSession(async_creds) # type: ignore[no-untyped-call]
return self._aiohttp_session
self._aiohttp_session = AsyncAuthorizedSession(async_creds) # type: ignore[no-untyped-call,assignment]
return self._aiohttp_session # type: ignore[return-value]
except ImportError:
pass

if not self._use_google_auth_async() and (
self._aiohttp_session is None
or self._aiohttp_session.closed
or self._aiohttp_session._loop.is_closed()
or self._aiohttp_session.closed # type: ignore[union-attr]
or self._aiohttp_session._loop.is_closed() # type: ignore[union-attr]
): # pylint: disable=protected-access
# Initialize the aiohttp client session if it's not set up or closed.
class AiohttpClientSession(aiohttp.ClientSession): # type: ignore[misc]
Expand Down Expand Up @@ -888,7 +888,7 @@ def __del__(self, _warnings: Any = warnings) -> None:
read_bufsize=READ_BUFFER_SIZE,
)

return self._aiohttp_session
return self._aiohttp_session # type: ignore[return-value]

@staticmethod
def _ensure_httpx_ssl_ctx(
Expand Down Expand Up @@ -1353,7 +1353,7 @@ async def _async_request_once(

if stream:
if self._use_aiohttp():
self._aiohttp_session = await self._get_aiohttp_session()
self._aiohttp_session = await self._get_aiohttp_session() # type: ignore[assignment]
url = http_request.url
if self._use_google_auth_async():
client_cert_source = mtls.default_client_cert_source() # type: ignore[no-untyped-call]
Expand All @@ -1368,7 +1368,7 @@ async def _async_request_once(
else:
url = url.replace('googleapis.com', 'mtls.googleapis.com')
try:
response = await self._aiohttp_session.request(
response = await self._aiohttp_session.request( # type: ignore[union-attr]
method=http_request.method,
url=url,
headers=http_request.headers,
Expand All @@ -1389,8 +1389,8 @@ async def _async_request_once(
self._ensure_aiohttp_ssl_ctx(self._http_options)
)
# Instantiate a new session with the updated SSL context.
self._aiohttp_session = await self._get_aiohttp_session()
response = await self._aiohttp_session.request(
self._aiohttp_session = await self._get_aiohttp_session() # type: ignore[assignment]
response = await self._aiohttp_session.request( # type: ignore[union-attr]
method=http_request.method,
url=url,
headers=http_request.headers,
Expand Down Expand Up @@ -1422,7 +1422,7 @@ async def _async_request_once(
return HttpResponse(client_response.headers, client_response)
else:
if self._use_aiohttp():
self._aiohttp_session = await self._get_aiohttp_session()
self._aiohttp_session = await self._get_aiohttp_session() # type: ignore[assignment]
url = http_request.url
if self._use_google_auth_async():
client_cert_source = mtls.default_client_cert_source() # type: ignore[no-untyped-call]
Expand All @@ -1437,7 +1437,7 @@ async def _async_request_once(
else:
url = url.replace('googleapis.com', 'mtls.googleapis.com')
try:
response = await self._aiohttp_session.request(
response = await self._aiohttp_session.request( # type: ignore[union-attr]
method=http_request.method,
url=url,
headers=http_request.headers,
Expand Down Expand Up @@ -1466,8 +1466,8 @@ async def _async_request_once(
self._ensure_aiohttp_ssl_ctx(self._http_options)
)
# Instantiate a new session with the updated SSL context.
self._aiohttp_session = await self._get_aiohttp_session()
response = await self._aiohttp_session.request(
self._aiohttp_session = await self._get_aiohttp_session() # type: ignore[assignment]
response = await self._aiohttp_session.request( # type: ignore[union-attr]
method=http_request.method,
url=url,
headers=http_request.headers,
Expand Down Expand Up @@ -1862,7 +1862,7 @@ async def _async_upload_fd(

# Upload the file in chunks
if self._use_aiohttp(): # pylint: disable=g-import-not-at-top
self._aiohttp_session = await self._get_aiohttp_session()
self._aiohttp_session = await self._get_aiohttp_session() # type: ignore[assignment]
while True:
if isinstance(file, io.IOBase):
file_chunk = file.read(CHUNK_SIZE)
Expand Down Expand Up @@ -1904,7 +1904,7 @@ async def _async_upload_fd(
retry_count = 0
response = None
while retry_count < MAX_RETRY_COUNT:
response = await self._aiohttp_session.request(
response = await self._aiohttp_session.request( # type: ignore[union-attr]
method='POST',
url=upload_url,
data=file_chunk,
Expand Down Expand Up @@ -2054,8 +2054,8 @@ async def async_download_file(
data = http_request.data

if self._use_aiohttp():
self._aiohttp_session = await self._get_aiohttp_session()
response = await self._aiohttp_session.request(
self._aiohttp_session = await self._get_aiohttp_session() # type: ignore[assignment]
response = await self._aiohttp_session.request( # type: ignore[union-attr]
method=http_request.method,
url=http_request.url,
headers=http_request.headers,
Expand Down
Loading