Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions cpp/src/arrow/util/byte_stream_split_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,14 +421,28 @@ void ByteStreamSplitEncodeScalar(const uint8_t* raw_values, int width,
DoSplitStreams(raw_values, kNumStreams, num_values, dest_streams.data());
}

// If changing this value, please check that TestByteStreamSplitLargeWidth still
// exercises the slow path.
constexpr inline int kByteStreamSplitMaxTemporaryAlloc = 8192;

inline void ByteStreamSplitEncodeScalarDynamic(const uint8_t* raw_values, int width,
const int64_t num_values, uint8_t* out) {
::arrow::internal::SmallVector<uint8_t*, 16> dest_streams;
dest_streams.resize(width);
for (int stream = 0; stream < width; ++stream) {
dest_streams[stream] = &out[stream * num_values];
if (ARROW_PREDICT_TRUE(width < kByteStreamSplitMaxTemporaryAlloc / 8)) {
::arrow::internal::SmallVector<uint8_t*, 32> dest_streams;
Comment thread
pitrou marked this conversation as resolved.
dest_streams.resize(width);
for (int stream = 0; stream < width; ++stream) {
dest_streams[stream] = &out[stream * num_values];
}
DoSplitStreams(raw_values, width, num_values, dest_streams.data());
} else {
// Slow path to avoid an oversized `dest_streams` container above.
for (int stream = 0; stream < width; ++stream) {
uint8_t* dest_stream = &out[stream * num_values];
for (int64_t i = 0; i < num_values; ++i) {
dest_stream[i] = raw_values[stream + i * width];
}
}
}
DoSplitStreams(raw_values, width, num_values, dest_streams.data());
}

template <int kNumStreams>
Expand All @@ -445,12 +459,22 @@ void ByteStreamSplitDecodeScalar(const uint8_t* data, int width, int64_t num_val
inline void ByteStreamSplitDecodeScalarDynamic(const uint8_t* data, int width,
int64_t num_values, int64_t stride,
uint8_t* out) {
::arrow::internal::SmallVector<const uint8_t*, 16> src_streams;
src_streams.resize(width);
for (int stream = 0; stream < width; ++stream) {
src_streams[stream] = &data[stream * stride];
if (ARROW_PREDICT_TRUE(width < kByteStreamSplitMaxTemporaryAlloc / 8)) {
::arrow::internal::SmallVector<const uint8_t*, 32> src_streams;
src_streams.resize(width);
for (int stream = 0; stream < width; ++stream) {
src_streams[stream] = &data[stream * stride];
}
DoMergeStreams(src_streams.data(), width, num_values, out);
} else {
// Slow path to avoid an oversized `src_streams` container above.
for (int stream = 0; stream < width; ++stream) {
const uint8_t* src_stream = &data[stream * stride];
for (int64_t i = 0; i < num_values; ++i) {
out[stream + i * width] = src_stream[i];
}
}
}
DoMergeStreams(src_streams.data(), width, num_values, out);
}

template <int kNumStreams>
Expand Down
11 changes: 10 additions & 1 deletion cpp/src/arrow/util/byte_stream_split_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class TestByteStreamSplitSpecialized : public ::testing::Test {
TYPED_TEST_SUITE(TestByteStreamSplitSpecialized, ByteStreamSplitTypes);

TYPED_TEST(TestByteStreamSplitSpecialized, RoundtripSmall) {
for (int64_t num_values : {1, 5, 7, 12, 19, 31, 32}) {
for (int64_t num_values : {0, 1, 5, 7, 12, 19, 31, 32}) {
this->TestRoundtrip(num_values);
}
}
Expand All @@ -210,4 +210,13 @@ TYPED_TEST(TestByteStreamSplitSpecialized, PiecewiseDecode) {
this->TestPiecewiseDecode(/*num_values=*/500);
}

class TestByteStreamSplitLargeWidth
: public TestByteStreamSplitSpecialized<std::array<uint8_t, 3000>> {};

TEST_F(TestByteStreamSplitLargeWidth, Roundtrip) {
for (int64_t num_values : {0, 1, 5, 100}) {
this->TestRoundtrip(num_values);
}
}

} // namespace arrow::util::internal
12 changes: 7 additions & 5 deletions cpp/src/parquet/decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2307,11 +2307,13 @@ class ByteStreamSplitDecoderBase : public TypedDecoderImpl<DType> {
protected:
int DecodeRaw(uint8_t* out_buffer, int max_values) {
const int values_to_decode = std::min(this->num_values_, max_values);
::arrow::util::internal::ByteStreamSplitDecode(this->data_, this->type_length_,
values_to_decode, stride_, out_buffer);
this->data_ += values_to_decode;
this->num_values_ -= values_to_decode;
this->len_ -= this->type_length_ * values_to_decode;
if (ARROW_PREDICT_TRUE(values_to_decode > 0)) {
::arrow::util::internal::ByteStreamSplitDecode(
this->data_, this->type_length_, values_to_decode, stride_, out_buffer);
this->data_ += values_to_decode;
this->num_values_ -= values_to_decode;
this->len_ -= this->type_length_ * values_to_decode;
}
return values_to_decode;
}

Expand Down
10 changes: 6 additions & 4 deletions cpp/src/parquet/encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -875,10 +875,12 @@ class ByteStreamSplitEncoderBase : public EncoderImpl,
return buf;
}
auto output_buffer = AllocateBuffer(this->memory_pool(), EstimatedDataEncodedSize());
uint8_t* output_buffer_raw = output_buffer->mutable_data();
const uint8_t* raw_values = sink_.data();
::arrow::util::internal::ByteStreamSplitEncode(
raw_values, /*width=*/byte_width_, num_values_in_buffer_, output_buffer_raw);
if (num_values_in_buffer_ > 0) {
uint8_t* output_buffer_raw = output_buffer->mutable_data();
const uint8_t* raw_values = sink_.data();
::arrow::util::internal::ByteStreamSplitEncode(
raw_values, /*width=*/byte_width_, num_values_in_buffer_, output_buffer_raw);
}
sink_.Reset();
num_values_in_buffer_ = 0;
return output_buffer;
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/parquet/encoding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1705,7 +1705,7 @@ TYPED_TEST(TestByteStreamSplitEncoding, RoundTripSpace) {

for (auto null_prob : {0.001, 0.1, 0.5, 0.9, 0.999}) {
// Test with both size and offset up to 3 Simd block
for (auto i = 1; i < kSimdSize * 3; i++) {
for (auto i = 0; i < kSimdSize * 3; i++) {
ASSERT_NO_FATAL_FAILURE(this->ExecuteSpaced(i, 1, 0, null_prob));
ASSERT_NO_FATAL_FAILURE(this->ExecuteSpaced(i, 1, i + 1, null_prob));
}
Expand Down
Loading