diff --git a/cpp/src/arrow/dataset/file_orc.cc b/cpp/src/arrow/dataset/file_orc.cc index 1393df57f9d..b1a435da6f6 100644 --- a/cpp/src/arrow/dataset/file_orc.cc +++ b/cpp/src/arrow/dataset/file_orc.cc @@ -69,13 +69,13 @@ class OrcScanTask { struct Impl { static Result Make(const FileSource& source, const FileFormat& format, - const ScanOptions& scan_options) { + const ScanOptions& scan_options, + const std::vector& stripe_ids) { ARROW_ASSIGN_OR_RAISE( auto reader, OpenORCReader(source, std::make_shared(scan_options))); auto materialized_fields = scan_options.MaterializedFields(); - // filter out virtual columns std::vector included_fields; ARROW_ASSIGN_OR_RAISE(auto schema, reader->ReadSchema()); for (const auto& ref : materialized_fields) { @@ -85,26 +85,89 @@ class OrcScanTask { included_fields.push_back(schema->field(match.indices()[0])->name()); } - std::shared_ptr record_batch_reader; - ARROW_ASSIGN_OR_RAISE( - record_batch_reader, - reader->GetRecordBatchReader(scan_options.batch_size, included_fields)); + if (stripe_ids.empty()) { + std::shared_ptr record_batch_reader; + ARROW_ASSIGN_OR_RAISE( + record_batch_reader, + reader->GetRecordBatchReader(scan_options.batch_size, included_fields)); + + return RecordBatchIterator(Impl{std::move(record_batch_reader)}); + } - return RecordBatchIterator(Impl{std::move(record_batch_reader)}); + std::vector included_indices; + for (const auto& field_name : included_fields) { + int idx = schema->GetFieldIndex(field_name); + if (idx >= 0) { + included_indices.push_back(idx); + } + } + + return RecordBatchIterator( + Impl{std::move(reader), stripe_ids, std::move(included_indices), + scan_options.batch_size}); } + explicit Impl(std::shared_ptr reader) + : record_batch_reader_(std::move(reader)), stripe_mode_(false) {} + + Impl(std::unique_ptr reader, + std::vector stripe_ids, std::vector included_indices, + int64_t batch_size) + : orc_reader_(std::move(reader)), + stripe_ids_(std::move(stripe_ids)), + included_indices_(std::move(included_indices)), + batch_size_(batch_size), + current_stripe_idx_(0), + stripe_mode_(true) {} + Result> Next() { - std::shared_ptr batch; - RETURN_NOT_OK(record_batch_reader_->ReadNext(&batch)); - return batch; + if (!stripe_mode_) { + std::shared_ptr batch; + RETURN_NOT_OK(record_batch_reader_->ReadNext(&batch)); + return batch; + } + + while (true) { + if (record_batch_reader_) { + std::shared_ptr batch; + RETURN_NOT_OK(record_batch_reader_->ReadNext(&batch)); + if (batch) { + return batch; + } + record_batch_reader_.reset(); + current_stripe_idx_++; + } + + if (current_stripe_idx_ >= stripe_ids_.size()) { + return nullptr; + } + + int64_t stripe_id = stripe_ids_[current_stripe_idx_]; + auto stripe_info = orc_reader_->GetStripeInformation(stripe_id); + RETURN_NOT_OK(orc_reader_->Seek(stripe_info.first_row_id)); + ARROW_ASSIGN_OR_RAISE( + record_batch_reader_, + orc_reader_->NextStripeReader(batch_size_, included_indices_)); + } } std::shared_ptr record_batch_reader_; + std::unique_ptr orc_reader_; + std::vector stripe_ids_; + std::vector included_indices_; + int64_t batch_size_; + size_t current_stripe_idx_; + bool stripe_mode_; }; + std::vector stripe_ids; + if (auto* orc_fragment = dynamic_cast(fragment_.get())) { + stripe_ids = orc_fragment->stripe_ids(); + } + return Impl::Make(fragment_->source(), *checked_pointer_cast(fragment_)->format(), - *options_); + *options_, stripe_ids); } private: @@ -208,6 +271,15 @@ Future> OrcFileFormat::CountRows( return DeferNotOk(options->io_context.executor()->Submit( [self, file]() -> Result> { ARROW_ASSIGN_OR_RAISE(auto reader, OpenORCReader(file->source())); + auto* orc_fragment = dynamic_cast(file.get()); + if (orc_fragment && !orc_fragment->stripe_ids().empty()) { + int64_t count = 0; + for (int stripe_id : orc_fragment->stripe_ids()) { + auto stripe_info = reader->GetStripeInformation(stripe_id); + count += stripe_info.num_rows; + } + return count; + } return reader->NumberOfRows(); })); } @@ -229,5 +301,52 @@ Result> OrcFileFormat::MakeWriter( return Status::NotImplemented("ORC writer not yet implemented."); } +// OrcFileFragment + +OrcFileFragment::OrcFileFragment(FileSource source, std::shared_ptr format, + compute::Expression partition_expression, + std::shared_ptr physical_schema, + std::optional> stripe_ids) + : FileFragment(std::move(source), std::move(format), + std::move(partition_expression), std::move(physical_schema)), + stripe_ids_(std::move(stripe_ids)) {} + +Result> OrcFileFragment::Subset(std::vector stripe_ids) { + return std::shared_ptr( + new OrcFileFragment(source_, format_, + partition_expression(), physical_schema_, + std::move(stripe_ids))); +} + +Result> OrcFileFormat::MakeFragment( + FileSource source, compute::Expression partition_expression, + std::shared_ptr physical_schema) { + return std::shared_ptr( + new OrcFileFragment(std::move(source), shared_from_this(), + std::move(partition_expression), + std::move(physical_schema), std::nullopt)); +} + +Result> OrcFileFormat::MakeFragment( + FileSource source, compute::Expression partition_expression, + std::shared_ptr physical_schema, std::vector stripe_ids) { + if (!stripe_ids.empty()) { + ARROW_ASSIGN_OR_RAISE(auto reader, OpenORCReader(source)); + int64_t num_stripes = reader->NumberOfStripes(); + for (int stripe_id : stripe_ids) { + if (stripe_id < 0 || stripe_id >= num_stripes) { + return Status::IndexError("Stripe ID ", stripe_id, + " is out of range. File has ", num_stripes, + " stripe(s), valid range is [0, ", + num_stripes - 1, "]"); + } + } + } + return std::shared_ptr( + new OrcFileFragment(std::move(source), shared_from_this(), + std::move(partition_expression), + std::move(physical_schema), std::move(stripe_ids))); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/file_orc.h b/cpp/src/arrow/dataset/file_orc.h index 5bfefd1e02b..2105da6a738 100644 --- a/cpp/src/arrow/dataset/file_orc.h +++ b/cpp/src/arrow/dataset/file_orc.h @@ -20,7 +20,9 @@ #pragma once #include +#include #include +#include #include "arrow/dataset/file_base.h" #include "arrow/dataset/type_fwd.h" @@ -37,6 +39,8 @@ namespace dataset { constexpr char kOrcTypeName[] = "orc"; +class OrcFileFragment; + /// \brief A FileFormat implementation that reads from and writes to ORC files class ARROW_DS_EXPORT OrcFileFormat : public FileFormat { public: @@ -67,6 +71,45 @@ class ARROW_DS_EXPORT OrcFileFormat : public FileFormat { fs::FileLocator destination_locator) const override; std::shared_ptr DefaultWriteOptions() override; + + using FileFormat::MakeFragment; + + Result> MakeFragment( + FileSource source, compute::Expression partition_expression, + std::shared_ptr physical_schema) override; + + Result> MakeFragment( + FileSource source, compute::Expression partition_expression, + std::shared_ptr physical_schema, std::vector stripe_ids); +}; + +/// \brief A FileFragment with ORC-specific logic for stripe-level subsetting. +/// +/// OrcFileFragment provides the ability to scan ORC files at stripe granularity, +/// enabling parallel processing of sub-file splits. The caller can provide an +/// optional list of selected stripe IDs to limit the scan to specific stripes. +class ARROW_DS_EXPORT OrcFileFragment : public FileFragment { + public: + /// \brief Return the stripe IDs selected by this fragment. + /// Empty vector means all stripes. + const std::vector& stripe_ids() const { + if (stripe_ids_) return *stripe_ids_; + static std::vector empty; + return empty; + } + + /// \brief Return fragment which selects a subset of this fragment's stripes. + Result> Subset(std::vector stripe_ids); + + private: + OrcFileFragment(FileSource source, std::shared_ptr format, + compute::Expression partition_expression, + std::shared_ptr physical_schema, + std::optional> stripe_ids); + + std::optional> stripe_ids_; + + friend class OrcFileFormat; }; /// @} diff --git a/cpp/src/arrow/dataset/file_orc_test.cc b/cpp/src/arrow/dataset/file_orc_test.cc index 17be015de51..0725bd84710 100644 --- a/cpp/src/arrow/dataset/file_orc_test.cc +++ b/cpp/src/arrow/dataset/file_orc_test.cc @@ -25,6 +25,7 @@ #include "arrow/dataset/discovery.h" #include "arrow/dataset/file_base.h" #include "arrow/dataset/partition.h" +#include "arrow/dataset/scanner.h" #include "arrow/dataset/test_util_internal.h" #include "arrow/io/memory.h" #include "arrow/record_batch.h" @@ -54,8 +55,6 @@ class OrcFormatHelper { class TestOrcFileFormat : public FileFormatFixtureMixin {}; -// TEST_F(TestOrcFileFormat, WriteRecordBatchReader) { TestWrite(); } - TEST_F(TestOrcFileFormat, InspectFailureWithRelevantError) { TestInspectFailureWithRelevantError(StatusCode::IOError, "ORC"); } @@ -92,5 +91,167 @@ INSTANTIATE_TEST_SUITE_P(TestScan, TestOrcFileFormatScan, ::testing::ValuesIn(TestFormatParams::Values()), TestFormatParams::ToTestNameString); +class TestOrcFileFragment : public ::testing::Test { + public: + void SetUp() override { + format_ = std::make_shared(); + opts_ = std::make_shared(); + opts_->dataset_schema = schema({field("f64", float64())}); + SetSchema(opts_->dataset_schema->fields()); + } + + void SetSchema(std::vector> fields) { + opts_->dataset_schema = schema(std::move(fields)); + ASSERT_OK_AND_ASSIGN(input_, WriteMultiStripeBuffer(4, /*batch_size=*/512)); + } + + Result> WriteMultiStripeBuffer(int num_batches, + int batch_size) { + adapters::orc::WriteOptions write_opts; + write_opts.stripe_size = 1024; + + ARROW_ASSIGN_OR_RAISE(auto sink, io::BufferOutputStream::Create()); + ARROW_ASSIGN_OR_RAISE(auto writer, + adapters::orc::ORCFileWriter::Open(sink.get(), write_opts)); + for (int i = 0; i < num_batches; i++) { + auto batch = + ConstantArrayGenerator::Zeroes(batch_size, opts_->dataset_schema); + RETURN_NOT_OK(writer->Write(*batch)); + } + RETURN_NOT_OK(writer->Close()); + return sink->Finish(); + } + + Result> MakeFragment(FileSource source) { + ARROW_ASSIGN_OR_RAISE(auto fragment, + format_->MakeFragment(std::move(source), literal(true), + opts_->dataset_schema)); + return std::dynamic_pointer_cast(fragment); + } + + Result> MakeFragment( + FileSource source, std::vector stripe_ids) { + return format_->MakeFragment(std::move(source), literal(true), + opts_->dataset_schema, std::move(stripe_ids)); + } + + void AssertScanEquals(std::shared_ptr fragment, int64_t expected_rows) { + auto dataset = std::make_shared( + opts_->dataset_schema, FragmentVector{std::move(fragment)}); + ScannerBuilder builder(dataset, opts_); + ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable()); + ASSERT_EQ(table->num_rows(), expected_rows); + } + + protected: + std::shared_ptr input_; + std::shared_ptr opts_; + std::shared_ptr format_; +}; + +TEST_F(TestOrcFileFragment, Basics) { + auto source = FileSource(input_); + ASSERT_OK_AND_ASSIGN(auto fragment, MakeFragment(source)); + ASSERT_NE(fragment, nullptr); + ASSERT_TRUE(fragment->stripe_ids().empty()); +} + +TEST_F(TestOrcFileFragment, MakeFragmentWithStripeIds) { + auto source = FileSource(input_); + std::vector stripe_ids = {0, 1}; + ASSERT_OK_AND_ASSIGN(auto fragment, MakeFragment(source, stripe_ids)); + ASSERT_NE(fragment, nullptr); + ASSERT_EQ(fragment->stripe_ids(), stripe_ids); +} + +TEST_F(TestOrcFileFragment, Subset) { + auto source = FileSource(input_); + ASSERT_OK_AND_ASSIGN(auto fragment, MakeFragment(source)); + + std::vector stripe_ids = {0}; + ASSERT_OK_AND_ASSIGN(auto subset_fragment, fragment->Subset(stripe_ids)); + ASSERT_NE(subset_fragment, nullptr); + + auto* orc_subset = dynamic_cast(subset_fragment.get()); + ASSERT_NE(orc_subset, nullptr); + ASSERT_EQ(orc_subset->stripe_ids(), stripe_ids); +} + +TEST_F(TestOrcFileFragment, ScanSubset) { + auto source = FileSource(input_); + + ASSERT_OK_AND_ASSIGN(auto reader, + adapters::orc::ORCFileReader::Open( + std::make_shared(input_), + default_memory_pool())); + int64_t num_stripes = reader->NumberOfStripes(); + ASSERT_GT(num_stripes, 1) << "Test file should have multiple stripes"; + + ASSERT_OK_AND_ASSIGN(auto full_fragment, MakeFragment(source)); + AssertScanEquals(full_fragment, 2048); + + std::vector first_stripe = {0}; + ASSERT_OK_AND_ASSIGN(auto subset_fragment, MakeFragment(source, first_stripe)); + + auto stripe_info = reader->GetStripeInformation(0); + int64_t expected_rows = stripe_info.num_rows; + ASSERT_LT(expected_rows, 2048); + AssertScanEquals(subset_fragment, expected_rows); + + ASSERT_OK_AND_ASSIGN(auto subset_via_subset, full_fragment->Subset(first_stripe)); + AssertScanEquals(subset_via_subset, expected_rows); +} + +TEST_F(TestOrcFileFragment, InvalidStripeIdOutOfRange) { + auto source = FileSource(input_); + ASSERT_OK_AND_ASSIGN(auto reader, + adapters::orc::ORCFileReader::Open( + std::make_shared(input_), + default_memory_pool())); + int64_t num_stripes = reader->NumberOfStripes(); + + std::vector invalid_ids = {static_cast(num_stripes)}; + ASSERT_RAISES(IndexError, MakeFragment(source, invalid_ids)); + + std::vector very_invalid_ids = {9999}; + ASSERT_RAISES(IndexError, MakeFragment(source, very_invalid_ids)); +} + +TEST_F(TestOrcFileFragment, InvalidStripeIdNegative) { + auto source = FileSource(input_); + std::vector negative_ids = {-1}; + ASSERT_RAISES(IndexError, MakeFragment(source, negative_ids)); +} + +TEST_F(TestOrcFileFragment, CountRowsWithStripeSubset) { + auto source = FileSource(input_); + ASSERT_OK_AND_ASSIGN(auto reader, + adapters::orc::ORCFileReader::Open( + std::make_shared(input_), + default_memory_pool())); + int64_t num_stripes = reader->NumberOfStripes(); + ASSERT_GT(num_stripes, 1) << "Test file should have multiple stripes"; + + std::vector first_stripe = {0}; + ASSERT_OK_AND_ASSIGN(auto fragment, MakeFragment(source, first_stripe)); + + auto stripe_info = reader->GetStripeInformation(0); + int64_t expected_rows = stripe_info.num_rows; + + auto count_result = format_->CountRows(fragment, literal(true), opts_); + ASSERT_OK_AND_ASSIGN(auto count, count_result.result()); + ASSERT_TRUE(count.has_value()); + ASSERT_EQ(count.value(), expected_rows); + + ASSERT_OK_AND_ASSIGN(auto full_fragment, MakeFragment(source)); + auto full_count_result = format_->CountRows(full_fragment, literal(true), opts_); + ASSERT_OK_AND_ASSIGN(auto full_count, full_count_result.result()); + ASSERT_TRUE(full_count.has_value()); + ASSERT_EQ(full_count.value(), 2048); + + ASSERT_LT(count.value(), full_count.value()); +} + } // namespace dataset } // namespace arrow