Skip to content

Commit 2fcc3ec

Browse files
rokjorisvandenbosschepitrou
authored
GH-38007: [C++] Add VariableShapeTensor implementation (#38008)
### Rationale for this change We want to add VariableShapeTensor extension type definition for arrays containing tensors with variable shapes. ### What changes are included in this PR? This adds a C++ implementation. ### Are these changes tested? Yes. ### Are there any user-facing changes? This adds a new extension type C++. * Closes: #38007 * GitHub Issue: #38007 Lead-authored-by: Rok Mihevc <rok@mihevc.org> Co-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com> Co-authored-by: Antoine Pitrou <pitrou@free.fr> Signed-off-by: Rok Mihevc <rok@mihevc.org>
1 parent f3f1eb0 commit 2fcc3ec

13 files changed

Lines changed: 944 additions & 141 deletions

cpp/src/arrow/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,8 @@ if(ARROW_JSON)
995995
arrow_add_object_library(ARROW_JSON
996996
extension/fixed_shape_tensor.cc
997997
extension/opaque.cc
998+
extension/tensor_internal.cc
999+
extension/variable_shape_tensor.cc
9981000
json/options.cc
9991001
json/chunked_builder.cc
10001002
json/chunker.cc

cpp/src/arrow/extension/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc)
1919

2020
if(ARROW_JSON)
21-
list(APPEND CANONICAL_EXTENSION_TESTS fixed_shape_tensor_test.cc opaque_test.cc)
21+
list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc)
2222
endif()
2323

2424
add_arrow_test(test

cpp/src/arrow/extension/fixed_shape_tensor.cc

Lines changed: 19 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include "arrow/array/array_primitive.h"
2727
#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
2828
#include "arrow/tensor.h"
29-
#include "arrow/util/int_util_overflow.h"
3029
#include "arrow/util/logging_internal.h"
3130
#include "arrow/util/print_internal.h"
3231
#include "arrow/util/sort_internal.h"
@@ -37,71 +36,18 @@
3736

3837
namespace rj = arrow::rapidjson;
3938

40-
namespace arrow {
41-
42-
namespace extension {
43-
44-
namespace {
45-
46-
Status ComputeStrides(const FixedWidthType& type, const std::vector<int64_t>& shape,
47-
const std::vector<int64_t>& permutation,
48-
std::vector<int64_t>* strides) {
49-
if (permutation.empty()) {
50-
return internal::ComputeRowMajorStrides(type, shape, strides);
51-
}
52-
53-
const int byte_width = type.byte_width();
54-
55-
int64_t remaining = 0;
56-
if (!shape.empty() && shape.front() > 0) {
57-
remaining = byte_width;
58-
for (auto i : permutation) {
59-
if (i > 0) {
60-
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
61-
return Status::Invalid(
62-
"Strides computed from shape would not fit in 64-bit integer");
63-
}
64-
}
65-
}
66-
}
67-
68-
if (remaining == 0) {
69-
strides->assign(shape.size(), byte_width);
70-
return Status::OK();
71-
}
72-
73-
strides->push_back(remaining);
74-
for (auto i : permutation) {
75-
if (i > 0) {
76-
remaining /= shape[i];
77-
strides->push_back(remaining);
78-
}
79-
}
80-
internal::Permute(permutation, strides);
81-
82-
return Status::OK();
83-
}
84-
85-
} // namespace
39+
namespace arrow::extension {
8640

8741
bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
8842
if (extension_name() != other.extension_name()) {
8943
return false;
9044
}
9145
const auto& other_ext = internal::checked_cast<const FixedShapeTensorType&>(other);
9246

93-
auto is_permutation_trivial = [](const std::vector<int64_t>& permutation) {
94-
for (size_t i = 1; i < permutation.size(); ++i) {
95-
if (permutation[i - 1] + 1 != permutation[i]) {
96-
return false;
97-
}
98-
}
99-
return true;
100-
};
10147
const bool permutation_equivalent =
102-
((permutation_ == other_ext.permutation()) ||
103-
(permutation_.empty() && is_permutation_trivial(other_ext.permutation())) ||
104-
(is_permutation_trivial(permutation_) && other_ext.permutation().empty()));
48+
(permutation_ == other_ext.permutation()) ||
49+
(internal::IsPermutationTrivial(permutation_) &&
50+
internal::IsPermutationTrivial(other_ext.permutation()));
10551

10652
return (storage_type()->Equals(other_ext.storage_type())) &&
10753
(this->shape() == other_ext.shape()) && (dim_names_ == other_ext.dim_names()) &&
@@ -167,7 +113,8 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
167113
internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
168114
rj::Document document;
169115
if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() ||
170-
!document.HasMember("shape") || !document["shape"].IsArray()) {
116+
!document.IsObject() || !document.HasMember("shape") ||
117+
!document["shape"].IsArray()) {
171118
return Status::Invalid("Invalid serialized JSON data: ", serialized_data);
172119
}
173120

@@ -218,10 +165,6 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
218165
if (array->null_count() > 0) {
219166
return Status::Invalid("Cannot convert data with nulls to Tensor.");
220167
}
221-
const auto& value_type =
222-
internal::checked_cast<const FixedWidthType&>(*ext_type.value_type());
223-
const auto byte_width = value_type.byte_width();
224-
225168
std::vector<int64_t> permutation = ext_type.permutation();
226169
if (permutation.empty()) {
227170
permutation.resize(ext_type.ndim());
@@ -236,13 +179,10 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
236179
internal::Permute<std::string>(permutation, &dim_names);
237180
}
238181

239-
std::vector<int64_t> strides;
240-
RETURN_NOT_OK(ComputeStrides(value_type, shape, permutation, &strides));
241-
const auto start_position = array->offset() * byte_width;
242-
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
243-
std::multiplies<>());
244-
const auto buffer =
245-
SliceBuffer(array->data()->buffers[1], start_position, size * byte_width);
182+
ARROW_ASSIGN_OR_RAISE(
183+
auto strides, internal::ComputeStrides(ext_type.value_type(), shape, permutation));
184+
ARROW_ASSIGN_OR_RAISE(const auto buffer, internal::SliceTensorBuffer(
185+
*array, *ext_type.value_type(), shape));
246186

247187
return Tensor::Make(ext_type.value_type(), buffer, shape, strides, dim_names);
248188
}
@@ -304,7 +244,7 @@ Result<std::shared_ptr<FixedShapeTensorArray>> FixedShapeTensorArray::FromTensor
304244
break;
305245
}
306246
case Type::UINT64: {
307-
value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data());
247+
value_array = std::make_shared<UInt64Array>(tensor->size(), tensor->data());
308248
break;
309249
}
310250
case Type::INT64: {
@@ -375,10 +315,8 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
375315
shape.insert(shape.begin(), 1, this->length());
376316
internal::Permute<int64_t>(permutation, &shape);
377317

378-
std::vector<int64_t> tensor_strides;
379-
const auto* fw_value_type = internal::checked_cast<FixedWidthType*>(value_type.get());
380-
ARROW_RETURN_NOT_OK(
381-
ComputeStrides(*fw_value_type, shape, permutation, &tensor_strides));
318+
ARROW_ASSIGN_OR_RAISE(auto tensor_strides,
319+
internal::ComputeStrides(value_type, shape, permutation));
382320

383321
const auto& raw_buffer = this->storage()->data()->child_data[0]->buffers[1];
384322
ARROW_ASSIGN_OR_RAISE(
@@ -412,11 +350,10 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
412350

413351
const std::vector<int64_t>& FixedShapeTensorType::strides() {
414352
if (strides_.empty()) {
415-
auto value_type = internal::checked_cast<FixedWidthType*>(this->value_type_.get());
416-
std::vector<int64_t> tensor_strides;
417-
ARROW_CHECK_OK(
418-
ComputeStrides(*value_type, this->shape(), this->permutation(), &tensor_strides));
419-
strides_ = tensor_strides;
353+
auto maybe_strides =
354+
internal::ComputeStrides(this->value_type_, this->shape(), this->permutation());
355+
ARROW_CHECK_OK(maybe_strides.status());
356+
strides_ = std::move(maybe_strides).MoveValueUnsafe();
420357
}
421358
return strides_;
422359
}
@@ -426,9 +363,8 @@ std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& va
426363
const std::vector<int64_t>& permutation,
427364
const std::vector<std::string>& dim_names) {
428365
auto maybe_type = FixedShapeTensorType::Make(value_type, shape, permutation, dim_names);
429-
ARROW_DCHECK_OK(maybe_type.status());
366+
ARROW_CHECK_OK(maybe_type.status());
430367
return maybe_type.MoveValueUnsafe();
431368
}
432369

433-
} // namespace extension
434-
} // namespace arrow
370+
} // namespace arrow::extension

cpp/src/arrow/extension/fixed_shape_tensor.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919

2020
#include "arrow/extension_type.h"
2121

22-
namespace arrow {
23-
namespace extension {
22+
namespace arrow::extension {
2423

2524
class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
2625
public:
@@ -112,7 +111,6 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
112111
const std::vector<std::string>& dim_names = {});
113112

114113
private:
115-
std::shared_ptr<DataType> storage_type_;
116114
std::shared_ptr<DataType> value_type_;
117115
std::vector<int64_t> shape_;
118116
std::vector<int64_t> strides_;
@@ -126,5 +124,4 @@ ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
126124
const std::vector<int64_t>& permutation = {},
127125
const std::vector<std::string>& dim_names = {});
128126

129-
} // namespace extension
130-
} // namespace arrow
127+
} // namespace arrow::extension

0 commit comments

Comments
 (0)