2626#include " arrow/array/array_primitive.h"
2727#include " arrow/json/rapidjson_defs.h" // IWYU pragma: keep
2828#include " arrow/tensor.h"
29- #include " arrow/util/int_util_overflow.h"
3029#include " arrow/util/logging_internal.h"
3130#include " arrow/util/print_internal.h"
3231#include " arrow/util/sort_internal.h"
3736
3837namespace rj = arrow::rapidjson;
3938
40- namespace arrow {
41-
42- namespace extension {
43-
44- namespace {
45-
46- Status ComputeStrides (const FixedWidthType& type, const std::vector<int64_t >& shape,
47- const std::vector<int64_t >& permutation,
48- std::vector<int64_t >* strides) {
49- if (permutation.empty ()) {
50- return internal::ComputeRowMajorStrides (type, shape, strides);
51- }
52-
53- const int byte_width = type.byte_width ();
54-
55- int64_t remaining = 0 ;
56- if (!shape.empty () && shape.front () > 0 ) {
57- remaining = byte_width;
58- for (auto i : permutation) {
59- if (i > 0 ) {
60- if (internal::MultiplyWithOverflow (remaining, shape[i], &remaining)) {
61- return Status::Invalid (
62- " Strides computed from shape would not fit in 64-bit integer" );
63- }
64- }
65- }
66- }
67-
68- if (remaining == 0 ) {
69- strides->assign (shape.size (), byte_width);
70- return Status::OK ();
71- }
72-
73- strides->push_back (remaining);
74- for (auto i : permutation) {
75- if (i > 0 ) {
76- remaining /= shape[i];
77- strides->push_back (remaining);
78- }
79- }
80- internal::Permute (permutation, strides);
81-
82- return Status::OK ();
83- }
84-
85- } // namespace
39+ namespace arrow ::extension {
8640
8741bool FixedShapeTensorType::ExtensionEquals (const ExtensionType& other) const {
8842 if (extension_name () != other.extension_name ()) {
8943 return false ;
9044 }
9145 const auto & other_ext = internal::checked_cast<const FixedShapeTensorType&>(other);
9246
93- auto is_permutation_trivial = [](const std::vector<int64_t >& permutation) {
94- for (size_t i = 1 ; i < permutation.size (); ++i) {
95- if (permutation[i - 1 ] + 1 != permutation[i]) {
96- return false ;
97- }
98- }
99- return true ;
100- };
10147 const bool permutation_equivalent =
102- (( permutation_ == other_ext.permutation ()) ||
103- (permutation_. empty ( ) && is_permutation_trivial (other_ext. permutation ())) ||
104- ( is_permutation_trivial (permutation_) && other_ext.permutation (). empty ()));
48+ (permutation_ == other_ext.permutation ()) ||
49+ ( internal::IsPermutationTrivial (permutation_ ) &&
50+ internal::IsPermutationTrivial ( other_ext.permutation ()));
10551
10652 return (storage_type ()->Equals (other_ext.storage_type ())) &&
10753 (this ->shape () == other_ext.shape ()) && (dim_names_ == other_ext.dim_names ()) &&
@@ -167,7 +113,8 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
167113 internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type ();
168114 rj::Document document;
169115 if (document.Parse (serialized_data.data (), serialized_data.length ()).HasParseError () ||
170- !document.HasMember (" shape" ) || !document[" shape" ].IsArray ()) {
116+ !document.IsObject () || !document.HasMember (" shape" ) ||
117+ !document[" shape" ].IsArray ()) {
171118 return Status::Invalid (" Invalid serialized JSON data: " , serialized_data);
172119 }
173120
@@ -218,10 +165,6 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
218165 if (array->null_count () > 0 ) {
219166 return Status::Invalid (" Cannot convert data with nulls to Tensor." );
220167 }
221- const auto & value_type =
222- internal::checked_cast<const FixedWidthType&>(*ext_type.value_type ());
223- const auto byte_width = value_type.byte_width ();
224-
225168 std::vector<int64_t > permutation = ext_type.permutation ();
226169 if (permutation.empty ()) {
227170 permutation.resize (ext_type.ndim ());
@@ -236,13 +179,10 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
236179 internal::Permute<std::string>(permutation, &dim_names);
237180 }
238181
239- std::vector<int64_t > strides;
240- RETURN_NOT_OK (ComputeStrides (value_type, shape, permutation, &strides));
241- const auto start_position = array->offset () * byte_width;
242- const auto size = std::accumulate (shape.begin (), shape.end (), static_cast <int64_t >(1 ),
243- std::multiplies<>());
244- const auto buffer =
245- SliceBuffer (array->data ()->buffers [1 ], start_position, size * byte_width);
182+ ARROW_ASSIGN_OR_RAISE (
183+ auto strides, internal::ComputeStrides (ext_type.value_type (), shape, permutation));
184+ ARROW_ASSIGN_OR_RAISE (const auto buffer, internal::SliceTensorBuffer (
185+ *array, *ext_type.value_type (), shape));
246186
247187 return Tensor::Make (ext_type.value_type (), buffer, shape, strides, dim_names);
248188}
@@ -304,7 +244,7 @@ Result<std::shared_ptr<FixedShapeTensorArray>> FixedShapeTensorArray::FromTensor
304244 break ;
305245 }
306246 case Type::UINT64: {
307- value_array = std::make_shared<Int64Array >(tensor->size (), tensor->data ());
247+ value_array = std::make_shared<UInt64Array >(tensor->size (), tensor->data ());
308248 break ;
309249 }
310250 case Type::INT64: {
@@ -375,10 +315,8 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
375315 shape.insert (shape.begin (), 1 , this ->length ());
376316 internal::Permute<int64_t >(permutation, &shape);
377317
378- std::vector<int64_t > tensor_strides;
379- const auto * fw_value_type = internal::checked_cast<FixedWidthType*>(value_type.get ());
380- ARROW_RETURN_NOT_OK (
381- ComputeStrides (*fw_value_type, shape, permutation, &tensor_strides));
318+ ARROW_ASSIGN_OR_RAISE (auto tensor_strides,
319+ internal::ComputeStrides (value_type, shape, permutation));
382320
383321 const auto & raw_buffer = this ->storage ()->data ()->child_data [0 ]->buffers [1 ];
384322 ARROW_ASSIGN_OR_RAISE (
@@ -412,11 +350,10 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
412350
413351const std::vector<int64_t >& FixedShapeTensorType::strides () {
414352 if (strides_.empty ()) {
415- auto value_type = internal::checked_cast<FixedWidthType*>(this ->value_type_ .get ());
416- std::vector<int64_t > tensor_strides;
417- ARROW_CHECK_OK (
418- ComputeStrides (*value_type, this ->shape (), this ->permutation (), &tensor_strides));
419- strides_ = tensor_strides;
353+ auto maybe_strides =
354+ internal::ComputeStrides (this ->value_type_ , this ->shape (), this ->permutation ());
355+ ARROW_CHECK_OK (maybe_strides.status ());
356+ strides_ = std::move (maybe_strides).MoveValueUnsafe ();
420357 }
421358 return strides_;
422359}
@@ -426,9 +363,8 @@ std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& va
426363 const std::vector<int64_t >& permutation,
427364 const std::vector<std::string>& dim_names) {
428365 auto maybe_type = FixedShapeTensorType::Make (value_type, shape, permutation, dim_names);
429- ARROW_DCHECK_OK (maybe_type.status ());
366+ ARROW_CHECK_OK (maybe_type.status ());
430367 return maybe_type.MoveValueUnsafe ();
431368}
432369
433- } // namespace extension
434- } // namespace arrow
370+ } // namespace arrow::extension
0 commit comments