diff --git a/lib/py/src/ext/protocol.tcc b/lib/py/src/ext/protocol.tcc index 448fc6f105..4b0d493f4a 100644 --- a/lib/py/src/ext/protocol.tcc +++ b/lib/py/src/ext/protocol.tcc @@ -29,6 +29,7 @@ #include #else #include +#include #endif namespace apache { @@ -120,8 +121,10 @@ inline bool input_check(PyObject* input) { inline EncodeBuffer* new_encode_buffer(size_t size) { EncodeBuffer* buffer = new EncodeBuffer; - buffer->buf.reserve(size); - buffer->pos = 0; + if (!buffer->init(size)) { + delete buffer; + return nullptr; + } return buffer; } @@ -165,21 +168,18 @@ inline bool ProtocolBase::isUtf8(PyObject* typeargs) { template PyObject* ProtocolBase::getEncodedValue() { - return PyBytes_FromStringAndSize(output_->buf.data(), output_->buf.size()); + return PyBytes_FromStringAndSize(output_->data, output_->size); } template inline bool ProtocolBase::writeBuffer(char* data, size_t size) { - size_t need = size + output_->pos; - if (output_->buf.capacity() < need) { - try { - output_->buf.reserve(need); - } catch (std::bad_alloc&) { - PyErr_SetString(PyExc_MemoryError, "Failed to allocate write buffer"); - return false; - } + if (!output_->ensure(size)) { + PyErr_SetString(PyExc_MemoryError, "Failed to allocate write buffer"); + return false; } - std::copy(data, data + size, std::back_inserter(output_->buf)); + + memcpy(output_->data + output_->size, data, size); + output_->size += size; return true; } diff --git a/lib/py/src/ext/types.h b/lib/py/src/ext/types.h index 2848b28f0b..01e6a38190 100644 --- a/lib/py/src/ext/types.h +++ b/lib/py/src/ext/types.h @@ -31,6 +31,8 @@ #if PY_MAJOR_VERSION >= 3 #include +#include +#include // TODO: better macros #define PyInt_AsLong(v) PyLong_AsLong(v) @@ -131,8 +133,65 @@ typedef PyObject EncodeBuffer; #else extern const char* refill_signature; struct EncodeBuffer { - std::vector buf; - size_t pos; + char* data; + size_t size; + size_t capacity; + + EncodeBuffer() : data(nullptr), size(0), capacity(0) {} + EncodeBuffer(const EncodeBuffer&) = delete; + EncodeBuffer& operator=(const EncodeBuffer&) = delete; + + ~EncodeBuffer() { + if (data) { + free(data); + } + } + + bool init(size_t initial_capacity) { + if (initial_capacity == 0) { + data = nullptr; + size = 0; + capacity = 0; + return true; + } + + data = static_cast(malloc(initial_capacity)); + if (!data) { + return false; + } + size = 0; + capacity = initial_capacity; + return true; + } + + bool ensure(size_t additional) { + if (additional > (std::numeric_limits::max)() - size) { + return false; + } + + size_t needed = size + additional; + if (needed <= capacity) { + return true; + } + + size_t new_capacity = capacity == 0 ? needed : capacity; + while (new_capacity < needed) { + if (new_capacity > (std::numeric_limits::max)() / 2) { + new_capacity = needed; + break; + } + new_capacity *= 2; + } + + char* new_data = static_cast(realloc(data, new_capacity)); + if (!new_data) { + return false; + } + + data = new_data; + capacity = new_capacity; + return true; + } }; #endif diff --git a/lib/py/test/thrift_TBinaryProtocol.py b/lib/py/test/thrift_TBinaryProtocol.py index d4269eb617..a94371433f 100644 --- a/lib/py/test/thrift_TBinaryProtocol.py +++ b/lib/py/test/thrift_TBinaryProtocol.py @@ -22,7 +22,9 @@ import uuid import _import_local_thrift # noqa +from thrift.Thrift import TApplicationException from thrift.protocol.TBinaryProtocol import TBinaryProtocol +from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory from thrift.protocol.TProtocol import TProtocolException from thrift.transport import TTransport @@ -167,6 +169,13 @@ def testField(type, data): protocol.readStructEnd() +APPLICATION_EXCEPTION_THRIFT_SPEC = ( + None, + (1, 11, "message", "UTF8", None), + (2, 8, "type", None, None), +) + + def testMessage(data, strict=True): message = {} message['name'] = data[0] @@ -196,6 +205,13 @@ def testMessage(data, strict=True): class TestTBinaryProtocol(unittest.TestCase): + def setUp(self): + try: + from thrift.protocol import fastbinary # noqa: F401 + self._has_fastbinary = True + except ImportError: + self._has_fastbinary = False + def test_TBinaryProtocol_write_read(self): try: testNaked('Byte', 123) @@ -280,6 +296,33 @@ def test_TBinaryProtocol_write_read(self): print("Assertion fail") raise e + def test_accelerated_large_message_roundtrip(self): + if not self._has_fastbinary: + self.skipTest("C extension not available") + + original = TApplicationException( + type=TApplicationException.INTERNAL_ERROR, + message="x" * 8192, + ) + + otrans = TTransport.TMemoryBuffer() + oproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(otrans) + oproto.trans.write(oproto._fast_encode( + original, + [TApplicationException, APPLICATION_EXCEPTION_THRIFT_SPEC], + )) + + itrans = TTransport.TMemoryBuffer(otrans.getvalue()) + iproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(itrans) + decoded = iproto._fast_decode( + None, + iproto, + [TApplicationException, APPLICATION_EXCEPTION_THRIFT_SPEC], + ) + + self.assertEqual(decoded.message, original.message) + self.assertEqual(decoded.type, original.type) + def test_TBinaryProtocol_no_strict_write_read(self): TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4} test_data = [("short message name", TMessageType['T_CALL'], 0),