diff --git a/benchmarks/micro/bench_checksumming_inline.py b/benchmarks/micro/bench_checksumming_inline.py new file mode 100644 index 0000000000..3bd7ba1804 --- /dev/null +++ b/benchmarks/micro/bench_checksumming_inline.py @@ -0,0 +1,57 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Micro-benchmark: inline checksumming check vs classmethod call. + +Measures the overhead of ProtocolVersion.has_checksumming_support() +classmethod call versus an inline integer comparison on the +encode/decode hot path. + +Run: + python benchmarks/bench_checksumming_inline.py +""" + +import sys +import timeit + +from cassandra import ProtocolVersion +from cassandra.protocol import _CHECKSUMMING_MIN_VERSION, _CHECKSUMMING_MAX_VERSION + + +def bench(): + protocol_version = ProtocolVersion.V4 + + def via_classmethod(): + return ProtocolVersion.has_checksumming_support(protocol_version) + + def via_inline(): + return _CHECKSUMMING_MIN_VERSION <= protocol_version < _CHECKSUMMING_MAX_VERSION + + n = 5_000_000 + t_classmethod = timeit.timeit(via_classmethod, number=n) + t_inline = timeit.timeit(via_inline, number=n) + + saving_ns = (t_classmethod - t_inline) / n * 1e9 + speedup = t_classmethod / t_inline if t_inline > 0 else float('inf') + + print(f"=== has_checksumming_support ({n:,} iters) ===") + print(f" classmethod call: {t_classmethod / n * 1e9:.1f} ns") + print(f" inline compare: {t_inline / n * 1e9:.1f} ns") + print(f" saving: {saving_ns:.1f} ns/call ({speedup:.1f}x)") + + +if __name__ == "__main__": + print(f"Python {sys.version}") + bench() diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 4628c7ee0e..43e701f448 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -69,6 +69,12 @@ class InternalError(Exception): _UNSET_VALUE = object() +# Inline constants for has_checksumming_support check, avoiding +# ProtocolVersion.has_checksumming_support() classmethod call overhead +# (~94 ns per call) on the encode/decode hot path. +_CHECKSUMMING_MIN_VERSION = ProtocolVersion.V5 +_CHECKSUMMING_MAX_VERSION = ProtocolVersion.DSE_V1 + def register_class(cls): _message_types_by_opcode[cls.opcode] = cls @@ -1098,32 +1104,33 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta flags |= USE_BETA_FLAG buff = io.BytesIO() - buff.seek(9) # With checksumming, the compression is done at the segment frame encoding - if (compressor and not ProtocolVersion.has_checksumming_support(protocol_version)): - body = io.BytesIO() + if (compressor and not (_CHECKSUMMING_MIN_VERSION <= protocol_version < _CHECKSUMMING_MAX_VERSION)): if msg.custom_payload: - write_bytesmap(body, msg.custom_payload) - msg.send_body(body, protocol_version) - body = body.getvalue() + write_bytesmap(buff, msg.custom_payload) + msg.send_body(buff, protocol_version) + body = buff.getvalue() if len(body) > 0: body = compressor(body) flags |= COMPRESSED_FLAG - buff.write(body) length = len(body) + header = v3_header_pack(protocol_version, flags, stream_id, msg.opcode) + int32_pack(length) + return header + body else: + buff.seek(9) + if msg.custom_payload: write_bytesmap(buff, msg.custom_payload) msg.send_body(buff, protocol_version) length = buff.tell() - 9 - buff.seek(0) - cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, length) - return buff.getvalue() + buff.seek(0) + cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, length) + return buff.getvalue() @staticmethod def _write_header(f, version, flags, stream_id, opcode, length): @@ -1148,7 +1155,7 @@ def decode_message(cls, protocol_version, protocol_features, user_type_map, stre :param decompressor: optional decompression function to inflate the body :return: a message decoded from the body and frame attributes """ - if (not ProtocolVersion.has_checksumming_support(protocol_version) and + if (not (_CHECKSUMMING_MIN_VERSION <= protocol_version < _CHECKSUMMING_MAX_VERSION) and flags & COMPRESSED_FLAG): if decompressor is None: raise RuntimeError("No de-compressor available for compressed frame!")