Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions lib/py/src/ext/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,36 @@ static PyObject* decode_compact(PyObject*, PyObject* args) {
return decode_impl<CompactProtocol>(args);
}

static PyObject* decode_binary_from_bytes(PyObject*, PyObject* args) {
PyObject* bytes_obj = nullptr;
PyObject* typeargs = nullptr;
if (!PyArg_ParseTuple(args, "OO", &bytes_obj, &typeargs)) {
return nullptr;
}
if (!PyBytes_Check(bytes_obj)) {
PyErr_SetString(PyExc_TypeError, "first argument must be bytes");
return nullptr;
}

StructTypeArgs parsedargs;
if (!parse_struct_args(&parsedargs, typeargs)) {
return nullptr;
}

BinaryProtocol protocol;
if (!protocol.prepareDecodeBufferFromBytes(bytes_obj)) {
return nullptr;
}

return protocol.readStruct(Py_None, parsedargs.klass, parsedargs.spec);
}

static PyMethodDef ThriftFastBinaryMethods[] = {
{"encode_binary", encode_binary, METH_VARARGS, ""},
{"decode_binary", decode_binary, METH_VARARGS, ""},
{"encode_compact", encode_compact, METH_VARARGS, ""},
{"decode_compact", decode_compact, METH_VARARGS, ""},
{"decode_binary_from_bytes", decode_binary_from_bytes, METH_VARARGS, ""},
{nullptr, nullptr, 0, nullptr} /* Sentinel */
};
Comment on lines 166 to 173

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following whats already here


Expand Down
1 change: 1 addition & 0 deletions lib/py/src/ext/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ProtocolBase {
inline virtual ~ProtocolBase();

bool prepareDecodeBufferFromTransport(PyObject* trans);
bool prepareDecodeBufferFromBytes(PyObject* bytes_obj);

PyObject* readStruct(PyObject* output, PyObject* klass, PyObject* spec_seq);

Expand Down
38 changes: 34 additions & 4 deletions lib/py/src/ext/protocol.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,18 @@ bool ProtocolBase<Impl>::readBytes(char** output, int len) {
PyErr_Format(PyExc_ValueError, "attempted to read negative length: %d", len);
return false;
}
// TODO(dreiss): Don't fear the malloc. Think about taking a copy of
// the partial read instead of forcing the transport
// to prepend it to its buffer.

if (input_.direct_buf) {
size_t requested = static_cast<size_t>(len);
if (input_.direct_pos > input_.direct_size || requested > (input_.direct_size - input_.direct_pos)) {
PyErr_SetString(PyExc_EOFError, "read past end of buffer");
return false;
}

*output = const_cast<char*>(input_.direct_buf + input_.direct_pos);
input_.direct_pos += requested;
return true;
}
Comment on lines +305 to +315

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch mr bot - fixed


int rlen = detail::read_buffer(input_.stringiobuf.get(), output, len);

Expand Down Expand Up @@ -338,7 +347,7 @@ bool ProtocolBase<Impl>::readBytes(char** output, int len) {

template <typename Impl>
bool ProtocolBase<Impl>::prepareDecodeBufferFromTransport(PyObject* trans) {
if (input_.stringiobuf) {
if (input_.stringiobuf || input_.direct_buf) {
PyErr_SetString(PyExc_ValueError, "decode buffer is already initialized");
return false;
}
Expand Down Expand Up @@ -366,6 +375,27 @@ bool ProtocolBase<Impl>::prepareDecodeBufferFromTransport(PyObject* trans) {
return true;
}

template <typename Impl>
bool ProtocolBase<Impl>::prepareDecodeBufferFromBytes(PyObject* bytes_obj) {
if (input_.stringiobuf || input_.direct_buf) {
PyErr_SetString(PyExc_ValueError, "decode buffer is already initialized");
return false;
}

char* buf = nullptr;
Py_ssize_t len = 0;
if (PyBytes_AsStringAndSize(bytes_obj, &buf, &len) < 0) {
return false;
}

Py_INCREF(bytes_obj);
input_.direct_source.reset(bytes_obj);
input_.direct_buf = buf;
input_.direct_size = static_cast<size_t>(len);
input_.direct_pos = 0;
return true;
}

template <typename Impl>
bool ProtocolBase<Impl>::prepareEncodeBuffer() {
output_ = detail::new_encode_buffer(INIT_OUTBUF_SIZE);
Expand Down
7 changes: 7 additions & 0 deletions lib/py/src/ext/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,17 @@ class ScopedPyObject {
/**
* A cache of the two key attributes of a CReadableTransport,
* so we don't have to keep calling PyObject_GetAttr.
* Also supports reading directly from a bytes object.
*/
struct DecodeBuffer {
ScopedPyObject stringiobuf;
ScopedPyObject refill_callable;
ScopedPyObject direct_source;
const char* direct_buf;
size_t direct_size;
size_t direct_pos;

DecodeBuffer() : direct_buf(nullptr), direct_size(0), direct_pos(0) {}
};

#if PY_MAJOR_VERSION < 3
Expand Down
105 changes: 105 additions & 0 deletions lib/py/test/thrift_TBinaryProtocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import _import_local_thrift # noqa
from thrift.protocol.TBinaryProtocol import TBinaryProtocol
from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory
from thrift.protocol.TProtocol import TProtocolException
from thrift.transport import TTransport

Expand Down Expand Up @@ -194,8 +195,60 @@ def testMessage(data, strict=True):
return result


class SimpleStruct(object):
thrift_spec = (
None,
(1, 11, "name", "UTF8", None),
(2, 8, "value", None, None),
(3, 2, "flag", None, None),
)

def __init__(self, name=None, value=None, flag=None):
self.name = name
self.value = value
self.flag = flag

def write(self, oprot):
if oprot._fast_encode is not None and self.thrift_spec is not None:
oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
return

oprot.writeStructBegin("SimpleStruct")
if self.name is not None:
oprot.writeFieldBegin("name", 11, 1)
oprot.writeString(self.name)
oprot.writeFieldEnd()
if self.value is not None:
oprot.writeFieldBegin("value", 8, 2)
oprot.writeI32(self.value)
oprot.writeFieldEnd()
if self.flag is not None:
oprot.writeFieldBegin("flag", 2, 3)
oprot.writeBool(self.flag)
oprot.writeFieldEnd()
oprot.writeFieldStop()
oprot.writeStructEnd()

@classmethod
def read(cls, iprot):
if (
iprot._fast_decode is not None
and isinstance(iprot.trans, TTransport.CReadableTransport)
and cls.thrift_spec is not None
):
return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec])
return iprot.readStruct(cls, cls.thrift_spec, False)
Comment on lines +232 to +240


class TestTBinaryProtocol(unittest.TestCase):

def setUp(self):
try:
from thrift.protocol import fastbinary
self._fastbinary = fastbinary
except ImportError:
self._fastbinary = None

def test_TBinaryProtocol_write_read(self):
try:
testNaked('Byte', 123)
Expand Down Expand Up @@ -280,6 +333,58 @@ def test_TBinaryProtocol_write_read(self):
print("Assertion fail")
raise e

def _encode_accelerated_struct(self, value):
otrans = TTransport.TMemoryBuffer()
oproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(otrans)
value.write(oproto)
return otrans.getvalue()

def _decode_accelerated_struct(self, encoded):
itrans = TTransport.TMemoryBuffer(encoded)
iproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(itrans)
return SimpleStruct.read(iproto)

def test_decode_binary_from_bytes_matches_transport(self):
if self._fastbinary is None:
self.skipTest("C extension not available")

original = SimpleStruct(name="transport-free", value=42, flag=True)
encoded = self._encode_accelerated_struct(original)

decoded_transport = self._decode_accelerated_struct(encoded)
decoded_direct = self._fastbinary.decode_binary_from_bytes(
encoded,
[SimpleStruct, SimpleStruct.thrift_spec],
)

self.assertEqual(decoded_direct.name, decoded_transport.name)
self.assertEqual(decoded_direct.value, decoded_transport.value)
self.assertEqual(decoded_direct.flag, decoded_transport.flag)

def test_decode_binary_from_bytes_rejects_non_bytes(self):
if self._fastbinary is None:
self.skipTest("C extension not available")

with self.assertRaises(TypeError):
self._fastbinary.decode_binary_from_bytes(
"not-bytes",
[SimpleStruct, SimpleStruct.thrift_spec],
)

def test_decode_binary_from_bytes_rejects_truncated_input(self):
if self._fastbinary is None:
self.skipTest("C extension not available")

encoded = self._encode_accelerated_struct(
SimpleStruct(name="trim me", value=7, flag=False)
)

with self.assertRaises(EOFError):
self._fastbinary.decode_binary_from_bytes(
encoded[:-1],
[SimpleStruct, SimpleStruct.thrift_spec],
)

def test_TBinaryProtocol_no_strict_write_read(self):
TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4}
test_data = [("short message name", TMessageType['T_CALL'], 0),
Expand Down
Loading