diff --git a/include/xtensor/core/xexpression.hpp b/include/xtensor/core/xexpression.hpp index 9b133315d..5b9d5c610 100644 --- a/include/xtensor/core/xexpression.hpp +++ b/include/xtensor/core/xexpression.hpp @@ -10,6 +10,7 @@ #ifndef XTENSOR_EXPRESSION_HPP #define XTENSOR_EXPRESSION_HPP +#include #include #include @@ -18,6 +19,7 @@ #include "../core/xlayout.hpp" #include "../core/xshape.hpp" +#include "../core/xiterator.hpp" #include "../core/xtensor_forward.hpp" #include "../utils/xutils.hpp" @@ -536,8 +538,10 @@ namespace xt using stepper = typename E::stepper; using const_stepper = typename E::const_stepper; - using linear_iterator = typename E::linear_iterator; - using const_linear_iterator = typename E::const_linear_iterator; + using linear_iterator = decltype(xt::linear_begin(std::declval())); + using const_linear_iterator = decltype(xt::linear_begin(std::declval())); + using reverse_linear_iterator = std::reverse_iterator; + using const_reverse_linear_iterator = std::reverse_iterator; using bool_load_type = typename E::bool_load_type; @@ -573,19 +577,65 @@ namespace xt XTENSOR_FORWARD_CONST_ITERATOR_METHOD(crbegin) XTENSOR_FORWARD_CONST_ITERATOR_METHOD(crend) - XTENSOR_FORWARD_METHOD(linear_begin) - XTENSOR_FORWARD_METHOD(linear_end) - XTENSOR_FORWARD_CONST_METHOD(linear_begin) - XTENSOR_FORWARD_CONST_METHOD(linear_end) - XTENSOR_FORWARD_CONST_METHOD(linear_cbegin) - XTENSOR_FORWARD_CONST_METHOD(linear_cend) - - XTENSOR_FORWARD_METHOD(linear_rbegin) - XTENSOR_FORWARD_METHOD(linear_rend) - XTENSOR_FORWARD_CONST_METHOD(linear_rbegin) - XTENSOR_FORWARD_CONST_METHOD(linear_rend) - XTENSOR_FORWARD_CONST_METHOD(linear_crbegin) - XTENSOR_FORWARD_CONST_METHOD(linear_crend) + linear_iterator linear_begin() noexcept + { + return xt::linear_begin(*m_ptr); + } + + linear_iterator linear_end() noexcept + { + return xt::linear_end(*m_ptr); + } + + const_linear_iterator linear_begin() const noexcept + { + return xt::linear_begin(*m_ptr); + } + + const_linear_iterator linear_end() const noexcept + { + return xt::linear_end(*m_ptr); + } + + const_linear_iterator linear_cbegin() const noexcept + { + return xt::linear_begin(*m_ptr); + } + + const_linear_iterator linear_cend() const noexcept + { + return xt::linear_end(*m_ptr); + } + + reverse_linear_iterator linear_rbegin() noexcept + { + return reverse_linear_iterator(linear_end()); + } + + reverse_linear_iterator linear_rend() noexcept + { + return reverse_linear_iterator(linear_begin()); + } + + const_reverse_linear_iterator linear_rbegin() const noexcept + { + return const_reverse_linear_iterator(linear_end()); + } + + const_reverse_linear_iterator linear_rend() const noexcept + { + return const_reverse_linear_iterator(linear_begin()); + } + + const_reverse_linear_iterator linear_crbegin() const noexcept + { + return const_reverse_linear_iterator(linear_cend()); + } + + const_reverse_linear_iterator linear_crend() const noexcept + { + return const_reverse_linear_iterator(linear_cbegin()); + } template std::enable_if_t::value, const inner_strides_type&> strides() const diff --git a/test/test_xexpression.cpp b/test/test_xexpression.cpp index c1ad4d14f..75eba3edb 100644 --- a/test/test_xexpression.cpp +++ b/test/test_xexpression.cpp @@ -10,8 +10,10 @@ #include #include "xtensor/containers/xarray.hpp" +#include "xtensor/containers/xtensor.hpp" #include "xtensor/core/xexpression.hpp" #include "xtensor/core/xmath.hpp" +#include "xtensor/generators/xrandom.hpp" #include "xtensor/io/xio.hpp" #include "test_common_macros.hpp" @@ -116,6 +118,21 @@ namespace xt EXPECT_EQ(expr, a * a); } + TEST(xexpression, shared_reducer_result) + { + std::size_t n = 1000; + std::size_t m = 9; + std::size_t o = 12; + + xtensor tensor = random::rand({n, m, o}, -20., 20.); + auto result = make_xshared(sum(tensor, {2})); + auto expected = sum(tensor, {2}); + + EXPECT_EQ(result.dimension(), std::size_t(2)); + EXPECT_EQ(result.shape(), expected.shape()); + EXPECT_TRUE(all(equal(result, expected))); + } + TEST(xexpression, temporary_type) { using dyn_shape = xt::svector, true>;