diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index b57f639da2..f7b941d594 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -160,6 +160,8 @@ class Fory: "_output_stream", "field_nullable", "policy", + "max_collection_size", + "max_binary_size", ) def __init__( @@ -172,6 +174,8 @@ def __init__( policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, + max_collection_size: int = 1_000_000, + max_binary_size: int = 64 * 1024 * 1024, ): """ Initialize a Fory serialization instance. @@ -210,6 +214,17 @@ def __init__( field_nullable: Treat all dataclass fields as nullable regardless of Optional annotation. + max_collection_size: Maximum allowed size for collections (lists, sets, tuples) + and maps (dicts) during deserialization. This limit is used to prevent + out-of-memory attacks from malicious payloads that claim extremely large + collection sizes, as collections preallocate memory based on the declared + size. Raises an exception if exceeded. Default is 1,000,000. + + max_binary_size: Maximum allowed size in bytes for binary data reads during + deserialization (default: 64 MB). Raises an exception if a single binary + read exceeds this limit, preventing out-of-memory attacks from malicious + payloads that claim extremely large binary sizes. + Example: >>> # Python-native mode with reference tracking >>> fory = Fory(ref=True) @@ -235,7 +250,8 @@ def __init__( self.serialization_context = SerializationContext(fory=self, scoped_meta_share_enabled=compatible) self.type_resolver.initialize() - self.buffer = Buffer.allocate(32) + self.max_binary_size = max_binary_size + self.buffer = Buffer.allocate(32, max_binary_size=max_binary_size) self.buffer_callback = None self._buffers = None self._unsupported_callback = None @@ -243,6 +259,7 @@ def __init__( self.is_peer_out_of_band_enabled = False self.max_depth = max_depth self.depth = 0 + self.max_collection_size = max_collection_size self._output_stream = None def register( @@ -621,7 +638,7 @@ def _deserialize( assert self.depth == 0, "Nested deserialization should use read_ref/read_no_ref." self.depth += 1 if isinstance(buffer, bytes): - buffer = Buffer(buffer) + buffer = Buffer(buffer, max_binary_size=self.max_binary_size) if unsupported_objects is not None: self._unsupported_objects = iter(unsupported_objects) reader_index = buffer.get_reader_index() @@ -666,6 +683,7 @@ def _read_no_ref_internal(self, buffer, serializer): """Internal method to read without modifying read_ref_ids.""" if serializer is None: serializer = self.type_resolver.read_type_info(buffer).serializer + self.inc_depth() o = serializer.read(buffer) self.dec_depth() @@ -812,6 +830,10 @@ class ThreadSafeFory: strict (bool): Whether to require type registration. Defaults to True. compatible (bool): Whether to enable compatible mode. Defaults to False. max_depth (int): Maximum depth for deserialization. Defaults to 50. + max_collection_size (int): Maximum allowed size for collections and maps during + deserialization. Defaults to 1,000,000. + max_binary_size (int): Maximum allowed size in bytes for binary data reads during + deserialization. Defaults to 64 MB. Example: >>> import pyfury diff --git a/python/pyfory/buffer.pxi b/python/pyfory/buffer.pxi index 2b98353f48..9424189bf7 100644 --- a/python/pyfory/buffer.pxi +++ b/python/pyfory/buffer.pxi @@ -124,9 +124,11 @@ cdef class Buffer: object output_stream Py_ssize_t shape[1] Py_ssize_t stride[1] + int32_t max_binary_size - def __init__(self, data not None, int32_t offset=0, length=None): + def __init__(self, data not None, int32_t offset=0, length=None, int32_t max_binary_size= 64 * 1024 * 1024): self.data = data + self.max_binary_size = max_binary_size cdef int32_t buffer_len = len(data) cdef int length_ if length is None: @@ -146,7 +148,7 @@ cdef class Buffer: self.output_stream = None @classmethod - def from_stream(cls, stream not None, uint32_t buffer_size=4096): + def from_stream(cls, stream not None, uint32_t buffer_size=4096, int32_t max_binary_size=64 * 1024 * 1024): cdef CBuffer* stream_buffer cdef c_string stream_error if Fory_PyCreateBufferFromStream( @@ -156,6 +158,7 @@ cdef class Buffer: if stream_buffer == NULL: raise ValueError("failed to create stream buffer") cdef Buffer buffer = Buffer.__new__(Buffer) + buffer.max_binary_size = max_binary_size buffer.c_buffer = move(deref(stream_buffer)) del stream_buffer buffer.data = stream @@ -167,6 +170,7 @@ cdef class Buffer: @staticmethod cdef Buffer wrap(shared_ptr[CBuffer] c_buffer): cdef Buffer buffer = Buffer.__new__(Buffer) + buffer.max_binary_size = 64 * 1024 * 1024 cdef CBuffer* ptr = c_buffer.get() buffer.c_buffer = CBuffer(ptr.data(), ptr.size(), False) cdef _SharedBufferOwner owner = _SharedBufferOwner.__new__(_SharedBufferOwner) @@ -178,11 +182,12 @@ cdef class Buffer: return buffer @classmethod - def allocate(cls, int32_t size): + def allocate(cls, int32_t size, int32_t max_binary_size=64 * 1024 * 1024): cdef CBuffer* buf = allocate_buffer(size) if buf == NULL: raise MemoryError("out of memory") cdef Buffer buffer = Buffer.__new__(Buffer) + buffer.max_binary_size = max_binary_size buffer.c_buffer = move(deref(buf)) del buf buffer.data = None @@ -407,6 +412,10 @@ cdef class Buffer: cpdef inline bytes read_bytes(self, int32_t length): if length == 0: return b"" + + if length > self.max_binary_size: + raise ValueError(f"Binary size {length} exceeds the configured limit of {self.max_binary_size}") + cdef bytes py_bytes = PyBytes_FromStringAndSize(NULL, length) if py_bytes is None: raise MemoryError("out of memory") diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi index 150f354e68..ecca3bff39 100644 --- a/python/pyfory/collection.pxi +++ b/python/pyfory/collection.pxi @@ -393,6 +393,9 @@ cdef class ListSerializer(CollectionSerializer): cdef MapRefResolver ref_resolver = self.fory.ref_resolver cdef TypeResolver type_resolver = self.fory.type_resolver cdef int32_t len_ = buffer.read_var_uint32() + # Check size limit before PyList_New preallocation to prevent OOM attacks + if len_ > self.fory.max_collection_size: + raise ValueError(f"List size {len_} exceeds the configured limit of {self.fory.max_collection_size}") cdef list list_ = PyList_New(len_) if len_ == 0: return list_ @@ -493,6 +496,9 @@ cdef class TupleSerializer(CollectionSerializer): cdef MapRefResolver ref_resolver = self.fory.ref_resolver cdef TypeResolver type_resolver = self.fory.type_resolver cdef int32_t len_ = buffer.read_var_uint32() + # Check size limit before PyTuple_New preallocation to prevent OOM attacks + if len_ > self.fory.max_collection_size: + raise ValueError(f"Tuple size {len_} exceeds the configured limit of {self.fory.max_collection_size}") cdef tuple tuple_ = PyTuple_New(len_) if len_ == 0: return tuple_ @@ -575,6 +581,9 @@ cdef class SetSerializer(CollectionSerializer): cdef set instance = set() ref_resolver.reference(instance) cdef int32_t len_ = buffer.read_var_uint32() + # Check size limit to prevent OOM attacks from malicious payloads + if len_ > self.fory.max_collection_size: + raise ValueError(f"Set size {len_} exceeds the configured limit of {self.fory.max_collection_size}") if len_ == 0: return instance cdef int8_t collect_flag = buffer.read_int8() @@ -897,6 +906,9 @@ cdef class MapSerializer(Serializer): cdef MapRefResolver ref_resolver = self.ref_resolver cdef TypeResolver type_resolver = self.type_resolver cdef int32_t size = buffer.read_var_uint32() + # Check size limit before _PyDict_NewPresized preallocation to prevent OOM attacks + if size > self.fory.max_collection_size: + raise ValueError(f"Map size {size} exceeds the configured limit of {self.fory.max_collection_size}") cdef dict map_ = _PyDict_NewPresized(size) ref_resolver.reference(map_) cdef int32_t ref_id diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py index 5bb88cbaa2..b44fec0261 100644 --- a/python/pyfory/collection.py +++ b/python/pyfory/collection.py @@ -167,6 +167,9 @@ def _write_different_types(self, buffer, value, collect_flag=0): def read(self, buffer): len_ = buffer.read_var_uint32() + # Check size limit before collection preallocation to prevent OOM attacks + if len_ > self.fory.max_collection_size: + raise ValueError(f"Collection size {len_} exceeds the configured limit of {self.fory.max_collection_size}") collection_ = self.new_instance(self.type_) if len_ == 0: return collection_ @@ -481,6 +484,9 @@ def read(self, buffer): ref_resolver = self.ref_resolver type_resolver = self.type_resolver size = buffer.read_var_uint32() + # Check size limit to prevent OOM attacks from malicious payloads + if size > fory.max_collection_size: + raise ValueError(f"Map size {size} exceeds the configured limit of {fory.max_collection_size}") map_ = {} ref_resolver.reference(map_) chunk_header = 0 diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index dea890fbc0..d443635648 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -1078,6 +1078,8 @@ cdef class Fory: cdef public bint is_peer_out_of_band_enabled cdef int32_t max_depth cdef int32_t depth + cdef public int32_t max_collection_size + cdef public int32_t max_binary_size cdef object _output_stream def __init__( @@ -1090,6 +1092,8 @@ cdef class Fory: max_depth: int = 50, field_nullable: bool = False, meta_compressor=None, + max_collection_size: int = 1_000_000, + max_binary_size: int = 64 * 1024 * 1024, ): """ Initialize a Fory serialization instance. @@ -1128,6 +1132,17 @@ cdef class Fory: field_nullable: Treat all dataclass fields as nullable regardless of Optional annotation. + max_collection_size: Maximum allowed size for collections (lists, sets, tuples) + and maps (dicts) during deserialization. This limit is used to prevent + out-of-memory attacks from malicious payloads that claim extremely large + collection sizes, as collections preallocate memory based on the declared + size. Raises an exception if exceeded. Default is 1,000,000. + + max_binary_size: Maximum allowed size in bytes for binary data reads during + deserialization (default: 64 MB). Raises an exception if a single binary + read exceeds this limit, preventing out-of-memory attacks from malicious + payloads that claim extremely large binary sizes. + Example: >>> # Python-native mode with reference tracking >>> fory = Fory(ref=True) @@ -1149,7 +1164,8 @@ cdef class Fory: self.type_resolver = TypeResolver(self, meta_share=compatible, meta_compressor=meta_compressor) self.serialization_context = SerializationContext(fory=self, scoped_meta_share_enabled=compatible) self.type_resolver.initialize() - self.buffer = Buffer.allocate(32) + self.max_binary_size = max_binary_size + self.buffer = Buffer.allocate(32, max_binary_size=max_binary_size) self.buffer_callback = None self._buffers = None self._unsupported_callback = None @@ -1157,6 +1173,7 @@ cdef class Fory: self.is_peer_out_of_band_enabled = False self.depth = 0 self.max_depth = max_depth + self.max_collection_size = max_collection_size self._output_stream = None def register_serializer(self, cls: Union[type, TypeVar], Serializer serializer): @@ -1508,7 +1525,7 @@ cdef class Fory: """ try: if type(buffer) == bytes: - buffer = Buffer(buffer) + buffer = Buffer(buffer, max_binary_size=self.max_binary_size) return self._deserialize(buffer, buffers, unsupported_objects) finally: self.reset_read() diff --git a/python/pyfory/tests/test_size_guardrails.py b/python/pyfory/tests/test_size_guardrails.py new file mode 100644 index 0000000000..700b2baa2e --- /dev/null +++ b/python/pyfory/tests/test_size_guardrails.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Test max_collection_size and max_binary_size guardrails to prevent OOM attacks +from malicious payloads. + +Collections preallocate memory based on declared size, so they need guardrails. +Binary reads are guarded by max_binary_size on the Buffer. +""" + +from dataclasses import dataclass +from typing import List + +import pytest + +import pyfory +from pyfory import Fory +from pyfory.serialization import Buffer + + +def roundtrip(data, limit, xlang=False, ref=False): + """Serialize and deserialize with given collection size limit.""" + writer = Fory(xlang=xlang, ref=ref) + reader = Fory(xlang=xlang, ref=ref, max_collection_size=limit) + return reader.deserialize(writer.serialize(data)) + + +def roundtrip_binary(data, max_binary_size, xlang=False, ref=False): + """Serialize and deserialize with given binary size limit.""" + writer = Fory(xlang=xlang, ref=ref) + reader = Fory(xlang=xlang, ref=ref, max_binary_size=max_binary_size) + return reader.deserialize(writer.serialize(data)) + + +class TestCollectionSizeLimit: + """Collections (list/set/dict) preallocate memory, so need size limits.""" + + @pytest.mark.parametrize("xlang", [False, True]) + @pytest.mark.parametrize( + "data,limit", + [ + ([1, 2, 3], 10), # list within limit + ({1, 2, 3}, 10), # set within limit + ({"a": 1}, 10), # dict within limit + ([], 0), # empty list ok + (set(), 0), # empty set ok + ({}, 0), # empty dict ok + ], + ) + def test_within_limit_succeeds(self, xlang, data, limit): + assert roundtrip(data, limit, xlang=xlang) == data + + @pytest.mark.parametrize("xlang", [False, True]) + @pytest.mark.parametrize( + "data,limit", + [ + (list(range(10)), 5), # list exceeds + (set(range(10)), 5), # set exceeds + ({str(i): i for i in range(10)}, 5), # dict exceeds + ([[1], list(range(10))], 5), # nested inner exceeds + ], + ) + def test_exceeds_limit_fails(self, xlang, data, limit): + with pytest.raises(ValueError, match="exceeds the configured limit"): + roundtrip(data, limit, xlang=xlang) + + @pytest.mark.parametrize("ref", [False, True]) + @pytest.mark.parametrize( + "data,limit,should_fail", + [ + ((1, 2, 3), 10, False), + (tuple(range(10)), 5, True), + ], + ) + def test_tuple_limit(self, ref, data, limit, should_fail): + """Tuple only works in xlang=False mode.""" + if should_fail: + with pytest.raises(ValueError, match="exceeds the configured limit"): + roundtrip(data, limit, xlang=False, ref=ref) + else: + assert roundtrip(data, limit, xlang=False, ref=ref) == data + + def test_default_limit_is_one_million(self): + assert Fory().max_collection_size == 1_000_000 + + def test_dataclass_list_field_exceeds_limit(self): + @dataclass + class Container: + items: List[pyfory.int32] + + writer = Fory(xlang=True) + writer.register(Container) + reader = Fory(xlang=True, max_collection_size=5) + reader.register(Container) + + with pytest.raises(ValueError, match="exceeds the configured limit"): + reader.deserialize(writer.serialize(Container(items=list(range(10))))) + + +class TestBinarySizeLimit: + """Binary reads are guarded by max_binary_size on the Buffer.""" + + def test_default_limit_is_64mib(self): + assert Fory().max_binary_size == 64 * 1024 * 1024 + + @pytest.mark.parametrize("xlang", [False, True]) + def test_within_limit_succeeds(self, xlang): + assert roundtrip_binary(b"x" * 100, max_binary_size=1024, xlang=xlang) == b"x" * 100 + + @pytest.mark.parametrize("xlang", [False, True]) + def test_exceeds_limit_fails(self, xlang): + with pytest.raises(ValueError, match="exceeds the configured limit"): + roundtrip_binary(b"x" * 200, max_binary_size=100, xlang=xlang) + + def test_from_stream_respects_limit(self): + import io + + payload = Fory().serialize(b"x" * 200) + buf = Buffer.from_stream(io.BytesIO(payload), max_binary_size=100) + with pytest.raises(ValueError, match="exceeds the configured limit"): + Fory(max_binary_size=100).deserialize(buf)