diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 8974c9581f7..8b2363051c1 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -179,6 +179,7 @@ set(ARROW_FLIGHT_SRCS client_cookie_middleware.cc client_tracing_middleware.cc cookie_internal.cc + flight_data_decoder.cc middleware.cc serialization_internal.cc server.cc @@ -207,6 +208,21 @@ if(CMAKE_UNITY_BUILD AND WIN32) PROPERTIES SKIP_UNITY_BUILD_INCLUSION TRUE) endif() +# Suppress warnings from Abseil headers using deprecated in C++20. +# GCC 15+ with C++20 emits #warning which -Werror turns into an error. +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_STANDARD GREATER_EQUAL 20) + set(ARROW_FLIGHT_CXX20_WARNING_FLAGS "-Wno-cpp") + set_source_files_properties(server_tracing_middleware.cc + client_tracing_middleware.cc + transport/grpc/grpc_client.cc + transport/grpc/grpc_server.cc + transport/grpc/serialization_internal.cc + transport/grpc/protocol_grpc_internal.cc + transport/grpc/util_internal.cc + PROPERTIES COMPILE_OPTIONS + "${ARROW_FLIGHT_CXX20_WARNING_FLAGS}") +endif() + if(ARROW_WITH_OPENTELEMETRY) list(APPEND ARROW_FLIGHT_SRCS otel_logging.cc) endif() @@ -320,6 +336,13 @@ if(ARROW_TESTING) foreach(LIB_TARGET ${ARROW_FLIGHT_TESTING_LIBRARIES}) target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_FLIGHT_EXPORTING) endforeach() + + # Suppress Abseil warnings in testing library (GCC 15+ with C++20) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_STANDARD GREATER_EQUAL 20) + set_source_files_properties(test_auth_handlers.cc test_definitions.cc + test_flight_server.cc test_util.cc + PROPERTIES COMPILE_OPTIONS "-Wno-cpp") + endif() endif() add_arrow_test(flight_internals_test @@ -334,11 +357,39 @@ add_arrow_test(flight_test LABELS "arrow_flight") +# PoC: Async Flight server using gRPC generic callback API +if(ARROW_BUILD_TESTS) + add_arrow_test(async_grpc_poc_test + SOURCES + transport/grpc/async_grpc_poc_test.cc + STATIC_LINK_LIBS + ${ARROW_FLIGHT_TEST_LINK_LIBS} + LABELS + "arrow_flight") +endif() + +# Suppress Abseil warnings in test files (GCC 15+ with C++20) +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_STANDARD GREATER_EQUAL 20) + if(TARGET arrow-flight-internals-test) + target_compile_options(arrow-flight-internals-test PRIVATE "-Wno-cpp") + endif() + if(TARGET arrow-flight-test) + target_compile_options(arrow-flight-test PRIVATE "-Wno-cpp") + endif() + if(TARGET arrow-async-grpc-poc-test) + target_compile_options(arrow-async-grpc-poc-test PRIVATE "-Wno-cpp") + endif() +endif() + # Build test server for unit tests or benchmarks if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS) add_executable(flight-test-server test_server.cc) target_link_libraries(flight-test-server ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES}) + # Suppress Abseil warnings (GCC 15+ with C++20) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_STANDARD GREATER_EQUAL 20) + target_compile_options(flight-test-server PRIVATE "-Wno-cpp") + endif() if(ARROW_BUILD_TESTS) add_dependencies(arrow-flight-test flight-test-server) @@ -365,6 +416,12 @@ if(ARROW_BUILD_BENCHMARKS) target_link_libraries(arrow-flight-benchmark ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES}) + # Suppress Abseil warnings (GCC 15+ with C++20) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_STANDARD GREATER_EQUAL 20) + target_compile_options(arrow-flight-perf-server PRIVATE "-Wno-cpp") + target_compile_options(arrow-flight-benchmark PRIVATE "-Wno-cpp") + endif() + add_dependencies(arrow-flight-benchmark arrow-flight-perf-server) add_dependencies(arrow_flight arrow-flight-benchmark) diff --git a/cpp/src/arrow/flight/flight_data_decoder.cc b/cpp/src/arrow/flight/flight_data_decoder.cc new file mode 100644 index 00000000000..0d7aa37a287 --- /dev/null +++ b/cpp/src/arrow/flight/flight_data_decoder.cc @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/flight_data_decoder.h" + +#include "arrow/flight/serialization_internal.h" +#include "arrow/flight/transport.h" +#include "arrow/ipc/message.h" +#include "arrow/ipc/reader.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/status.h" + +namespace arrow { +namespace flight { + +namespace { + +// FlightDataMessageReader is an ipc::MessageReader that accepts one at a time +// FlightData messages. Analogous to MessageReader::Open(InputStream*) but for +// individual FlightData messages directly read from the received buffers. +class FlightDataMessageReader : public ipc::MessageReader { + public: + void Push(internal::FlightData data) { data_ = std::move(data); } + + ::arrow::Result> ReadNextMessage() override { + if (!data_.metadata) return nullptr; + return data_.OpenMessage(); + } + + std::shared_ptr ReadAppMetadata() { return data_.app_metadata; } + + private: + internal::FlightData data_; +}; + +} // namespace + +class FlightMessageDecoder::FlightMessageDecoderImpl { + public: + FlightMessageDecoderImpl(std::shared_ptr listener, + ipc::IpcReadOptions options) + : listener_(std::move(listener)), + options_(std::move(options)), + message_reader_(new FlightDataMessageReader()) {} + + Status Consume(std::shared_ptr buffer) { + ARROW_ASSIGN_OR_RAISE(auto data, internal::DeserializeFlightData(buffer)); + + if (!data.metadata) { + // Metadata-only message: no IPC content, just Flight app_metadata. + if (data.app_metadata && data.app_metadata->size() > 0) { + FlightStreamChunk chunk; + chunk.app_metadata = std::move(data.app_metadata); + RETURN_NOT_OK(listener_->OnNext(std::move(chunk))); + } + return Status::OK(); + } + + message_reader_->Push(std::move(data)); + + if (!batch_reader_) { + // Initialize RecordBatchStreamReader and read the first IPC message. + // It must be a schema. + // RecordBatchStreamReader requiring unique_ptr is slightly awkward + // since we want to keep a reference to the message reader. + ARROW_ASSIGN_OR_RAISE( + batch_reader_, + ipc::RecordBatchStreamReader::Open( + std::unique_ptr(message_reader_), options_)); + return listener_->OnSchemaDecoded(batch_reader_->schema()); + } + + std::shared_ptr batch; + RETURN_NOT_OK(batch_reader_->ReadNext(&batch)); + auto app_metadata = message_reader_->ReadAppMetadata(); + + if (batch) { + FlightStreamChunk chunk; + chunk.data = std::move(batch); + chunk.app_metadata = std::move(app_metadata); + return listener_->OnNext(std::move(chunk)); + } + // This has to be a Dictionary batch. + // TODO: Add unit test validating assumption. + if (app_metadata && app_metadata->size() > 0) { + FlightStreamChunk chunk; + chunk.app_metadata = std::move(app_metadata); + return listener_->OnNext(std::move(chunk)); + } + return Status::OK(); + } + + std::shared_ptr schema() const { + return batch_reader_ ? batch_reader_->schema() : nullptr; + } + + private: + std::shared_ptr listener_; + ipc::IpcReadOptions options_; + // This is owned by the RecordBatchStreamReader once it's passed to it. + // We want to keep a reference to it so we can extract the app_metadata. + FlightDataMessageReader* message_reader_; + std::shared_ptr batch_reader_; +}; + +FlightMessageDecoder::FlightMessageDecoder(std::shared_ptr listener, + ipc::IpcReadOptions options) + : impl_(std::make_unique(std::move(listener), + std::move(options))) {} + +FlightMessageDecoder::~FlightMessageDecoder() = default; + +Status FlightMessageDecoder::Consume(std::shared_ptr buffer) { + return impl_->Consume(std::move(buffer)); +} + +std::shared_ptr FlightMessageDecoder::schema() const { return impl_->schema(); } + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/flight_data_decoder.h b/cpp/src/arrow/flight/flight_data_decoder.h new file mode 100644 index 00000000000..9f949e6bb75 --- /dev/null +++ b/cpp/src/arrow/flight/flight_data_decoder.h @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/flight/types.h" +#include "arrow/flight/visibility.h" +#include "arrow/ipc/options.h" +#include "arrow/ipc/reader.h" +#include "arrow/result.h" +#include "arrow/status.h" + +namespace arrow { +namespace flight { + +/// \brief A general listener class to receive events from FlightMessageDecoder +/// +/// User must implement callback methods for interested events. +class ARROW_FLIGHT_EXPORT FlightDataListener : public ipc::Listener { + public: + /// \brief Called for each decoded FlightStreamChunk. + /// + /// chunk.data is the decoded RecordBatch, or nullptr for metadata-only + /// messages. + virtual Status OnNext(FlightStreamChunk chunk) = 0; +}; + +/// \brief Push style stream decoder that turns raw arrow Buffers into +/// FlightStreamChunks. +/// +/// This class decodes Apache Arrow Flight data format from arrow::Buffer +/// and fires events on the provided FlightDataListener. +class ARROW_FLIGHT_EXPORT FlightMessageDecoder { + public: + explicit FlightMessageDecoder( + std::shared_ptr listener, + ipc::IpcReadOptions options = ipc::IpcReadOptions::Defaults()); + ~FlightMessageDecoder(); + + /// \brief Decode one FlightData message directly from abuffer. + /// + /// Fires listener->OnSchemaDecoded() on the first message containing + /// a schema, listener->OnNext() for each subsequent record batch, + /// metadata-only message or dictionary batch. + /// + /// \param[in] buffer a raw buffer directly from the transport. Example + /// the arrow::Buffer extracted from the grpc::ByteBuffer from the gRPC transport. + /// \return Status + Status Consume(std::shared_ptr buffer); + + /// \brief The decoded schema. + /// + /// Available after the first Consume() call that contains a schema message. + /// Returns nullptr if no schema has been decoded yet. + std::shared_ptr schema() const; + + private: + class FlightMessageDecoderImpl; + std::unique_ptr impl_; +}; + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/serialization_internal.cc b/cpp/src/arrow/flight/serialization_internal.cc index 604375311d3..5ad4995ec06 100644 --- a/cpp/src/arrow/flight/serialization_internal.cc +++ b/cpp/src/arrow/flight/serialization_internal.cc @@ -17,18 +17,25 @@ #include "arrow/flight/serialization_internal.h" +#include #include #include #include +#include +#include +#include #include "arrow/buffer.h" #include "arrow/flight/protocol_internal.h" #include "arrow/io/memory.h" +#include "arrow/ipc/message.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/logging_internal.h" // Lambda helper & CTAD template @@ -612,6 +619,220 @@ Status ToProto(const CloseSessionResult& result, pb::CloseSessionResult* pb_resu return Status::OK(); } +namespace { + +using google::protobuf::internal::WireFormatLite; +using google::protobuf::io::ArrayOutputStream; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; + +static constexpr int64_t kInt32Max = std::numeric_limits::max(); +static const uint8_t kSerializePaddingBytes[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + +arrow::Status IpcMessageHeaderSize(const arrow::ipc::IpcPayload& ipc_msg, bool has_body, + size_t* header_size, int32_t* metadata_size) { + DCHECK_LE(ipc_msg.metadata->size(), kInt32Max); + *metadata_size = static_cast(ipc_msg.metadata->size()); + + // 1 byte for metadata tag + *header_size += 1 + WireFormatLite::LengthDelimitedSize(*metadata_size); + + // 2 bytes for body tag + if (has_body) { + // We write the body tag in the header but not the actual body data + *header_size += 2 + WireFormatLite::LengthDelimitedSize(ipc_msg.body_length) - + ipc_msg.body_length; + } + + return Status::OK(); +} + +bool ReadBytesZeroCopy(const std::shared_ptr& source_data, + CodedInputStream* input, std::shared_ptr* out) { + uint32_t length; + if (!input->ReadVarint32(&length)) { + return false; + } + auto buf = + SliceBuffer(source_data, input->CurrentPosition(), static_cast(length)); + *out = buf; + return input->Skip(static_cast(length)); +} + +} // namespace + +arrow::Result SerializePayloadToBuffers( + const arrow::flight::FlightPayload& msg) { + namespace pb = arrow::flight::protocol; + // Size of the IPC body (protobuf: data_body) + size_t body_size = 0; + // Size of the Protobuf "header" (everything except for the body) + size_t header_size = 0; + // Size of IPC header metadata (protobuf: data_header) + int32_t metadata_size = 0; + + // Write the descriptor if present + int32_t descriptor_size = 0; + if (msg.descriptor != nullptr) { + DCHECK_LE(msg.descriptor->size(), kInt32Max); + descriptor_size = static_cast(msg.descriptor->size()); + header_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size); + } + + // App metadata tag if appropriate + int32_t app_metadata_size = 0; + if (msg.app_metadata && msg.app_metadata->size() > 0) { + DCHECK_LE(msg.app_metadata->size(), kInt32Max); + app_metadata_size = static_cast(msg.app_metadata->size()); + header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size); + } + + const arrow::ipc::IpcPayload& ipc_msg = msg.ipc_message; + // No data in this payload (metadata-only). + bool has_ipc = ipc_msg.type != ipc::MessageType::NONE; + bool has_body = has_ipc ? ipc::Message::HasBody(ipc_msg.type) : false; + + if (has_ipc) { + DCHECK(has_body || ipc_msg.body_length == 0); + ARROW_RETURN_NOT_OK( + IpcMessageHeaderSize(ipc_msg, has_body, &header_size, &metadata_size)); + body_size = static_cast(ipc_msg.body_length); + } + + arrow::BufferVector buffers; + ARROW_ASSIGN_OR_RAISE(auto header_buf, arrow::AllocateBuffer(header_size)); + // Force the header_stream to be destructed, which actually flushes + // the data into the buffer. + { + ArrayOutputStream header_writer(header_buf->mutable_data(), + static_cast(header_size)); + CodedOutputStream header_stream(&header_writer); + + // Write descriptor + if (msg.descriptor != nullptr) { + WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(descriptor_size); + header_stream.WriteRawMaybeAliased(msg.descriptor->data(), + static_cast(msg.descriptor->size())); + } + + // Write header + if (has_ipc) { + WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(metadata_size); + header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(), + static_cast(ipc_msg.metadata->size())); + } + + // Write app metadata + if (app_metadata_size > 0) { + WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(app_metadata_size); + header_stream.WriteRawMaybeAliased(msg.app_metadata->data(), + static_cast(msg.app_metadata->size())); + } + if (has_body) { + // Write body tag + WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(static_cast(body_size)); + } + DCHECK_EQ(static_cast(header_size), header_stream.ByteCount()); + } + // Once header is written we just add the referenced buffers to the output BufferVector. + buffers.push_back(std::move(header_buf)); + + if (has_body) { + for (const auto& buffer : ipc_msg.body_buffers) { + if (!buffer || buffer->size() == 0) continue; + buffers.push_back(buffer); + const auto remainder = static_cast( + bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); + if (remainder) { + buffers.push_back( + std::make_shared(kSerializePaddingBytes, remainder)); + } + } + } + + return buffers; +} + +arrow::Result DeserializeFlightData( + const std::shared_ptr& in_buffer) { + arrow::flight::internal::FlightData out; + + if (!in_buffer) { + return {Status::Invalid("No payload")}; + } + + auto buffer_length = static_cast(in_buffer->size()); + CodedInputStream pb_stream(in_buffer->data(), buffer_length); + + pb_stream.SetTotalBytesLimit(buffer_length); + + // This is the bytes remaining when using CodedInputStream like this + while (pb_stream.BytesUntilTotalBytesLimit()) { + const uint32_t tag = pb_stream.ReadTag(); + const int field_number = WireFormatLite::GetTagFieldNumber(tag); + switch (field_number) { + case pb::FlightData::kFlightDescriptorFieldNumber: { + pb::FlightDescriptor pb_descriptor; + uint32_t length; + if (!pb_stream.ReadVarint32(&length)) { + return {Status::Invalid("Unable to parse length of FlightDescriptor")}; + } + // Can't use ParseFromCodedStream as this reads the entire + // rest of the stream into the descriptor command field. + std::string buffer; + pb_stream.ReadString(&buffer, length); + if (!pb_descriptor.ParseFromString(buffer)) { + return {Status::Invalid("Unable to parse FlightDescriptor")}; + } + arrow::flight::FlightDescriptor descriptor; + ARROW_RETURN_NOT_OK( + arrow::flight::internal::FromProto(pb_descriptor, &descriptor)); + out.descriptor = std::make_unique(descriptor); + } break; + case pb::FlightData::kDataHeaderFieldNumber: { + if (!ReadBytesZeroCopy(in_buffer, &pb_stream, &out.metadata)) { + return {Status::Invalid("Unable to read FlightData metadata")}; + } + } break; + case pb::FlightData::kAppMetadataFieldNumber: { + if (!ReadBytesZeroCopy(in_buffer, &pb_stream, &out.app_metadata)) { + return {Status::Invalid("Unable to read FlightData application metadata")}; + } + } break; + case pb::FlightData::kDataBodyFieldNumber: { + if (!ReadBytesZeroCopy(in_buffer, &pb_stream, &out.body)) { + return {Status::Invalid("Unable to read FlightData body")}; + } + } break; + default: { + // Unknown field. We should skip it for compatibility. + if (!WireFormatLite::SkipField(&pb_stream, tag)) { + return {Status::Invalid("Could not skip unknown field tag in FlightData")}; + } + break; + } + } + } + // TODO(wesm): Where and when should we verify that the FlightData is not + // malformed? + + // Set the default value for an unspecified FlightData body. The other + // fields can be null if they're unspecified. + if (out.body == nullptr) { + out.body = std::make_shared(nullptr, 0); + } + + return out; +} + } // namespace internal } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/serialization_internal.h b/cpp/src/arrow/flight/serialization_internal.h index 4d07efad815..ba5938e92ad 100644 --- a/cpp/src/arrow/flight/serialization_internal.h +++ b/cpp/src/arrow/flight/serialization_internal.h @@ -182,6 +182,21 @@ ARROW_FLIGHT_EXPORT Status ToProto(const CloseSessionResult& result, Status ToPayload(const FlightDescriptor& descr, std::shared_ptr* out); +/// \brief Serialize a FlightPayload to a vector of buffers. +/// +/// The first buffer contains the protobuf wire format header. Subsequent +/// buffers are zero-copy references to the IPC body buffers, with padding +/// buffers inserted as needed for 8-byte alignment. +ARROW_FLIGHT_EXPORT +arrow::Result SerializePayloadToBuffers( + const FlightPayload& payload); + +// TODO: we shouldn't be exporting FlightData. +/// \brief Deserialize FlightData from a contiguous buffer. +ARROW_FLIGHT_EXPORT +arrow::Result DeserializeFlightData( + const std::shared_ptr& buffer); + // We want to reuse RecordBatchStreamReader's implementation while // (1) Adapting it to the Flight message format // (2) Allowing pure-metadata messages before data is sent diff --git a/cpp/src/arrow/flight/transport.cc b/cpp/src/arrow/flight/transport.cc index fd74b6d95a5..3d65f5de24c 100644 --- a/cpp/src/arrow/flight/transport.cc +++ b/cpp/src/arrow/flight/transport.cc @@ -22,6 +22,8 @@ #include #include "arrow/flight/client_auth.h" +// TODO: We shouldn't be exposing this directly. +#include "arrow/flight/serialization_internal.h" #include "arrow/flight/transport_server.h" #include "arrow/flight/types.h" #include "arrow/flight/types_async.h" @@ -37,6 +39,12 @@ ::arrow::Result> FlightData::OpenMessage() { return ipc::Message::Open(metadata, body); } +// TODO: We shouldn't be exposing this directly. +::arrow::Result FlightData::DeserializeFrom( + const std::shared_ptr& buffer) { + return internal::DeserializeFlightData(buffer); +} + bool TransportDataStream::ReadData(internal::FlightData*) { return false; } arrow::Result TransportDataStream::WriteData(const FlightPayload&) { return Status::NotImplemented("Writing data for this stream"); diff --git a/cpp/src/arrow/flight/transport.h b/cpp/src/arrow/flight/transport.h index 4ce50534023..233635ec16c 100644 --- a/cpp/src/arrow/flight/transport.h +++ b/cpp/src/arrow/flight/transport.h @@ -91,6 +91,10 @@ struct FlightData { /// Open IPC message from the metadata and body ::arrow::Result> OpenMessage(); + + /// \brief Deserialize a FlightData from a contiguous buffer. + static ::arrow::Result DeserializeFrom( + const std::shared_ptr& buffer); }; /// \brief A transport-specific interface for reading/writing Arrow data. diff --git a/cpp/src/arrow/flight/transport/grpc/async_grpc_poc_test.cc b/cpp/src/arrow/flight/transport/grpc/async_grpc_poc_test.cc new file mode 100644 index 00000000000..401fbf32377 --- /dev/null +++ b/cpp/src/arrow/flight/transport/grpc/async_grpc_poc_test.cc @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/flight/client.h" +#include "arrow/flight/transport/grpc/async_grpc_server.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow::flight::transport::grpc { + +// Simple DoGet test. This creates a FlightServer and calls DoGet +// from a FlightClient. The server uses the async callback API and +// serializes FlightData directly from the ByteBuffer using FlightDataSerialize. +TEST(AsyncGrpcTest, BasicDoGet) { + FlightCallbackService service; + int port = 0; + ::grpc::ServerBuilder builder; + builder.AddListeningPort("localhost:0", ::grpc::InsecureServerCredentials(), &port); + builder.RegisterCallbackGenericService(&service); + auto server = builder.BuildAndStart(); + ASSERT_NE(server, nullptr); + + // Connect existing arrow::flight::FlightClient implementation + std::string uri = "grpc://localhost:" + std::to_string(port); + ASSERT_OK_AND_ASSIGN(auto location, arrow::flight::Location::Parse(uri)); + ASSERT_OK_AND_ASSIGN(auto client, arrow::flight::FlightClient::Connect(location)); + + // Call DoGet + arrow::flight::Ticket ticket{"test"}; + ASSERT_OK_AND_ASSIGN(auto reader, client->DoGet(ticket)); + + // Read schema + ASSERT_OK_AND_ASSIGN(auto schema, reader->GetSchema()); + ASSERT_EQ(schema->num_fields(), 2); + ASSERT_EQ(schema->field(0)->name(), "a"); + ASSERT_EQ(schema->field(1)->name(), "b"); + + // Read batches + int batch_count = 0; + while (true) { + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); + if (!chunk.data) break; + ASSERT_EQ(chunk.data->num_rows(), 5); + batch_count++; + } + ASSERT_EQ(batch_count, 5); + + // Cleanup + ASSERT_OK(client->Close()); + server->Shutdown(); + server->Wait(); +} + +// Simple DoPut test. This creates a FlightServer and calls DoPut +// from a FlightClient. The server uses the async callback API and +// deserializes FlightData directly from the ByteBuffer using FlightDataDeserialize. +TEST(AsyncGrpcTest, BasicDoPut) { + FlightCallbackService service; + int port = 0; + ::grpc::ServerBuilder builder; + builder.AddListeningPort("localhost:0", ::grpc::InsecureServerCredentials(), &port); + builder.RegisterCallbackGenericService(&service); + auto server = builder.BuildAndStart(); + ASSERT_NE(server, nullptr); + + // Connect existing arrow::flight::FlightClient implementation + std::string uri = "grpc://localhost:" + std::to_string(port); + ASSERT_OK_AND_ASSIGN(auto location, arrow::flight::Location::Parse(uri)); + ASSERT_OK_AND_ASSIGN(auto client, arrow::flight::FlightClient::Connect(location)); + + // Create test data + auto schema = arrow::schema( + {arrow::field("a", arrow::int64()), arrow::field("b", arrow::int64())}); + auto batch = + arrow::RecordBatch::Make(schema, 3, + {arrow::ArrayFromJSON(arrow::int64(), "[1, 2, 3]"), + arrow::ArrayFromJSON(arrow::int64(), "[10, 20, 30]")}); + + // Call DoPut + arrow::flight::FlightDescriptor descriptor = + arrow::flight::FlightDescriptor::Path({"test"}); + ASSERT_OK_AND_ASSIGN(auto result, client->DoPut(descriptor, schema)); + + // Send batches + ASSERT_OK(result.writer->WriteRecordBatch(*batch)); + ASSERT_OK(result.writer->WriteRecordBatch(*batch)); + ASSERT_OK(result.writer->Close()); + + // Cleanup + ASSERT_OK(client->Close()); + server->Shutdown(); + server->Wait(); +} +} // namespace arrow::flight::transport::grpc diff --git a/cpp/src/arrow/flight/transport/grpc/async_grpc_server.h b/cpp/src/arrow/flight/transport/grpc/async_grpc_server.h new file mode 100644 index 00000000000..cb669217a21 --- /dev/null +++ b/cpp/src/arrow/flight/transport/grpc/async_grpc_server.h @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Async gRPC-based. This is a PoC. + +#pragma once + +#include + +#include "arrow/array.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/flight/flight_data_decoder.h" +#include "arrow/flight/server.h" +#include "arrow/flight/transport/grpc/customize_grpc.h" +// Currently used for `SliceFromBuffer` and `WrapGrpcBuffer`. +#include "arrow/flight/transport/grpc/serialization_internal.h" +#include "arrow/record_batch.h" + +namespace arrow::flight::transport::grpc { + +namespace pb = arrow::flight::protocol; + +// DoGet using gRPC's generic callback API with ServerGenericBidiReactor. +class DoGetReactor : public ::grpc::ServerGenericBidiReactor { + public: + DoGetReactor() { StartRead(&request_buf_); } + + void OnReadDone(bool ok) override { + // Request has been read. + if (!ok) { + Finish(::grpc::Status(::grpc::StatusCode::INTERNAL, "Failed to read request")); + return; + } + + // DoGet request must contain the Ticket. + // TODO Parse ticket, we do not care about it in this PoC. + WriteNextPayload(); + } + + void OnWriteDone(bool ok) override { + // We have finished writing. We can write the next payload or finish the stream. + if (!ok) { + Finish(::grpc::Status(::grpc::StatusCode::INTERNAL, "Write failed")); + return; + } + WriteNextPayload(); + } + + void OnCancel() override { + // Client cancelled the RPC. We must implement this out of the PoC. + } + + void OnDone() override { delete this; } + + private: + void WriteNextPayload() { + FlightPayload payload; + if (data_stream_ == nullptr) { + auto schema = arrow::schema( + {arrow::field("a", arrow::int64()), arrow::field("b", arrow::int64())}); + arrow::Int64Builder builder_a, builder_b; + (void)builder_a.AppendValues({1, 2, 3, 4, 5}); + (void)builder_b.AppendValues({10, 20, 30, 40, 50}); + auto arr_a = *builder_a.Finish(); + auto arr_b = *builder_b.Finish(); + auto batch = arrow::RecordBatch::Make(schema, 5, {arr_a, arr_b}); + auto reader = + RecordBatchReader::Make({batch, batch, batch, batch, batch}).ValueOrDie(); + data_stream_ = std::make_unique(std::move(reader)); + payload = data_stream_->GetSchemaPayload().ValueOrDie(); + } else { + payload = data_stream_->Next().ValueOrDie(); + } + + if (payload.ipc_message.metadata == nullptr) { + Finish(::grpc::Status::OK); + return; + } + + auto buffers = payload.SerializeToBuffers().ValueOrDie(); + std::vector<::grpc::Slice> slices; + slices.reserve(buffers.size()); + for (const auto& buf : buffers) { + // Should we move this out of the internal files and expose as + // utility to the users? + slices.push_back(SliceFromBuffer(buf).ValueOrDie()); + } + write_buf_ = ::grpc::ByteBuffer(slices.data(), slices.size()); + + StartWrite(&write_buf_); + } + + ::grpc::ByteBuffer request_buf_; + ::grpc::ByteBuffer write_buf_; + std::unique_ptr data_stream_; +}; + +// A listener that counts received batches, only for PoC purposes. +class DoPutListener : public arrow::flight::FlightDataListener { + public: + arrow::Status OnSchemaDecoded(std::shared_ptr schema) override { + schema_ = std::move(schema); + return arrow::Status::OK(); + } + + arrow::Status OnNext(arrow::flight::FlightStreamChunk chunk) override { + if (chunk.data) { + ++batches_received_; + } + return arrow::Status::OK(); + } + + int batches_received() const { return batches_received_; } + + private: + std::shared_ptr schema_; + int batches_received_ = 0; +}; + +class DoPutReactor : public ::grpc::ServerGenericBidiReactor { + public: + DoPutReactor() : decoder_(std::make_shared()) { + StartRead(&request_buf_); + } + + void OnReadDone(bool ok) override { + // Request has been read. + if (!ok) { + // End of stream, ack completion. + pb::PutResult pb_result; + ::grpc::Slice slice(pb_result.SerializeAsString()); + // Not use a local variable for the ByteBuffer as + // StartWrite requires the buffer to remain valid until OnWriteDone. + write_buf_ = ::grpc::ByteBuffer(&slice, 1); + StartWrite(&write_buf_); + return; + } + + // Extract Arrow buffers from the gRPC ByteBuffer, then feed it to + // the FlightMessageDecoder which fires Listener callbacks + // (OnSchemaDecoded / OnNext). + std::shared_ptr arrow_buf; + // TODO: What do we do with the WrapGrpcBuffer? This is internal and + // we don't want to expose it but it's quite complex to leave it to the + // user to implement. Similar to SliceFromBuffer, should we move this out + // of the serialization_internal file and expose as a utility to the user? + auto wrap_status = WrapGrpcBuffer(&request_buf_, &arrow_buf); + if (!wrap_status.ok()) { + Finish(::grpc::Status(::grpc::StatusCode::INTERNAL, + "Failed to wrap gRPC buffer: " + wrap_status.message())); + return; + } + + // Push the buffer into the decoder, which will fire the appropriate calls to the + // listener. + auto decode_status = decoder_.Consume(std::move(arrow_buf)); + if (!decode_status.ok()) { + Finish(::grpc::Status(::grpc::StatusCode::INTERNAL, + "Failed to decode Arrow buffer: " + decode_status.message())); + return; + } + // Read next FlightData + StartRead(&request_buf_); + } + + void OnWriteDone(bool ok) override { + // We don't really write on DoPut. + if (!ok) { + Finish(::grpc::Status(::grpc::StatusCode::INTERNAL, "Write failed")); + return; + } + Finish(::grpc::Status::OK); + } + + void OnCancel() override { + // Client cancelled the RPC. We must implement this out of the PoC. + } + + void OnDone() override { delete this; } + + private: + arrow::flight::FlightMessageDecoder decoder_; + ::grpc::ByteBuffer request_buf_; + ::grpc::ByteBuffer write_buf_; +}; + +class FlightCallbackService : public ::grpc::CallbackGenericService { + public: + ::grpc::ServerGenericBidiReactor* CreateReactor( + ::grpc::GenericCallbackServerContext* ctx) override { + const std::string& method = ctx->method(); + if (method == "/arrow.flight.protocol.FlightService/DoGet") { + return new DoGetReactor(); + } + if (method == "/arrow.flight.protocol.FlightService/DoPut") { + return new DoPutReactor(); + } + // Reject unknown methods + class Unimplemented : public ::grpc::ServerGenericBidiReactor { + public: + Unimplemented() { Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); } + void OnDone() override { delete this; } + }; + return new Unimplemented(); + } +}; + +} // namespace arrow::flight::transport::grpc diff --git a/cpp/src/arrow/flight/transport/grpc/serialization_internal.cc b/cpp/src/arrow/flight/transport/grpc/serialization_internal.cc index 0b8c90a08eb..42cd734f0d4 100644 --- a/cpp/src/arrow/flight/transport/grpc/serialization_internal.cc +++ b/cpp/src/arrow/flight/transport/grpc/serialization_internal.cc @@ -32,10 +32,6 @@ # pragma warning(disable : 4267) #endif -#include -#include -#include - #include #include #include @@ -49,9 +45,7 @@ #include "arrow/flight/serialization_internal.h" #include "arrow/flight/transport.h" #include "arrow/flight/transport/grpc/util_internal.h" -#include "arrow/ipc/message.h" #include "arrow/ipc/writer.h" -#include "arrow/util/bit_util.h" #include "arrow/util/logging_internal.h" namespace arrow { @@ -62,27 +56,11 @@ namespace grpc { namespace pb = arrow::flight::protocol; static constexpr int64_t kInt32Max = std::numeric_limits::max(); -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::ArrayOutputStream; -using google::protobuf::io::CodedInputStream; -using google::protobuf::io::CodedOutputStream; using ::grpc::ByteBuffer; namespace { -bool ReadBytesZeroCopy(const std::shared_ptr& source_data, - CodedInputStream* input, std::shared_ptr* out) { - uint32_t length; - if (!input->ReadVarint32(&length)) { - return false; - } - auto buf = - SliceBuffer(source_data, input->CurrentPosition(), static_cast(length)); - *out = buf; - return input->Skip(static_cast(length)); -} - // Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow // consumers with zero-copy class GrpcBuffer : public MutableBuffer { @@ -157,6 +135,8 @@ void ReleaseBuffer(void* buf_ptr) { delete reinterpret_cast*>(buf_ptr); } +} // namespace + // Initialize gRPC Slice from arrow Buffer arrow::Result<::grpc::Slice> SliceFromBuffer(const std::shared_ptr& buf) { // Allocate persistent shared_ptr to control Buffer lifetime @@ -176,146 +156,42 @@ arrow::Result<::grpc::Slice> SliceFromBuffer(const std::shared_ptr& buf) return slice; } -const uint8_t kPaddingBytes[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - -// Update the sizes of our Protobuf fields based on the given IPC payload. -::grpc::Status IpcMessageHeaderSize(const arrow::ipc::IpcPayload& ipc_msg, bool has_body, - size_t* header_size, int32_t* metadata_size) { - DCHECK_LE(ipc_msg.metadata->size(), kInt32Max); - *metadata_size = static_cast(ipc_msg.metadata->size()); - - // 1 byte for metadata tag - *header_size += 1 + WireFormatLite::LengthDelimitedSize(*metadata_size); - - // 2 bytes for body tag - if (has_body) { - // We write the body tag in the header but not the actual body data - *header_size += 2 + WireFormatLite::LengthDelimitedSize(ipc_msg.body_length) - - ipc_msg.body_length; - } - - return ::grpc::Status::OK; -} - -} // namespace - ::grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, bool* own_buffer) { - // Size of the IPC body (protobuf: data_body) - size_t body_size = 0; - // Size of the Protobuf "header" (everything except for the body) - size_t header_size = 0; - // Size of IPC header metadata (protobuf: data_header) - int32_t metadata_size = 0; - - // Write the descriptor if present - int32_t descriptor_size = 0; - if (msg.descriptor != nullptr) { - DCHECK_LE(msg.descriptor->size(), kInt32Max); - descriptor_size = static_cast(msg.descriptor->size()); - header_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size); - } - - // App metadata tag if appropriate - int32_t app_metadata_size = 0; - if (msg.app_metadata && msg.app_metadata->size() > 0) { - DCHECK_LE(msg.app_metadata->size(), kInt32Max); - app_metadata_size = static_cast(msg.app_metadata->size()); - header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size); - } - - const arrow::ipc::IpcPayload& ipc_msg = msg.ipc_message; - // No data in this payload (metadata-only). - bool has_ipc = ipc_msg.type != ipc::MessageType::NONE; - bool has_body = has_ipc ? ipc::Message::HasBody(ipc_msg.type) : false; - - if (has_ipc) { - DCHECK(has_body || ipc_msg.body_length == 0); - GRPC_RETURN_NOT_GRPC_OK( - IpcMessageHeaderSize(ipc_msg, has_body, &header_size, &metadata_size)); - body_size = static_cast(ipc_msg.body_length); - } - // TODO(wesm): messages over 2GB unlikely to be yet supported // Validated in WritePayload since returning error here causes gRPC to fail an assertion - DCHECK_LE(body_size, kInt32Max); - - // Allocate and initialize slices - std::vector<::grpc::Slice> slices; - slices.emplace_back(header_size); - - // Force the header_stream to be destructed, which actually flushes - // the data into the slice. - { - ArrayOutputStream header_writer(const_cast(slices[0].begin()), - static_cast(slices[0].size())); - CodedOutputStream header_stream(&header_writer); - - // Write descriptor - if (msg.descriptor != nullptr) { - WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); - header_stream.WriteVarint32(descriptor_size); - header_stream.WriteRawMaybeAliased(msg.descriptor->data(), - static_cast(msg.descriptor->size())); - } + DCHECK_LE(msg.ipc_message.body_length, kInt32Max); - // Write header - if (has_ipc) { - WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); - header_stream.WriteVarint32(metadata_size); - header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(), - static_cast(ipc_msg.metadata->size())); - } - - // Write app metadata - if (app_metadata_size > 0) { - WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); - header_stream.WriteVarint32(app_metadata_size); - header_stream.WriteRawMaybeAliased(msg.app_metadata->data(), - static_cast(msg.app_metadata->size())); - } + // Retrieve BufferVector from the FlightPayload's IPC message. + auto buffers_result = msg.SerializeToBuffers(); + if (!buffers_result.ok()) { + return ToGrpcStatus(buffers_result.status()); + } - if (has_body) { - // Write body tag - WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); - header_stream.WriteVarint32(static_cast(body_size)); - - // Enqueue body buffers for writing, without copying - for (const auto& buffer : ipc_msg.body_buffers) { - // Buffer may be null when the row length is zero, or when all - // entries are invalid. - if (!buffer || buffer->size() == 0) continue; - - ::grpc::Slice slice; - auto status = SliceFromBuffer(buffer).Value(&slice); - if (ARROW_PREDICT_FALSE(!status.ok())) { - // This will likely lead to abort as gRPC cannot recover from an error here - return ToGrpcStatus(status); - } - slices.push_back(std::move(slice)); - - // Write padding if not multiple of 8 - const auto remainder = static_cast( - bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); - if (remainder) { - slices.emplace_back(kPaddingBytes, remainder); - } - } + std::vector<::grpc::Slice> slices; + slices.reserve(buffers_result->size()); + for (const auto& buffer : *buffers_result) { + ::grpc::Slice slice; + auto status = SliceFromBuffer(buffer).Value(&slice); + if (ARROW_PREDICT_FALSE(!status.ok())) { + // This will likely lead to abort as gRPC cannot recover from an error here + return ToGrpcStatus(status); } - - DCHECK_EQ(static_cast(header_size), header_stream.ByteCount()); + slices.push_back(std::move(slice)); } - // Hand off the slices to the returned ByteBuffer - *out = ::grpc::ByteBuffer(slices.data(), slices.size()); + *out = ByteBuffer(slices.data(), slices.size()); *own_buffer = true; return ::grpc::Status::OK; } +arrow::Status WrapGrpcBuffer(::grpc::ByteBuffer* grpc_buf, + std::shared_ptr* out) { + ARROW_RETURN_NOT_OK(GrpcBuffer::Wrap(grpc_buf, out)); + grpc_buf->Clear(); + return arrow::Status::OK(); +} + // Read internal::FlightData from grpc::ByteBuffer containing FlightData // protobuf without copying ::grpc::Status FlightDataDeserialize(ByteBuffer* buffer, @@ -324,81 +200,16 @@ ::grpc::Status FlightDataDeserialize(ByteBuffer* buffer, return {::grpc::StatusCode::INTERNAL, "No payload"}; } - // Reset fields in case the caller reuses a single allocation - out->descriptor = nullptr; - out->app_metadata = nullptr; - out->metadata = nullptr; - out->body = nullptr; - std::shared_ptr wrapped_buffer; GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer)); - - auto buffer_length = static_cast(wrapped_buffer->size()); - CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); - - pb_stream.SetTotalBytesLimit(buffer_length); - - // This is the bytes remaining when using CodedInputStream like this - while (pb_stream.BytesUntilTotalBytesLimit()) { - const uint32_t tag = pb_stream.ReadTag(); - const int field_number = WireFormatLite::GetTagFieldNumber(tag); - switch (field_number) { - case pb::FlightData::kFlightDescriptorFieldNumber: { - pb::FlightDescriptor pb_descriptor; - uint32_t length; - if (!pb_stream.ReadVarint32(&length)) { - return {::grpc::StatusCode::INTERNAL, - "Unable to parse length of FlightDescriptor"}; - } - // Can't use ParseFromCodedStream as this reads the entire - // rest of the stream into the descriptor command field. - std::string buffer; - pb_stream.ReadString(&buffer, length); - if (!pb_descriptor.ParseFromString(buffer)) { - return {::grpc::StatusCode::INTERNAL, "Unable to parse FlightDescriptor"}; - } - arrow::flight::FlightDescriptor descriptor; - GRPC_RETURN_NOT_OK( - arrow::flight::internal::FromProto(pb_descriptor, &descriptor)); - out->descriptor = std::make_unique(descriptor); - } break; - case pb::FlightData::kDataHeaderFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { - return {::grpc::StatusCode::INTERNAL, "Unable to read FlightData metadata"}; - } - } break; - case pb::FlightData::kAppMetadataFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->app_metadata)) { - return {::grpc::StatusCode::INTERNAL, - "Unable to read FlightData application metadata"}; - } - } break; - case pb::FlightData::kDataBodyFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { - return {::grpc::StatusCode::INTERNAL, "Unable to read FlightData body"}; - } - } break; - default: { - // Unknown field. We should skip it for compatibility. - if (!WireFormatLite::SkipField(&pb_stream, tag)) { - return {::grpc::StatusCode::INTERNAL, - "Could not skip unknown field tag in FlightData"}; - } - break; - } - } - } + // Release gRPC memory now that Arrow Buffer holds its own reference. buffer->Clear(); - // TODO(wesm): Where and when should we verify that the FlightData is not - // malformed? - - // Set the default value for an unspecified FlightData body. The other - // fields can be null if they're unspecified. - if (out->body == nullptr) { - out->body = std::make_shared(nullptr, 0); + auto result = arrow::flight::internal::DeserializeFlightData(wrapped_buffer); + if (!result.ok()) { + return ToGrpcStatus(result.status()); } - + *out = result.MoveValueUnsafe(); return ::grpc::Status::OK; } diff --git a/cpp/src/arrow/flight/transport/grpc/serialization_internal.h b/cpp/src/arrow/flight/transport/grpc/serialization_internal.h index 5c347fd4f81..0c7ca5867e6 100644 --- a/cpp/src/arrow/flight/transport/grpc/serialization_internal.h +++ b/cpp/src/arrow/flight/transport/grpc/serialization_internal.h @@ -22,6 +22,7 @@ #include +#include "arrow/buffer.h" #include "arrow/flight/protocol_internal.h" #include "arrow/flight/transport/grpc/protocol_grpc_internal.h" #include "arrow/flight/type_fwd.h" @@ -34,6 +35,13 @@ namespace grpc { namespace pb = arrow::flight::protocol; +/// Convert an Arrow Buffer to a gRPC Slice. +arrow::Result<::grpc::Slice> SliceFromBuffer(const std::shared_ptr& buf); + +/// Wrap a gRPC ByteBuffer as a zero-copy Arrow Buffer (and clear the ByteBuffer). +arrow::Status WrapGrpcBuffer(::grpc::ByteBuffer* grpc_buf, + std::shared_ptr* out); + /// Write Flight message on gRPC stream with zero-copy optimizations. // Returns Invalid if the payload is ill-formed // Returns true if the payload was written, false if it was not diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 8166513d4e3..a597c9d194e 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -886,6 +886,10 @@ Status FlightPayload::Validate() const { return Status::OK(); } +arrow::Result FlightPayload::SerializeToBuffers() const { + return internal::SerializePayloadToBuffers(*this); +} + std::string ActionType::ToString() const { return arrow::internal::JoinToString(""); diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index d498ac67f7a..e4ac975b16b 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -904,6 +904,13 @@ struct ARROW_FLIGHT_EXPORT FlightPayload { /// \brief Check that the payload can be written to the wire. Status Validate() const; + + /// \brief Serialize this payload to a vector of buffers. + /// + /// The first buffer contains the protobuf wire format header. Subsequent + /// buffers are zero-copy references to the IPC body buffers, with padding + /// buffers inserted as needed for 8-byte alignment. + arrow::Result SerializeToBuffers() const; }; // A wrapper around arrow.flight.protocol.PutResult is not defined