diff --git a/fastapi_cache/backends/redis.py b/fastapi_cache/backends/redis.py index 0518b6d..d7140b0 100644 --- a/fastapi_cache/backends/redis.py +++ b/fastapi_cache/backends/redis.py @@ -1,3 +1,4 @@ +import warnings from typing import Union, Optional, AnyStr import aioredis @@ -6,8 +7,9 @@ from .base import BaseCacheBackend DEFAULT_ENCODING = 'utf-8' -DEFAULT_POOL_MIN_SIZE = 5 CACHE_KEY = 'REDIS' +# a singleton sentinel value for parameter defaults +_sentinel = object() # expected to be of bytearray, bytes, float, int, or str type @@ -19,11 +21,16 @@ class RedisCacheBackend(BaseCacheBackend[RedisKey, RedisValue]): def __init__( self, address: str, - pool_minsize: Optional[int] = DEFAULT_POOL_MIN_SIZE, + pool_minsize: Optional[int] = _sentinel, encoding: Optional[str] = DEFAULT_ENCODING, ) -> None: self._redis_address = address - self._redis_pool_minsize = pool_minsize + if pool_minsize is not _sentinel: + warnings.warn( + "Parameter 'pool_minsize' has been obsolete since aioredis 2.0.0.", + DeprecationWarning, + ) + self._encoding = encoding self._pool: Optional[Redis] = None @@ -36,10 +43,7 @@ async def _client(self) -> Redis: return self._pool async def _create_connection(self) -> Redis: - return await aioredis.create_redis_pool( - self._redis_address, - minsize=self._redis_pool_minsize, - ) + return aioredis.from_url(self._redis_address) async def add( self, @@ -75,10 +79,12 @@ async def get( default: RedisValue = None, **kwargs, ) -> AnyStr: - kwargs.setdefault('encoding', self._encoding) + encoding = kwargs.pop("encoding", self._encoding) client = await self._client cached_value = await client.get(key, **kwargs) + if encoding is not None and isinstance(cached_value, bytes): + cached_value = cached_value.decode(encoding) return cached_value if cached_value is not None else default @@ -118,5 +124,6 @@ async def expire( async def close(self) -> None: client = await self._client - client.close() - await client.wait_closed() + # Redis.close() only close currrent connection, but not the pool + await client.connection_pool.disconnect() + await client.close() diff --git a/setup.py b/setup.py index 23efc44..b1aed25 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,6 @@ 'redis', 'aioredis', 'asyncio', 'fastapi', 'starlette', 'cache' ], install_requires=[ - 'aioredis==1.3.1', + 'aioredis==2.0.0', ], ) diff --git a/tests/redis_tests.py b/tests/redis_tests.py index cf498be..c03107b 100644 --- a/tests/redis_tests.py +++ b/tests/redis_tests.py @@ -32,6 +32,7 @@ async def test_should_add_n_get_data_no_encoding( ) -> None: NO_ENCODING_KEY = 'bytes' NO_ENCODING_VALUE = b'test' + await f_backend.expire(NO_ENCODING_KEY, 0) is_added = await f_backend.add(NO_ENCODING_KEY, NO_ENCODING_VALUE) assert is_added is True @@ -165,8 +166,8 @@ async def test_close_should_close_connection( f_backend: RedisCacheBackend ) -> None: await f_backend.close() - with pytest.raises(aioredis.errors.PoolClosedError): - await f_backend.add(TEST_KEY, TEST_VALUE) + assert len(f_backend._pool.connection_pool._available_connections) == 0 + assert len(f_backend._pool.connection_pool._in_use_connections) == 0 @pytest.mark.asyncio