Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dpnp/backend/extensions/indexing/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ set_target_properties(

target_include_directories(
${python_module_name}
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../ ${CMAKE_CURRENT_SOURCE_DIR}/../common
)

# treat below headers as system to suppress the warnings there during the build
Expand Down
118 changes: 87 additions & 31 deletions dpnp/backend/extensions/indexing/choose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,41 +30,116 @@
#include <cstddef>
#include <cstdint>
#include <memory>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <sycl/sycl.hpp>
#include <type_traits>
#include <utility>
#include <vector>

#include "choose_kernel.hpp"
#include <sycl/sycl.hpp>

#include "dpctl4pybind11.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

// utils extension header
#include "ext/common.hpp"
#include "kernels/indexing/choose.hpp"

// dpctl tensor headers
#include "utils/indexing_utils.hpp"
#include "utils/memory_overlap.hpp"
#include "utils/offset_utils.hpp" //
#include "utils/output_validation.hpp"
#include "utils/sycl_alloc_utils.hpp"
#include "utils/type_dispatch.hpp"

namespace dpnp::extensions::indexing
{

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

static kernels::choose_fn_ptr_t choose_clip_dispatch_table[td_ns::num_types]
[td_ns::num_types];
static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
[td_ns::num_types];
using dpctl::tensor::ssize_t;

typedef sycl::event (*choose_fn_ptr_t)(sycl::queue &,
size_t,
ssize_t,
int,
const ssize_t *,
const char *,
char *,
char **,
ssize_t,
ssize_t,
const ssize_t *,
const std::vector<sycl::event> &);

static choose_fn_ptr_t choose_clip_dispatch_table[td_ns::num_types]
[td_ns::num_types];
static choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
[td_ns::num_types];

template <typename ProjectorT, typename indTy, typename Ty>
sycl::event choose_impl(sycl::queue &q,
size_t nelems,
ssize_t n_chcs,
int nd,
const ssize_t *shape_and_strides,
const char *ind_cp,
char *dst_cp,
char **chcs_cp,
ssize_t ind_offset,
ssize_t dst_offset,
const ssize_t *chc_offsets,
const std::vector<sycl::event> &depends)
{
dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);

const indTy *ind_tp = reinterpret_cast<const indTy *>(ind_cp);
Ty *dst_tp = reinterpret_cast<Ty *>(dst_cp);

namespace py = pybind11;
sycl::event choose_ev = q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

namespace detail
using InOutIndexerT =
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
const InOutIndexerT ind_out_indexer{nd, ind_offset, dst_offset,
shape_and_strides};

using NthChoiceIndexerT =
dpnp::kernels::choose::strides::NthStrideOffsetUnpacked;
const NthChoiceIndexerT choices_indexer{
nd, chc_offsets, shape_and_strides, shape_and_strides + 3 * nd};

using ChooseFunc =
dpnp::kernels::choose::ChooseFunctor<ProjectorT, InOutIndexerT,
NthChoiceIndexerT, indTy, Ty>;

cgh.parallel_for<ChooseFunc>(sycl::range<1>(nelems),
ChooseFunc(ind_tp, dst_tp, chcs_cp, n_chcs,
ind_out_indexer,
choices_indexer));
});

return choose_ev;
}

template <typename fnT, typename IndT, typename T, typename Index>
struct ChooseFactory
{
fnT get()
{
if constexpr (std::is_integral<IndT>::value &&
!std::is_same<IndT, bool>::value) {
fnT fn = choose_impl<Index, IndT, T>;
return fn;
}
else {
fnT fn = nullptr;
return fn;
}
}
};

namespace detail
{
using host_ptrs_allocator_t =
dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
using ptrs_t = std::vector<char *, host_ptrs_allocator_t>;
Expand Down Expand Up @@ -191,7 +266,6 @@ std::vector<dpctl::tensor::usm_ndarray> parse_py_chcs(const sycl::queue &q,

return res;
}

} // namespace detail

std::pair<sycl::event, sycl::event>
Expand Down Expand Up @@ -412,23 +486,6 @@ std::pair<sycl::event, sycl::event>
return std::make_pair(arg_cleanup_ev, choose_generic_ev);
}

template <typename fnT, typename IndT, typename T, typename Index>
struct ChooseFactory
{
fnT get()
{
if constexpr (std::is_integral<IndT>::value &&
!std::is_same<IndT, bool>::value) {
fnT fn = kernels::choose_impl<Index, IndT, T>;
return fn;
}
else {
fnT fn = nullptr;
return fn;
}
}
};

using dpctl::tensor::indexing_utils::ClipIndex;
using dpctl::tensor::indexing_utils::WrapIndex;

Expand All @@ -441,7 +498,6 @@ using ChooseClipFactory = ChooseFactory<fnT, IndT, T, ClipIndex<IndT>>;
void init_choose_dispatch_tables(void)
{
using ext::common::init_dispatch_table;
using kernels::choose_fn_ptr_t;

init_dispatch_table<choose_fn_ptr_t, ChooseClipFactory>(
choose_clip_dispatch_table);
Expand Down
191 changes: 0 additions & 191 deletions dpnp/backend/extensions/indexing/choose_kernel.hpp

This file was deleted.

Loading
Loading