Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions arm_compute/runtime/NEON/functions/NEMaxUnpoolingLayer.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, 2025 Arm Limited.
* Copyright (c) 2020-2022, 2025-2026 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
Expand Down Expand Up @@ -76,7 +76,7 @@ class NEMaxUnpoolingLayer : public IFunction
*
* @param[in, out] input Source tensor. (Written to only when padding != 0) Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[out] output Destination tensor. Data types supported: Same as @p input.
* @param[out] indices The indices of the maximal values. Data type supported: U32.
* @param[in] indices The indices of the maximal values. Data type supported: U32.
* @param[in] pool_info Contains pooling operation information described in @ref PoolingLayerInfo.
*/
void configure(ITensor *input, ITensor *indices, ITensor *output, const PoolingLayerInfo &pool_info);
Expand Down
12 changes: 8 additions & 4 deletions src/cpu/kernels/maxunpool/generic/neon/impl.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023 Arm Limited.
* Copyright (c) 2022-2023, 2026 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
Expand Down Expand Up @@ -27,6 +27,9 @@
#include "arm_compute/core/Window.h"

#include "src/core/NEON/wrapper/wrapper.h"

#include <algorithm>

namespace arm_compute
{
namespace cpu
Expand All @@ -38,13 +41,14 @@ void max_unpooling(const ITensor *input, const ITensor *indices, ITensor *output
Iterator indices_itr(indices, window);
auto out_ptr = reinterpret_cast<T *>(output->buffer());
const int out_stride_w = static_cast<int>(output->info()->strides_in_bytes()[3]);
uint32_t slice_size = output->info()->tensor_shape().total_size_lower(3);
execute_window_loop(
window,
[&](const Coordinates &id)
{
auto vindices = reinterpret_cast<uint32_t *>(indices_itr.ptr());
auto vinput = reinterpret_cast<T *>(input_itr.ptr());
out_ptr[id[3] * out_stride_w / sizeof(T) + *vindices] = *vinput;
auto vindices = reinterpret_cast<uint32_t *>(indices_itr.ptr());
auto vinput = reinterpret_cast<T *>(input_itr.ptr());
out_ptr[id[3] * out_stride_w / sizeof(T) + std::min(slice_size - 1, *vindices)] = *vinput;
},
input_itr, indices_itr);
}
Expand Down