diff --git a/CMakeLists.txt b/CMakeLists.txt index 8e88b77..4756b0b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,14 +13,13 @@ if(CCACHE_PROGRAM) endif() list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/utils") +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/dependencies") set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) -# The CUDA standard is still C++14 to enable interopability with -# slightly older and still well-supported versions of CUDA/nvcc -# (e.g. CUDA < 11). This will be bumped to 17 once CUDA 11 is -# required. -set(CMAKE_CUDA_STANDARD 14) +set(CMAKE_CUDA_STANDARD 20) set(CMAKE_CUDA_STANDARD_REQUIRED ON) # no modules in this library diff --git a/CMakePresets.json b/CMakePresets.json index aec9e32..e63a94c 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -82,7 +82,7 @@ "hidden": true, "cacheVariables": { "CMAKE_CUDA_COMPILER": "nvcc", - "FL_USE_CUDNN": false, + "FL_USE_CUDNN": true, "CMAKE_CUDA_ARCHITECTURES": "native", "CMAKE_CUDA_FLAGS": "-allow-unsupported-compiler", "VCPKG_MANIFEST_FEATURES": "cuda" diff --git a/Folder.DotSettings b/Folder.DotSettings new file mode 100644 index 0000000..84f2a07 --- /dev/null +++ b/Folder.DotSettings @@ -0,0 +1,8 @@ + + <NamingElement Priority="6" Title="Parameters"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="function parameter" /><type Name="lambda parameter" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="aaBb"><ExtraRule Prefix="_" Suffix="" Style="aaBb" /></Policy></NamingElement> + <NamingElement Priority="16" Title="Other constants"><Descriptor Static="True" Constexpr="Indeterminate" Const="True" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="class field" /><type Name="local variable" /><type Name="struct field" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="AA_BB"><ExtraRule Prefix="" Suffix="" Style="aa_bb" /></Policy></NamingElement> + <NamingElement Priority="15" Title="Enumerators"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="scoped enumerator" /><type Name="unscoped enumerator" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="AA_BB"><ExtraRule Prefix="" Suffix="" Style="aa_bb" /></Policy></NamingElement> + <NamingElement Priority="3" Title="Enums"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="enum" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="AaBb_AaBb"><ExtraRule Prefix="" Suffix="" Style="aa_bb" /></Policy></NamingElement> + True + True + True \ No newline at end of file diff --git a/README.md b/README.md index 94d1e3e..06b85a2 100644 --- a/README.md +++ b/README.md @@ -22,10 +22,13 @@ Please read the [todo list](TODO.md) ### Quirks -`FL_USE_CUDNN`: -- NOT WORKING ATM (v6-7 api from 2017 i gotta fix that first) -- requires `CUDNN_ROOT` to be set in the environment -- windows users: **do not** install CUDNN with the default **windows installer**. It will create: `CUDNN_ROOT///`. Since i cannot anticipate the cuda version you use, i can't traverse this. **FIX:** install as **tarball** instead. +- `ArrayFire` + - Do not install via winget (for now), 3.9 has a critical bug. +- `FL_USE_CUDNN`: + - NOT WORKING ATM (v6-7 api from 2017 i gotta fix that first) + - requires `CUDNN_ROOT` to be set in the environment + - windows users: **do not** install CUDNN with the default **windows installer**. It will create: `CUDNN_ROOT///`. Since i cannot anticipate the cuda version you use, i can't traverse this. **FIX:** install as **tarball** instead. + ### Functional changes from Flashlight - backends removed: diff --git a/cmake/dependencies/FindCUDNN.cmake b/cmake/dependencies/FindCUDNN.cmake index a150310..fd77eea 100644 --- a/cmake/dependencies/FindCUDNN.cmake +++ b/cmake/dependencies/FindCUDNN.cmake @@ -76,4 +76,8 @@ if(CUDNN_FOUND) endif() endif() +if (CUDNN_FOUND AND CUDNN_VERSION VERSION_LESS "8.0") + message(FATAL_ERROR "Flashlight requires cuDNN >= 8.0, found ${CUDNN_VERSION}") +endif() + mark_as_advanced(CUDNN_ROOT CUDNN_INCLUDE_DIR CUDNN_LIBRARY CUDNN_VERSION) diff --git a/cmake/dependencies/FindFilesystem.cmake b/cmake/dependencies/FindFilesystem.cmake index a13bb35..1bda6eb 100644 --- a/cmake/dependencies/FindFilesystem.cmake +++ b/cmake/dependencies/FindFilesystem.cmake @@ -126,6 +126,7 @@ set(CMAKE_REQUIRED_QUIET ${Filesystem_FIND_QUIETLY}) # All of our tests required C++20 or later set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_EXTENSIONS OFF) # Normalize and check the component list we were given set(want_components ${Filesystem_FIND_COMPONENTS}) diff --git a/cmake/utils/TestUtils.cmake b/cmake/utils/TestUtils.cmake index 8f9a228..53cd86c 100644 --- a/cmake/utils/TestUtils.cmake +++ b/cmake/utils/TestUtils.cmake @@ -5,26 +5,32 @@ set(GTEST_IMPORTED_TARGETS "") # Get or find Google Test and Google Mock find_package(GTest 1.10.0) -if (NOT GTEST_FOUND) - if (NOT TARGET gtest) + +if(NOT GTEST_FOUND) + if(NOT TARGET gtest) message(STATUS "googletest not found - will download and build from source") + # Download and build googletest include(${PROJECT_SOURCE_DIR}/cmake/BuildGoogleTest.cmake) list(APPEND GTEST_IMPORTED_TARGETS GTest::gtest GTest::gtest_main GTest::gmock GTest::gmock_main) endif() else() message(STATUS "gtest found: (include: ${GTEST_INCLUDE_DIRS}, lib: ${GTEST_BOTH_LIBRARIES}") - if (TARGET GTest::GTest) + + if(TARGET GTest::GTest) # We found the differently-named CMake targets from FindGTest - if (NOT TARGET GTest::Main) + if(NOT TARGET GTest::Main) message(FATAL_ERROR "Google Test must be built with main") endif() + list(APPEND GTEST_IMPORTED_TARGETS GTest::GTest GTest::Main) endif() - if (NOT TARGET GTest::gmock) + + if(NOT TARGET GTest::gmock) find_package(GMock REQUIRED) message(STATUS "gmock found: (include: ${GMOCK_INCLUDE_DIRS}, lib: ${GMOCK_BOTH_LIBRARIES})") endif() + list(APPEND GTEST_IMPORTED_TARGETS GTest::gmock GTest::gmock_main) message(STATUS "Found gtest and gmock on system.") endif() @@ -42,28 +48,30 @@ function(build_test) get_filename_component(src_name ${build_test_SRC} NAME_WE) set(target "${src_name}") add_executable(${target} ${build_test_SRC}) - if (TARGET gtest) + + if(TARGET gtest) add_dependencies(${target} gtest) # make sure gtest is built first endif() + target_link_libraries( ${target} PUBLIC ${GTEST_IMPORTED_TARGETS} ${build_test_LIBS} ${CMAKE_THREAD_LIBS_INIT} - ) + ) target_include_directories( ${target} PUBLIC ${PROJECT_SOURCE_DIR} - ) + ) target_compile_definitions( ${target} PUBLIC ${build_test_PREPROC} - ) + ) - if (CMAKE_SYSTEM_NAME STREQUAL "Windows") + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") target_compile_definitions( ${target} PUBLIC @@ -71,7 +79,8 @@ function(build_test) GMOCK_LINKED_AS_SHARED_LIBRARY=$ ) endif() - gtest_add_tests(TARGET ${target}) + + gtest_discover_tests(${target} DISCOVERY_MODE PRE_TEST) if(WIN32) fm_target_copy_dependencies(${target}) diff --git a/cmake/utils/flashlightConfig.cmake.in b/cmake/utils/flashlightConfig.cmake.in index 2bf2550..9d73423 100644 --- a/cmake/utils/flashlightConfig.cmake.in +++ b/cmake/utils/flashlightConfig.cmake.in @@ -49,7 +49,7 @@ if (@FL_BUILD_STANDALONE@) endif() if (@FL_USE_CUDA@) if (@FL_USE_CUDNN@) - find_dependency(CUDNN 7.1) + find_dependency(CUDNN 8) endif() if (@FL_BUILD_DISTRIBUTED@) find_dependency(NCCL) diff --git a/cmake/utils/fm_target_utilities.cmake b/cmake/utils/fm_target_utilities.cmake index c3ad681..f5310d7 100644 --- a/cmake/utils/fm_target_utilities.cmake +++ b/cmake/utils/fm_target_utilities.cmake @@ -53,10 +53,17 @@ function(fm_glob OUT_VAR) set(GLOB_PATTERNS ${ARG_PATTERNS}) endif() - if(GLOB_PATTERNS) + # Normalize paths to prevent CONFIGURE_DEPENDS cache mismatch issues on Windows + set(NORMALIZED_PATTERNS "") + foreach(PATTERN IN LISTS GLOB_PATTERNS) + cmake_path(ABSOLUTE_PATH PATTERN NORMALIZE OUTPUT_VARIABLE NORMALIZED) + list(APPEND NORMALIZED_PATTERNS "${NORMALIZED}") + endforeach() + + if(NORMALIZED_PATTERNS) file(GLOB_RECURSE FOUND_FILES CONFIGURE_DEPENDS - ${GLOB_PATTERNS} + ${NORMALIZED_PATTERNS} ) set(${OUT_VAR} ${${OUT_VAR}} ${FOUND_FILES} PARENT_SCOPE) endif() diff --git a/flashlight/fl/autograd/Functions.cpp b/flashlight/fl/autograd/Functions.cpp index 70b7a9e..efae460 100644 --- a/flashlight/fl/autograd/Functions.cpp +++ b/flashlight/fl/autograd/Functions.cpp @@ -1,8 +1,8 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ #include @@ -24,7 +24,7 @@ namespace fl { namespace detail { - Tensor tileAs(const Tensor& input, const Shape& rdims) { + Tensor tileAs(Tensor const& input, Shape const& rdims) { // Scalar tensor if(input.ndim() == 0) return tile(input, rdims); @@ -36,7 +36,7 @@ namespace detail { if(rdims[i] % idimsSize != 0) { std::stringstream ss; ss << "Invalid dims for tileAs for input dims " << idims - << " to output dims " << rdims; + << " to output dims " << rdims; throw std::invalid_argument(ss.str()); } dims[i] = rdims[i] / idimsSize; @@ -44,19 +44,19 @@ namespace detail { return tile(input, dims); } - Tensor sumAs(const Tensor& input, const Shape& rdims) { + Tensor sumAs(Tensor const& input, Shape const& rdims) { Shape idims = input.shape(); auto result = input; for(int i = 0; i < input.ndim(); i++) if(i + 1 > rdims.ndim() || idims[i] != rdims[i]) result = fl::sum(result, {i}, /* keepDims = */ true); - return fl::reshape(result.astype(input.type()), rdims); + return fl::reshape(result.asType(input.type()), rdims); } Shape expandedShapeFromReducedDims( - const Tensor& input, - const std::vector& axes, + Tensor const& input, + std::vector const& axes, bool keepDims /* = false */ ) { // Fast path - tensor already retained its shape @@ -72,7 +72,7 @@ namespace detail { unsigned inputIdx = 0; for(unsigned i = 0; i < preNDims; ++i) { if(i == axes[axesIdx]) - // This dim was reduced over, leave as 1 in the new shape + // This dim was reduced over, leave as 1 in the new shape axesIdx++; else { // Dim wasn't reduced over - add the shape from the new tensor @@ -83,10 +83,10 @@ namespace detail { return newShape; } -// TODO: remove these/use a simple template + // TODO: remove these/use a simple template Variable expandFromReduction( - const Variable& input, - const std::vector& axes, + Variable const& input, + std::vector const& axes, bool keepDims ) { return moddims( @@ -96,8 +96,8 @@ namespace detail { } Tensor expandFromReduction( - const Tensor& input, - const std::vector& axes, + Tensor const& input, + std::vector const& axes, bool keepDims ) { auto o = expandedShapeFromReducedDims(input, axes, keepDims); @@ -107,75 +107,87 @@ namespace detail { ); } - bool areVariableTypesEqual(const Variable& a, const Variable& b) { return a.type() == b.type(); } + bool areVariableTypesEqual(Variable const& a, Variable const& b) { return a.type() == b.type(); } } // namespace detail -Variable operator+(const Variable& lhs, const Variable& rhs) { +Variable operator+(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); auto result = lhs.tensor() + rhs.tensor(); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor(), false)); - inputs[1].addGrad(Variable(gradOutput.tensor(), false)); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + inputs[0].addGrad(Variable(gradOutput.tensor(), false)); + inputs[1].addGrad(Variable(gradOutput.tensor(), false)); + }; return Variable(result, {lhs.withoutData(), rhs.withoutData()}, gradFunc); } -Variable operator+(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() + rhsVal).astype(lhs.type()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor(), false)); - }; +Variable operator+(Variable const& lhs, double const& rhsVal) { + auto result = (lhs.tensor() + rhsVal).asType(lhs.type()); + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + inputs[0].addGrad(Variable(gradOutput.tensor(), false)); + }; return Variable(result, {lhs.withoutData()}, gradFunc); } -Variable operator+(const double& lhsVal, const Variable& rhs) { return rhs + lhsVal; } +Variable operator+(double const& lhsVal, Variable const& rhs) { return rhs + lhsVal; } -Variable operator-(const Variable& lhs, const Variable& rhs) { +Variable operator-(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); auto result = lhs.tensor() - rhs.tensor(); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor(), false)); - inputs[1].addGrad(Variable(negate(gradOutput).tensor(), false)); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + inputs[0].addGrad(Variable(gradOutput.tensor(), false)); + inputs[1].addGrad(Variable(negate(gradOutput).tensor(), false)); + }; return Variable(result, {lhs.withoutData(), rhs.withoutData()}, gradFunc); } -Variable operator-(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() - rhsVal).astype(lhs.type()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor(), false)); - }; +Variable operator-(Variable const& lhs, double const& rhsVal) { + auto result = (lhs.tensor() - rhsVal).asType(lhs.type()); + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + inputs[0].addGrad(Variable(gradOutput.tensor(), false)); + }; return Variable(result, {lhs.withoutData()}, gradFunc); } -Variable operator-(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal - rhs.tensor()).astype(rhs.type()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false)); - }; +Variable operator-(double const& lhsVal, Variable const& rhs) { + auto result = (lhsVal - rhs.tensor()).asType(rhs.type()); + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false)); + }; return Variable(result, {rhs.withoutData()}, gradFunc); } -Variable operator*(const Variable& lhs, const Variable& rhs) { +Variable operator*(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); auto result = lhs.tensor() * rhs.tensor(); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - if(inputs[0].isCalcGrad()) - inputs[0].addGrad( - Variable(gradOutput.tensor() * inputs[1].tensor(), false) - ); - if(inputs[1].isCalcGrad()) - inputs[1].addGrad( - Variable(gradOutput.tensor() * inputs[0].tensor(), false) - ); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + if(inputs[0].isCalcGrad()) + inputs[0].addGrad( + Variable(gradOutput.tensor() * inputs[1].tensor(), false) + ); + if(inputs[1].isCalcGrad()) + inputs[1].addGrad( + Variable(gradOutput.tensor() * inputs[0].tensor(), false) + ); + }; return Variable( result, { @@ -186,34 +198,35 @@ Variable operator*(const Variable& lhs, const Variable& rhs) { ); } -Variable operator*(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() * rhsVal).astype(lhs.type()); - auto gradFunc = - [rhsVal](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor() * rhsVal, false)); - }; +Variable operator*(Variable const& lhs, double const& rhsVal) { + auto result = (lhs.tensor() * rhsVal).asType(lhs.type()); + auto gradFunc = [rhsVal](std::vector& inputs, Variable const& gradOutput) { + inputs[0].addGrad(Variable(gradOutput.tensor() * rhsVal, false)); + }; return Variable(result, {lhs.withoutData()}, gradFunc); } -Variable operator*(const double& lhsVal, const Variable& rhs) { return rhs * lhsVal; } +Variable operator*(double const& lhsVal, Variable const& rhs) { return rhs * lhsVal; } -Variable operator/(const Variable& lhs, const Variable& rhs) { +Variable operator/(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); auto result = lhs.tensor() / rhs.tensor(); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto inputs1rec = reciprocal(inputs[1]); - auto gradInput0 = gradOutput * inputs1rec; - if(inputs[0].isCalcGrad()) - inputs[0].addGrad(Variable(gradInput0.tensor(), false)); - if(inputs[1].isCalcGrad()) - inputs[1].addGrad( - Variable( - (gradInput0 * negate(inputs[0]) * inputs1rec).tensor(), - false - ) - ); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + auto inputs1rec = reciprocal(inputs[1]); + auto gradInput0 = gradOutput * inputs1rec; + if(inputs[0].isCalcGrad()) + inputs[0].addGrad(Variable(gradInput0.tensor(), false)); + if(inputs[1].isCalcGrad()) + inputs[1].addGrad( + Variable( + (gradInput0 * negate(inputs[0]) * inputs1rec).tensor(), + false + ) + ); + }; return Variable( result, {rhs.isCalcGrad() ? lhs : lhs.withoutData(), rhs}, @@ -221,368 +234,395 @@ Variable operator/(const Variable& lhs, const Variable& rhs) { ); } -Variable operator/(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() / rhsVal).astype(lhs.type()); +Variable operator/(Variable const& lhs, double const& rhsVal) { + auto result = (lhs.tensor() / rhsVal).asType(lhs.type()); auto gradFunc = - [rhsVal](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad(Variable((gradOutput / rhsVal).tensor(), false)); - }; + [rhsVal](std::vector& inputs, Variable const& gradOutput) { + inputs[0].addGrad(Variable((gradOutput / rhsVal).tensor(), false)); + }; return Variable(result, {lhs.withoutData()}, gradFunc); } -Variable operator/(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal / rhs.tensor()).astype(rhs.type()); +Variable operator/(double const& lhsVal, Variable const& rhs) { + auto result = (lhsVal / rhs.tensor()).asType(rhs.type()); auto gradFunc = [lhsVal]( std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable( - (gradOutput * (-lhsVal) / (inputs[0] * inputs[0])).tensor(), - false - ) - ); - }; + Variable const& gradOutput + ) { + inputs[0].addGrad( + Variable( + (gradOutput * (-lhsVal) / (inputs[0] * inputs[0])).tensor(), + false + ) + ); + }; return Variable(result, {rhs}, gradFunc); } -Variable operator>(const Variable& lhs, const Variable& rhs) { +Variable operator>(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); auto result = lhs.tensor() > rhs.tensor(); return Variable(result, false); } -Variable operator>(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() > rhsVal).astype(lhs.type()); +Variable operator>(Variable const& lhs, double const& rhsVal) { + auto result = (lhs.tensor() > rhsVal).asType(lhs.type()); return Variable(result, false); } -Variable operator>(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal > rhs.tensor()).astype(rhs.type()); +Variable operator>(double const& lhsVal, Variable const& rhs) { + auto result = (lhsVal > rhs.tensor()).asType(rhs.type()); return Variable(result, false); } -Variable operator<(const Variable& lhs, const Variable& rhs) { +Variable operator<(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); auto result = lhs.tensor() < rhs.tensor(); return Variable(result, false); } -Variable operator<(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() < rhsVal).astype(lhs.type()); +Variable operator<(Variable const& lhs, double const& rhsVal) { + auto result = (lhs.tensor() < rhsVal).asType(lhs.type()); return Variable(result, false); } -Variable operator<(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal < rhs.tensor()).astype(rhs.type()); +Variable operator<(double const& lhsVal, Variable const& rhs) { + auto result = (lhsVal < rhs.tensor()).asType(rhs.type()); return Variable(result, false); } -Variable operator>=(const Variable& lhs, const Variable& rhs) { +Variable operator>=(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); auto result = lhs.tensor() >= rhs.tensor(); return Variable(result, false); } -Variable operator>=(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() >= rhsVal).astype(lhs.type()); +Variable operator>=(Variable const& lhs, double const& rhsVal) { + auto result = (lhs.tensor() >= rhsVal).asType(lhs.type()); return Variable(result, false); } -Variable operator>=(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal >= rhs.tensor()).astype(rhs.type()); +Variable operator>=(double const& lhsVal, Variable const& rhs) { + auto result = (lhsVal >= rhs.tensor()).asType(rhs.type()); return Variable(result, false); } -Variable operator<=(const Variable& lhs, const Variable& rhs) { +Variable operator<=(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); auto result = lhs.tensor() <= rhs.tensor(); return Variable(result, false); } -Variable operator<=(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() <= rhsVal).astype(lhs.type()); +Variable operator<=(Variable const& lhs, double const& rhsVal) { + auto result = (lhs.tensor() <= rhsVal).asType(lhs.type()); return Variable(result, false); } -Variable operator<=(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal <= rhs.tensor()).astype(rhs.type()); +Variable operator<=(double const& lhsVal, Variable const& rhs) { + auto result = (lhsVal <= rhs.tensor()).asType(rhs.type()); return Variable(result, false); } -Variable operator&&(const Variable& lhs, const Variable& rhs) { +Variable operator&&(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); auto result = lhs.tensor() && rhs.tensor(); return Variable(result, false); } -Variable operator!(const Variable& input) { - auto result = (!input.tensor()).astype(input.type()); +Variable operator!(Variable const& input) { + auto result = (!input.tensor()).asType(input.type()); return Variable(result, false); } -Variable max(const Variable& lhs, const Variable& rhs) { +Variable max(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); auto result = fl::maximum(lhs.tensor(), rhs.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto mask = Variable( - (inputs[0].tensor() > inputs[1].tensor()).astype(gradOutput.type()), - false - ); - inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); - inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false)); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + auto mask = Variable( + (inputs[0].tensor() > inputs[1].tensor()).asType(gradOutput.type()), + false + ); + inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); + inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false)); + }; return Variable(result, {lhs, rhs}, gradFunc); } -Variable max(const Variable& lhs, const double& rhsVal) { - auto result = fl::maximum(lhs.tensor(), rhsVal).astype(lhs.type()); +Variable max(Variable const& lhs, double const& rhsVal) { + auto result = fl::maximum(lhs.tensor(), rhsVal).asType(lhs.type()); auto gradFunc = - [rhsVal](std::vector& inputs, const Variable& gradOutput) { - auto mask = Variable( - (inputs[0].tensor() > rhsVal).astype(gradOutput.type()), - false - ); - inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); - }; + [rhsVal](std::vector& inputs, Variable const& gradOutput) { + auto mask = Variable( + (inputs[0].tensor() > rhsVal).asType(gradOutput.type()), + false + ); + inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); + }; return Variable(result, {lhs}, gradFunc); } -Variable max(const double& lhsVal, const Variable& rhs) { return max(rhs, lhsVal); } +Variable max(double const& lhsVal, Variable const& rhs) { return max(rhs, lhsVal); } -Variable min(const Variable& lhs, const Variable& rhs) { +Variable min(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); auto result = fl::minimum(lhs.tensor(), rhs.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto mask = Variable( - (inputs[0].tensor() < inputs[1].tensor()).astype(gradOutput.type()), - false - ); - inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); - inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false)); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + auto mask = Variable( + (inputs[0].tensor() < inputs[1].tensor()).asType(gradOutput.type()), + false + ); + inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); + inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false)); + }; return Variable(result, {lhs, rhs}, gradFunc); } -Variable min(const Variable& lhs, const double& rhsVal) { - auto result = fl::minimum(lhs.tensor(), rhsVal).astype(lhs.type()); +Variable min(Variable const& lhs, double const& rhsVal) { + auto result = fl::minimum(lhs.tensor(), rhsVal).asType(lhs.type()); auto gradFunc = - [rhsVal](std::vector& inputs, const Variable& gradOutput) { - auto mask = Variable( - (inputs[0].tensor() < rhsVal).astype(gradOutput.type()), - false - ); - inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); - }; + [rhsVal](std::vector& inputs, Variable const& gradOutput) { + auto mask = Variable( + (inputs[0].tensor() < rhsVal).asType(gradOutput.type()), + false + ); + inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); + }; return Variable(result, {lhs}, gradFunc); } -Variable min(const double& lhsVal, const Variable& rhs) { return min(rhs, lhsVal); } +Variable min(double const& lhsVal, Variable const& rhs) { return min(rhs, lhsVal); } -Variable negate(const Variable& input) { - auto result = (0.0 - input.tensor()).astype(input.type()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false)); - }; +Variable negate(Variable const& input) { + auto result = (0.0 - input.tensor()).asType(input.type()); + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false)); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable reciprocal(const Variable& input) { +Variable reciprocal(Variable const& input) { auto result = 1.0 / FL_ADJUST_INPUT_TYPE(input.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto res = reciprocal(inputs[0]); - inputs[0].addGrad( - Variable((negate(gradOutput) * res * res).tensor(), false) - ); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + auto res = reciprocal(inputs[0]); + inputs[0].addGrad( + Variable((negate(gradOutput) * res * res).tensor(), false) + ); + }; return Variable(result, {input}, gradFunc); } -Variable exp(const Variable& input) { +Variable exp(Variable const& input) { auto result = fl::exp(FL_ADJUST_INPUT_TYPE(input.tensor())); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable(gradOutput.tensor() * fl::exp(inputs[0].tensor()), false) - ); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + inputs[0].addGrad( + Variable(gradOutput.tensor() * fl::exp(inputs[0].tensor()), false) + ); + }; return Variable(result, {input}, gradFunc); } -Variable log(const Variable& input) { +Variable log(Variable const& input) { auto result = fl::log(FL_ADJUST_INPUT_TYPE(input.tensor())); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable((gradOutput.tensor() / inputs[0].tensor()), false) - ); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + inputs[0].addGrad( + Variable((gradOutput.tensor() / inputs[0].tensor()), false) + ); + }; return Variable(result, {input}, gradFunc); } -Variable log1p(const Variable& input) { +Variable log1p(Variable const& input) { auto result = fl::log1p(FL_ADJUST_INPUT_TYPE(input.tensor())); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable((gradOutput.tensor() / (1.0 + inputs[0].tensor())), false) - ); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + inputs[0].addGrad( + Variable((gradOutput.tensor() / (1.0 + inputs[0].tensor())), false) + ); + }; return Variable(result, {input}, gradFunc); } -Variable pow(const Variable& input, double p) { +Variable pow(Variable const& input, double p) { auto result = fl::power(FL_ADJUST_INPUT_TYPE(input.tensor()), p); - auto gradFunc = [p](std::vector& inputs, - const Variable& gradOutput) { - Tensor grad = - p * fl::power(inputs[0].tensor(), p - 1) * gradOutput.tensor(); - inputs[0].addGrad(Variable(grad, false)); - }; + auto gradFunc = [p]( + std::vector& inputs, + Variable const& gradOutput + ) { + Tensor grad = + p * fl::power(inputs[0].tensor(), p - 1) * gradOutput.tensor(); + inputs[0].addGrad(Variable(grad, false)); + }; return Variable(result, {input}, gradFunc); } -Variable sin(const Variable& input) { +Variable sin(Variable const& input) { auto result = fl::sin(input.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable((gradOutput.tensor() * cos(inputs[0].tensor())), false) - ); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + inputs[0].addGrad( + Variable((gradOutput.tensor() * cos(inputs[0].tensor())), false) + ); + }; return Variable(result, {input}, gradFunc); } -Variable cos(const Variable& input) { +Variable cos(Variable const& input) { auto result = fl::cos(input.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable( - (gradOutput.tensor() * negative(sin(inputs[0].tensor()))), - false - ) - ); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + inputs[0].addGrad( + Variable( + (gradOutput.tensor() * negative(sin(inputs[0].tensor()))), + false + ) + ); + }; return Variable(result, {input}, gradFunc); } -Variable tanh(const Variable& input) { +Variable tanh(Variable const& input) { auto result = fl::tanh(input.tensor()); auto gradFunc = - [result](std::vector& inputs, const Variable& gradOutput) { - auto grad = - Variable((1.0 - result * result) * gradOutput.tensor(), false); - inputs[0].addGrad(Variable(grad.tensor(), false)); - }; + [result](std::vector& inputs, Variable const& gradOutput) { + auto grad = + Variable((1.0 - result * result) * gradOutput.tensor(), false); + inputs[0].addGrad(Variable(grad.tensor(), false)); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable clamp(const Variable& input, const double lo, const double hi) { +Variable clamp(Variable const& input, double const lo, double const hi) { auto result = fl::clip(input.tensor(), lo, hi); auto gradFunc = [lo, hi, result]( std::vector& inputs, - const Variable& gradOutput) { - Tensor gradMask = gradOutput.tensor(); - gradMask = fl::where((result > lo) && (result < hi), gradMask, 0); - inputs[0].addGrad(Variable(gradMask, false)); - }; + Variable const& gradOutput + ) { + Tensor gradMask = gradOutput.tensor(); + gradMask = fl::where((result > lo) && (result < hi), gradMask, 0); + inputs[0].addGrad(Variable(gradMask, false)); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable sqrt(const Variable& input) { +Variable sqrt(Variable const& input) { auto result = fl::sqrt(input.tensor()); auto gradFunc = [result]( std::vector& inputs, - const Variable& gradOutput) { - auto output = Variable(result, false); - inputs[0].addGrad(Variable((gradOutput / (2 * output)).tensor(), false)); - }; + Variable const& gradOutput + ) { + auto output = Variable(result, false); + inputs[0].addGrad(Variable((gradOutput / (2 * output)).tensor(), false)); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable sigmoid(const Variable& input) { +Variable sigmoid(Variable const& input) { auto result = fl::sigmoid(input.tensor()); auto gradFunc = - [result](std::vector& inputs, const Variable& gradOutput) { - auto grad = gradOutput.tensor() * result * (1 - result); - inputs[0].addGrad(Variable(grad, false)); - }; + [result](std::vector& inputs, Variable const& gradOutput) { + auto grad = gradOutput.tensor() * result * (1 - result); + inputs[0].addGrad(Variable(grad, false)); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable swish(const Variable& input, double beta) { return input * sigmoid(beta * input); } +Variable swish(Variable const& input, double beta) { return input * sigmoid(beta * input); } -Variable erf(const Variable& input) { +Variable erf(Variable const& input) { auto result = fl::erf(FL_ADJUST_INPUT_TYPE(input.tensor())); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto x = inputs[0].tensor(); - auto grad = gradOutput.tensor() * 2 / std::sqrt(M_PI) * fl::exp(-(x * x)); - inputs[0].addGrad(Variable(grad, false)); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + auto x = inputs[0].tensor(); + auto grad = gradOutput.tensor() * 2 / std::sqrt(M_PI) * fl::exp(-(x * x)); + inputs[0].addGrad(Variable(grad, false)); + }; return Variable(result, {input}, gradFunc); } -Variable transpose(const Variable& input, const Shape& dims /* = {} */) { +Variable transpose(Variable const& input, Shape const& dims /* = {} */) { auto result = fl::transpose(input.tensor(), dims); auto gradFunc = [inputDims = input.shape(), ndim = input.ndim(), dims]( std::vector& inputs, - const Variable& gradOutput) { - Shape reverseShape = dims; - - if(dims.ndim()) { - // Reverse vec if transposing all dims (empty arg) - auto dVec = dims.get(); - std::reverse(dVec.begin(), dVec.end()); - reverseShape = Shape(dVec); - } + Variable const& gradOutput + ) { + Shape reverseShape = dims; - for(unsigned i = 0; i < reverseShape.ndim(); ++i) - reverseShape[dims[i]] = i; + if(dims.ndim()) { + // Reverse vec if transposing all dims (empty arg) + auto dVec = dims.get(); + std::reverse(dVec.begin(), dVec.end()); + reverseShape = Shape(dVec); + } - inputs[0].addGrad( - Variable(fl::transpose(gradOutput.tensor(), reverseShape), false) - ); - }; + for(unsigned i = 0; i < reverseShape.ndim(); ++i) + reverseShape[dims[i]] = i; + + inputs[0].addGrad( + Variable(fl::transpose(gradOutput.tensor(), reverseShape), false) + ); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable tileAs(const Variable& input, const Shape& rdims) { +Variable tileAs(Variable const& input, Shape const& rdims) { auto result = detail::tileAs(input.tensor(), rdims); Shape inDims = input.shape(); auto gradFunc = [inDims]( std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable( - sumAs(gradOutput, inDims).tensor().astype(inputs[0].type()), - false - ) - ); - }; + Variable const& gradOutput + ) { + inputs[0].addGrad( + Variable( + sumAs(gradOutput, inDims).tensor().asType(inputs[0].type()), + false + ) + ); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable tileAs(const Variable& input, const Variable& reference) { return tileAs(input, reference.shape()); } +Variable tileAs(Variable const& input, Variable const& reference) { return tileAs(input, reference.shape()); } -Variable sumAs(const Variable& input, const Shape& rdims) { +Variable sumAs(Variable const& input, Shape const& rdims) { auto result = detail::sumAs(FL_ADJUST_INPUT_TYPE(input.tensor()), rdims); auto idims = input.tensor().shape(); auto gradFunc = - [idims](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad(Variable(tileAs(gradOutput, idims).tensor(), false)); - }; + [idims](std::vector& inputs, Variable const& gradOutput) { + inputs[0].addGrad(Variable(tileAs(gradOutput, idims).tensor(), false)); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable sumAs(const Variable& input, const Variable& reference) { return sumAs(input, reference.shape()); } +Variable sumAs(Variable const& input, Variable const& reference) { return sumAs(input, reference.shape()); } -Variable concatenate(const std::vector& concatInputs, int dim) { +Variable concatenate(std::vector const& concatInputs, int dim) { if(concatInputs.empty()) throw std::invalid_argument("cannot concatenate zero variables"); @@ -620,7 +660,7 @@ Variable concatenate(const std::vector& concatInputs, int dim) { Tensor result(dims, concatInputs[0].type()); std::vector slice(numDims, fl::span); int start = 0; - for(const auto& input : concatInputs) { + for(auto const& input : concatInputs) { slice[dim] = fl::range({start, start + input.dim(dim)}); result(slice) = input.tensor(); start += input.dim(dim); @@ -629,38 +669,39 @@ Variable concatenate(const std::vector& concatInputs, int dim) { std::vector inputsNoData; std::vector inDims; - for(const auto& in : concatInputs) { + for(auto const& in : concatInputs) { inputsNoData.push_back(in.withoutData()); inDims.push_back(in.shape()); } auto gradFunc = [dim, inDims, numDims]( std::vector& inputs, - const Variable& gradOutput) { - std::vector sx(numDims, fl::span); - int s = 0; - for(size_t i = 0; i < inputs.size(); ++i) { - sx[dim] = fl::range(s, s + inDims[i][dim]); - inputs[i].addGrad(Variable(gradOutput.tensor()(sx), false)); - s += inDims[i][dim]; - } - }; + Variable const& gradOutput + ) { + std::vector sx(numDims, fl::span); + int s = 0; + for(size_t i = 0; i < inputs.size(); ++i) { + sx[dim] = fl::range(s, s + inDims[i][dim]); + inputs[i].addGrad(Variable(gradOutput.tensor()(sx), false)); + s += inDims[i][dim]; + } + }; return Variable(result, inputsNoData, gradFunc); } -std::vector split(const Variable& input, long splitSize, int dim) { +std::vector split(Variable const& input, int64_t splitSize, int dim) { if(splitSize <= 0) throw std::invalid_argument("split size must be a positive integer"); auto dimSize = input.dim(dim); - std::vector splitSizes(dimSize / splitSize, splitSize); + std::vector splitSizes(dimSize / splitSize, splitSize); if(dimSize % splitSize > 0) splitSizes.push_back(dimSize % splitSize); return split(input, splitSizes, dim); } -std::vector split(const Variable& input, const std::vector& splitSizes, int dim) { +std::vector split(Variable const& input, std::vector const& splitSizes, int dim) { if(dim >= input.ndim()) throw std::invalid_argument( "split: passed dim is larger than the number of dimensions " @@ -685,24 +726,24 @@ std::vector split(const Variable& input, const std::vector& spli return outputs; } -Variable tile(const Variable& input, const Shape& dims) { +Variable tile(Variable const& input, Shape const& dims) { Tensor result = fl::tile(input.tensor(), dims); Shape idims = input.shape(); auto gradFunc = - [idims](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad( - Variable( - sumAs(gradOutput, idims).tensor().astype(inputs[0].type()), - false - ) - ); - }; + [idims](std::vector& inputs, Variable const& gradOutput) { + inputs[0].addGrad( + Variable( + sumAs(gradOutput, idims).tensor().asType(inputs[0].type()), + false + ) + ); + }; return Variable(result, {input.withoutData()}, gradFunc); } Variable sum( - const Variable& input, - const std::vector& axes, + Variable const& input, + std::vector const& axes, bool keepDims /* = false*/ ) { auto result = FL_ADJUST_INPUT_TYPE(input.tensor()); @@ -711,23 +752,24 @@ Variable sum( Shape indims = input.shape(); auto gradFunc = [indims, axes, keepDims]( std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable( - detail::tileAs( - detail::expandFromReduction(gradOutput.tensor(), axes, keepDims), - indims - ), - false - ) - ); - }; - return Variable(result.astype(input.type()), {input.withoutData()}, gradFunc); + Variable const& gradOutput + ) { + inputs[0].addGrad( + Variable( + detail::tileAs( + detail::expandFromReduction(gradOutput.tensor(), axes, keepDims), + indims + ), + false + ) + ); + }; + return Variable(result.asType(input.type()), {input.withoutData()}, gradFunc); } Variable mean( - const Variable& input, - const std::vector& axes, + Variable const& input, + std::vector const& axes, bool keepDims /* = false*/ ) { auto result = FL_ADJUST_INPUT_TYPE(input.tensor()); @@ -736,38 +778,39 @@ Variable mean( Shape idims = input.shape(); auto gradFunc = [idims, axes, keepDims]( std::vector& inputs, - const Variable& gradOutput) { - Shape odims = gradOutput.shape(); - Dim count = 1; - for(int i = 0; i < idims.ndim(); i++) { - Dim odimSize = i + 1 > odims.ndim() ? 1 : odims[i]; - count *= idims[i] / odimSize; - } - auto grad = + Variable const& gradOutput + ) { + Shape odims = gradOutput.shape(); + Dim count = 1; + for(int i = 0; i < idims.ndim(); i++) { + Dim odimSize = i + 1 > odims.ndim() ? 1 : odims[i]; + count *= idims[i] / odimSize; + } + auto grad = + detail::tileAs( + detail::expandFromReduction(gradOutput.tensor(), axes, keepDims), + idims + ) + / count; + inputs[0].addGrad( + Variable( detail::tileAs( detail::expandFromReduction(gradOutput.tensor(), axes, keepDims), idims ) - / count; - inputs[0].addGrad( - Variable( - detail::tileAs( - detail::expandFromReduction(gradOutput.tensor(), axes, keepDims), - idims - ) - / count, - false - ) - ); - }; + / count, + false + ) + ); + }; return Variable(result, {input.withoutData()}, gradFunc); } Variable var( - const Variable& in, - const std::vector& axes, - const bool isbiased /* = false */, + Variable const& in, + std::vector const& axes, + bool const isbiased /* = false */, bool keepDims /* = false*/ ) { Tensor input = FL_ADJUST_INPUT_TYPE(in.tensor()); @@ -785,30 +828,30 @@ Variable var( result = val * (result - n * avg * avg); auto gradFunc = - [val, axes](std::vector& inputs, const Variable& gradOutput) { - Shape expandedDims = inputs[0].shape(); - Shape tileDims = inputs[0].shape(); - for(auto ax : axes) { - tileDims[ax] = inputs[0].dim(ax); - expandedDims[ax] = 1; - } + [val, axes](std::vector& inputs, Variable const& gradOutput) { + Shape expandedDims = inputs[0].shape(); + Shape tileDims = inputs[0].shape(); + for(auto ax : axes) { + tileDims[ax] = inputs[0].dim(ax); + expandedDims[ax] = 1; + } - inputs[0].addGrad( - Variable( - ((2 * val * tileAs(moddims(gradOutput, expandedDims), tileDims)) - * (inputs[0] - - tileAs(moddims(mean(inputs[0], axes), expandedDims), tileDims))) - .tensor(), - false - ) - ); - }; + inputs[0].addGrad( + Variable( + ((2 * val * tileAs(moddims(gradOutput, expandedDims), tileDims)) + * (inputs[0] + - tileAs(moddims(mean(inputs[0], axes), expandedDims), tileDims))) + .tensor(), + false + ) + ); + }; return Variable(result, {in}, gradFunc); } Variable norm( - const Variable& input, - const std::vector& axes, + Variable const& input, + std::vector const& axes, double p /* = 2 */, bool keepDims /* = false */ ) { @@ -823,25 +866,26 @@ Variable norm( auto gradFunc = [sumap, p, axes, keepDims]( std::vector& inputs, - const Variable& gradOutput) { - // correct, but less precise: auto gvar = Variable(fl::power(result, p - 1), - // false); - auto gvar = Variable(fl::power(sumap, 1 - 1 / p), false); - auto normGrad = - (inputs[0].tensor() * fl::pow(fl::abs(inputs[0]), p - 2).tensor() - * detail::tileAs( - detail::expandFromReduction(gradOutput.tensor(), axes, keepDims) - / gvar.tensor(), - inputs[0].shape() - )); - inputs[0].addGrad(Variable(normGrad, false)); - }; + Variable const& gradOutput + ) { + // correct, but less precise: auto gvar = Variable(fl::power(result, p - 1), + // false); + auto gvar = Variable(fl::power(sumap, 1 - 1 / p), false); + auto normGrad = + (inputs[0].tensor() * fl::pow(fl::abs(inputs[0]), p - 2).tensor() + * detail::tileAs( + detail::expandFromReduction(gradOutput.tensor(), axes, keepDims) + / gvar.tensor(), + inputs[0].shape() + )); + inputs[0].addGrad(Variable(normGrad, false)); + }; return Variable(result, {input}, gradFunc); } Variable normalize( - const Variable& in, - const std::vector& axes, + Variable const& in, + std::vector const& axes, double p /* = 2 */, double eps /* = 1e-12 */ ) { @@ -851,7 +895,7 @@ Variable normalize( return input / tileAs(invscale, input); } -Variable matmul(const Variable& lhs, const Variable& rhs) { +Variable matmul(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); // lhs:Input[0] -- [M, N] // rhs:Input[1] -- [N, K] @@ -859,50 +903,55 @@ Variable matmul(const Variable& lhs, const Variable& rhs) { // -- matmul([M, N], [N, K]) -- [M, K] // result:gradOutput -- [M, K] auto result = fl::matmul(lhs.tensor(), rhs.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - if(inputs[0].isCalcGrad()) { - Tensor _lhs = gradOutput.tensor(); - if(_lhs.ndim() == 1) - _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)}); - Tensor _rhs = inputs[1].tensor(); - if(_rhs.ndim() == 1) - _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1}); - - // matmulNT(gradOutput, inputs[1]) - // -- matmulNT([M, K], [N, K]) - // -- matmul([M, K], [K, N]) -- [M, K] - auto val = fl::matmul( - _lhs, - _rhs, - /* lhsProp = */ MatrixProperty::None, - /* rhsProp = */ MatrixProperty::Transpose - ); - inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); - } - if(inputs[1].isCalcGrad()) { - Tensor _lhs = inputs[0].tensor(); - if(_lhs.ndim() == 1) - _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)}); - Tensor _rhs = gradOutput.tensor(); - if(_rhs.ndim() == 1) - _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1}); - - // matmulTN(inputs[0], gradOutput) - // -- matmulTN([M, N], [M, K]) - // -- matmul([N, M], [M, K]) -- [N, K] - auto val = fl::matmul( - _lhs, - _rhs, - /* lhsProp = */ MatrixProperty::Transpose - ); - inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); - } - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + if(inputs[0].isCalcGrad()) { + Tensor _lhs = gradOutput.tensor(); + if(_lhs.ndim() == 1) + _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)}); + Tensor _rhs = inputs[1].tensor(); + if(_rhs.ndim() == 1) + _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1}); + + // matmulNT(gradOutput, inputs[1]) + // -- matmulNT([M, K], [N, K]) + // -- matmul([M, K], [K, N]) -- [M, K] + auto val = fl::matmul( + _lhs, + _rhs, + /* lhsProp = */ + MatrixProperty::None, + /* rhsProp = */ + MatrixProperty::Transpose + ); + inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); + } + if(inputs[1].isCalcGrad()) { + Tensor _lhs = inputs[0].tensor(); + if(_lhs.ndim() == 1) + _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)}); + Tensor _rhs = gradOutput.tensor(); + if(_rhs.ndim() == 1) + _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1}); + + // matmulTN(inputs[0], gradOutput) + // -- matmulTN([M, N], [M, K]) + // -- matmul([N, M], [M, K]) -- [N, K] + auto val = fl::matmul( + _lhs, + _rhs, + /* lhsProp = */ + MatrixProperty::Transpose + ); + inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); + } + }; return Variable(result, {lhs, rhs}, gradFunc); } -Variable matmulTN(const Variable& lhs, const Variable& rhs) { +Variable matmulTN(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); // lhs:Input[0] -- [N, M] // rhs:Input[1] -- [N, K] @@ -912,34 +961,39 @@ Variable matmulTN(const Variable& lhs, const Variable& rhs) { // result:gradOutput -- [M, K] auto result = fl::matmul( lhs.tensor(), - rhs.tensor(), /* lhsProp = */ + rhs.tensor(), + /* lhsProp = */ MatrixProperty::Transpose ); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - if(inputs[0].isCalcGrad()) { - // matmulNT(inputs[1], gradOutput) - // -- matmulNT([N, K], [M, K]) - // -- matmul([N, K], [K, M]) -- [N, M] - auto val = fl::matmul( - inputs[1].tensor(), - gradOutput.tensor(), - /* lhsProp = */ MatrixProperty::None, - /* rhsProp = */ MatrixProperty::Transpose - ); - inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); - } - if(inputs[1].isCalcGrad()) { - // matmul(inputs[0], gradOutput) - // -- matmulNT([N, M], [M, K]) -- [N, K] - auto val = fl::matmul(inputs[0].tensor(), gradOutput.tensor()); - inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); - } - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + if(inputs[0].isCalcGrad()) { + // matmulNT(inputs[1], gradOutput) + // -- matmulNT([N, K], [M, K]) + // -- matmul([N, K], [K, M]) -- [N, M] + auto val = fl::matmul( + inputs[1].tensor(), + gradOutput.tensor(), + /* lhsProp = */ + MatrixProperty::None, + /* rhsProp = */ + MatrixProperty::Transpose + ); + inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); + } + if(inputs[1].isCalcGrad()) { + // matmul(inputs[0], gradOutput) + // -- matmulNT([N, M], [M, K]) -- [N, K] + auto val = fl::matmul(inputs[0].tensor(), gradOutput.tensor()); + inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); + } + }; return Variable(result, {lhs, rhs}, gradFunc); } -Variable matmulNT(const Variable& lhs, const Variable& rhs) { +Variable matmulNT(Variable const& lhs, Variable const& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); // lhs:Input[0] -- [M, N] // rhs:Input[1] -- [K, N] @@ -950,54 +1004,61 @@ Variable matmulNT(const Variable& lhs, const Variable& rhs) { auto result = fl::matmul( lhs.tensor(), rhs.tensor(), - /* lhsProp = */ MatrixProperty::None, - /* rhsProp = */ MatrixProperty::Transpose + /* lhsProp = */ + MatrixProperty::None, + /* rhsProp = */ + MatrixProperty::Transpose ); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - if(inputs[0].isCalcGrad()) { - // matmul(gradOutput, inputs[1]) - // -- matmul([M, K], [K, N]) -- [M, N] - auto val = fl::matmul(gradOutput.tensor(), inputs[1].tensor()); - inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); - } - if(inputs[1].isCalcGrad()) { - // matmulTN(gradOutput, inputs[0]) - // -- matmulTN([M, K], [M, N]) - // -- matmul([K, M], [M, N]) -- [K, N] - auto val = fl::matmul( - gradOutput.tensor(), - inputs[0].tensor(), - /* lhsProp = */ MatrixProperty::Transpose - ); - inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); - } - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + if(inputs[0].isCalcGrad()) { + // matmul(gradOutput, inputs[1]) + // -- matmul([M, K], [K, N]) -- [M, N] + auto val = fl::matmul(gradOutput.tensor(), inputs[1].tensor()); + inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); + } + if(inputs[1].isCalcGrad()) { + // matmulTN(gradOutput, inputs[0]) + // -- matmulTN([M, K], [M, N]) + // -- matmul([K, M], [M, N]) -- [K, N] + auto val = fl::matmul( + gradOutput.tensor(), + inputs[0].tensor(), + /* lhsProp = */ + MatrixProperty::Transpose + ); + inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); + } + }; return Variable(result, {lhs, rhs}, gradFunc); } -Variable abs(const Variable& input) { +Variable abs(Variable const& input) { auto result = fl::abs(input.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - // Convert it into -1, 0, 1 - auto sign = fl::sign(inputs[0].tensor()); - inputs[0].addGrad(Variable((sign * gradOutput.tensor()), false)); - }; + auto gradFunc = []( + std::vector& inputs, + Variable const& gradOutput + ) { + // Convert it into -1, 0, 1 + auto sign = fl::sign(inputs[0].tensor()); + inputs[0].addGrad(Variable((sign * gradOutput.tensor()), false)); + }; return Variable(result, {input}, gradFunc); } -Variable flat(const Variable& input) { +Variable flat(Variable const& input) { auto result = input.tensor().flatten(); Shape idims = input.shape(); auto gradFunc = - [idims](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad(Variable(reshape(gradOutput.tensor(), idims), false)); - }; + [idims](std::vector& inputs, Variable const& gradOutput) { + inputs[0].addGrad(Variable(reshape(gradOutput.tensor(), idims), false)); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable moddims(const Variable& input, const Shape& dims) { +Variable moddims(Variable const& input, Shape const& dims) { if(input.ndim() == 0) return input; Shape inferDims = dims; @@ -1036,13 +1097,14 @@ Variable moddims(const Variable& input, const Shape& dims) { Shape inDims = input.shape(); auto gradFunc = [inDims]( std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(moddims(gradOutput, inDims).tensor(), false)); - }; + Variable const& gradOutput + ) { + inputs[0].addGrad(Variable(moddims(gradOutput, inDims).tensor(), false)); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable softmax(const Variable& input, const int dim) { +Variable softmax(Variable const& input, int const dim) { Tensor inputArr = FL_ADJUST_INPUT_TYPE(input.tensor()); auto maxvals = amax(inputArr, {dim}, /* keepDims = */ true); Shape tiledims(std::vector(input.ndim(), 1)); @@ -1055,17 +1117,18 @@ Variable softmax(const Variable& input, const int dim) { fl::eval(result); auto gradFunc = [dim, tiledims, result]( std::vector& inputs, - const Variable& gradOutput) { - auto rbyg = gradOutput.tensor() * result; - auto gradSm = rbyg - - result - * fl::tile(fl::sum(rbyg, {dim}, /* keepDims = */ true), tiledims); - inputs[0].addGrad(Variable(gradSm.astype(inputs[0].type()), false)); - }; + Variable const& gradOutput + ) { + auto rbyg = gradOutput.tensor() * result; + auto gradSm = rbyg + - result + * fl::tile(fl::sum(rbyg, {dim}, /* keepDims = */ true), tiledims); + inputs[0].addGrad(Variable(gradSm.asType(inputs[0].type()), false)); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable logSoftmax(const Variable& input, const int dim) { +Variable logSoftmax(Variable const& input, int const dim) { Tensor inputArr = FL_ADJUST_INPUT_TYPE(input.tensor()); auto maxvals = amax(inputArr, {dim}, /* keepDims = */ true); // TODO{fl::Tensor}{rewrite} @@ -1077,7 +1140,8 @@ Variable logSoftmax(const Variable& input, const int dim) { fl::sum( fl::exp(inputArr - fl::tile(maxvals, tiledims)), {dim}, - /* keepDims = */ true + /* keepDims = */ + true ) ) + maxvals, @@ -1087,28 +1151,29 @@ Variable logSoftmax(const Variable& input, const int dim) { fl::eval(result); auto gradFunc = [dim, tiledims, result]( std::vector& inputs, - const Variable& gradOutput) { - auto gradLsm = gradOutput.tensor() - - fl::exp(result) - * fl::tile( - fl::sum(gradOutput.tensor(), {dim}, /* keepDims = */ true), - tiledims - ); - inputs[0].addGrad(Variable(gradLsm.astype(inputs[0].type()), false)); - }; + Variable const& gradOutput + ) { + auto gradLsm = gradOutput.tensor() + - fl::exp(result) + * fl::tile( + fl::sum(gradOutput.tensor(), {dim}, /* keepDims = */ true), + tiledims + ); + inputs[0].addGrad(Variable(gradLsm.asType(inputs[0].type()), false)); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable binaryCrossEntropy(const Variable& inputs, const Variable& targets) { - auto targetsTyped = targets.astype(inputs.type()); +Variable binaryCrossEntropy(Variable const& inputs, Variable const& targets) { + auto targetsTyped = targets.asType(inputs.type()); return negate( targetsTyped * log(inputs) + (1 - targetsTyped) * log(1 - inputs) ); } Variable categoricalCrossEntropy( - const Variable& in, - const Variable& targets, + Variable const& in, + Variable const& targets, ReduceMode reduction /* =ReduceMode::MEAN */, int ignoreIndex /* = -1 */ ) { @@ -1129,7 +1194,7 @@ Variable categoricalCrossEntropy( int C = input.dim(0); int X = targets.elements(); if( - fl::any( + fl::any_of( ((targets.tensor() < 0) || (targets.tensor() >= C)) && (targets.tensor() != ignoreIndex) ) @@ -1155,12 +1220,15 @@ Variable categoricalCrossEntropy( Tensor denominator; if(reduction == ReduceMode::NONE) { result = fl::reshape(result, targets.shape()); // [X1 X2 X3] - } else if(reduction == ReduceMode::MEAN) { - denominator = fl::sum((!ignoreMask).astype(fl::dtype::s32), {0}); + } + else if(reduction == ReduceMode::MEAN) { + denominator = fl::sum((!ignoreMask).asType(fl::dtype::s32), {0}); result = fl::sum(result, {0}) / denominator; // [1] - } else if(reduction == ReduceMode::SUM) { + } + else if(reduction == ReduceMode::SUM) { result = fl::sum(result, {0}); // [1] - } else + } + else throw std::invalid_argument( "unknown reduction method for categorical cross entropy" ); @@ -1168,28 +1236,29 @@ Variable categoricalCrossEntropy( auto inputDims = input.shape(); auto gradFunc = [C, X, mask, ignoreMask, denominator, reduction, inputDims]( std::vector& inputs, - const Variable& gradOutput) { - Tensor grad = gradOutput.tensor(); - if(reduction == ReduceMode::NONE) - grad = fl::reshape(grad, {X}); - else if(reduction == ReduceMode::MEAN) - grad = fl::tile(grad / denominator, {X}); - else if(reduction == ReduceMode::SUM) - grad = fl::tile(grad, {X}); - // [1 X] - grad(ignoreMask) = 0.; - grad = fl::reshape(grad, {1, X}); - grad = fl::tile(grad, {C}) * mask; - inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false)); - }; + Variable const& gradOutput + ) { + Tensor grad = gradOutput.tensor(); + if(reduction == ReduceMode::NONE) + grad = fl::reshape(grad, {X}); + else if(reduction == ReduceMode::MEAN) + grad = fl::tile(grad / denominator, {X}); + else if(reduction == ReduceMode::SUM) + grad = fl::tile(grad, {X}); + // [1 X] + grad(ignoreMask) = 0.; + grad = fl::reshape(grad, {1, X}); + grad = fl::tile(grad, {C}) * mask; + inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false)); + }; return Variable(result, {input.withoutData(), targets}, gradFunc); } Variable weightedCategoricalCrossEntropy( - const Variable& input, - const Variable& targets, - const Variable& weight, + Variable const& input, + Variable const& targets, + Variable const& weight, int ignoreIndex /* = -1 */ ) { // input -- [C, X1, X2, X3] @@ -1213,7 +1282,7 @@ Variable weightedCategoricalCrossEntropy( int C = input.dim(0); int X = targets.elements(); if( - fl::any((targets.tensor() < 0) || (targets.tensor() >= C)) + fl::any_of((targets.tensor() < 0) || (targets.tensor() >= C)) .scalar() ) throw std::invalid_argument( @@ -1234,29 +1303,30 @@ Variable weightedCategoricalCrossEntropy( auto result = mask * x; result = result * weight.tensor(); - auto ignoreMask = (y != ignoreIndex).astype(fl::dtype::s32); // [1, X] + auto ignoreMask = (y != ignoreIndex).asType(fl::dtype::s32); // [1, X] result = ignoreMask * fl::sum(result, {0}, /* keepDims = */ true); // [1, X] result = fl::sum(result, {1}, /* keepDims = */ true) / denominator.tensor(); auto inputDims = input.shape(); auto gradFunc = [C, X, mask, ignoreMask, denominator, inputDims]( std::vector& inputs, - const Variable& gradOutput) { - auto grad = gradOutput.tensor(); - grad = fl::tile(grad / denominator.tensor(), {1, X}); - - auto weightTensor = inputs[2].tensor(); - grad *= ignoreMask; - grad = fl::tile(grad, {C}) * mask; - grad = fl::reshape(grad, inputDims); - grad = grad * weightTensor; - inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false)); - }; + Variable const& gradOutput + ) { + auto grad = gradOutput.tensor(); + grad = fl::tile(grad / denominator.tensor(), {1, X}); + + auto weightTensor = inputs[2].tensor(); + grad *= ignoreMask; + grad = fl::tile(grad, {C}) * mask; + grad = fl::reshape(grad, inputDims); + grad = grad * weightTensor; + inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false)); + }; return Variable(result, {input.withoutData(), targets, weight}, gradFunc); } -Variable reorder(const Variable& input, const Shape& shape) { +Variable reorder(Variable const& input, Shape const& shape) { auto result = fl::transpose(input.tensor(), shape); if(!result.isContiguous()) result = result.asContiguousTensor(); @@ -1268,24 +1338,24 @@ Variable reorder(const Variable& input, const Shape& shape) { std::sort(dimGrad.begin(), dimGrad.end()); auto gradFunc = - [dimGrad](std::vector& inputs, const Variable& gradOutput) { - Shape reordered(std::vector(dimGrad.size())); - for(unsigned i = 0; i < dimGrad.size(); ++i) - reordered[i] = dimGrad[i].second; + [dimGrad](std::vector& inputs, Variable const& gradOutput) { + Shape reordered(std::vector(dimGrad.size())); + for(unsigned i = 0; i < dimGrad.size(); ++i) + reordered[i] = dimGrad[i].second; - inputs[0].addGrad( - Variable(fl::transpose(gradOutput.tensor(), reordered), false) - ); - }; + inputs[0].addGrad( + Variable(fl::transpose(gradOutput.tensor(), reordered), false) + ); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable linear(const Variable& input, const Variable& weight) { - auto dummyBias = Variable(Tensor().astype(input.type()), false); +Variable linear(Variable const& input, Variable const& weight) { + auto dummyBias = Variable(Tensor().asType(input.type()), false); return linear(input, weight, dummyBias); } -Variable linear(const Variable& in, const Variable& wt, const Variable& bs) { +Variable linear(Variable const& in, Variable const& wt, Variable const& bs) { FL_VARIABLE_DTYPES_MATCH_CHECK(in, wt, bs); auto input = FL_ADJUST_INPUT_TYPE(in); auto weight = FL_ADJUST_INPUT_TYPE(wt); @@ -1307,42 +1377,43 @@ Variable linear(const Variable& in, const Variable& wt, const Variable& bs) { auto gradFunc = [hasBias]( std::vector& inputs, - const Variable& gradOutput) { - auto& in = inputs[0]; - auto& wt = inputs[1]; - Tensor wtTensor = wt.tensor(); - Tensor gradOutputTensor = gradOutput.tensor(); - - auto nframes = in.elements() / in.dim(0); - - if(hasBias && inputs[2].isCalcGrad()) { - auto& bs = inputs[2]; - auto biasGrad = sumAs(gradOutput, bs).tensor(); - bs.addGrad(Variable(biasGrad, false)); - } - if(in.isCalcGrad()) { - Shape to2dout({wtTensor.dim(0), nframes}); - auto inGrad = - moddims(matmulTN(wt, moddims(gradOutput, to2dout)), in.shape()) - .tensor(); - in.addGrad(Variable(inGrad, false)); - } - if(wt.isCalcGrad()) { - Shape to2din({wtTensor.dim(1), nframes}); - Shape to2dout({wtTensor.dim(0), nframes}); - auto wtGrad = - matmulNT(moddims(gradOutput, to2dout), moddims(in, to2din)).tensor(); - wt.addGrad(Variable(wtGrad, false)); - } - }; + Variable const& gradOutput + ) { + auto& in = inputs[0]; + auto& wt = inputs[1]; + Tensor wtTensor = wt.tensor(); + Tensor gradOutputTensor = gradOutput.tensor(); + + auto nframes = in.elements() / in.dim(0); + + if(hasBias && inputs[2].isCalcGrad()) { + auto& bs = inputs[2]; + auto biasGrad = sumAs(gradOutput, bs).tensor(); + bs.addGrad(Variable(biasGrad, false)); + } + if(in.isCalcGrad()) { + Shape to2dout({wtTensor.dim(0), nframes}); + auto inGrad = + moddims(matmulTN(wt, moddims(gradOutput, to2dout)), in.shape()) + .tensor(); + in.addGrad(Variable(inGrad, false)); + } + if(wt.isCalcGrad()) { + Shape to2din({wtTensor.dim(1), nframes}); + Shape to2dout({wtTensor.dim(0), nframes}); + auto wtGrad = + matmulNT(moddims(gradOutput, to2dout), moddims(in, to2din)).tensor(); + wt.addGrad(Variable(wtGrad, false)); + } + }; if(hasBias) return Variable(output, {input, weight, bias}, gradFunc); return Variable(output, {input, weight}, gradFunc); } Variable conv2d( - const Variable& input, - const Variable& weights, + Variable const& input, + Variable const& weights, int sx, int sy, int px, @@ -1369,9 +1440,9 @@ Variable conv2d( } Variable conv2d( - const Variable& in, - const Variable& wt, - const Variable& bs, + Variable const& in, + Variable const& wt, + Variable const& bs, int sx, int sy, int px, @@ -1407,103 +1478,105 @@ Variable conv2d( auto gradFunc = [sx, sy, px, py, dx, dy, hasBias, groups, benchmarks, payload]( - std::vector& inputs, const Variable& gradOutput) { - // Create benchmarks if needed - auto& autogradExtension = - inputs[0].tensor().backend().getExtension(); - - std::shared_ptr dataBench; - std::shared_ptr filterBench; - std::shared_ptr biasBench; - if(benchmarks && DynamicBenchmark::getBenchmarkMode()) { - if(!benchmarks->bwdFilterBenchmark) { - benchmarks->bwdFilterBenchmark = - autogradExtension.createBenchmarkOptions(); - filterBench = benchmarks->bwdFilterBenchmark; - } - if(!benchmarks->bwdDataBenchmark) { - benchmarks->bwdDataBenchmark = - autogradExtension.createBenchmarkOptions(); - dataBench = benchmarks->bwdDataBenchmark; - } - if(!benchmarks->bwdBiasBenchmark) { - benchmarks->bwdBiasBenchmark = - autogradExtension.createBenchmarkOptions(); - biasBench = benchmarks->bwdBiasBenchmark; - } + std::vector& inputs, + Variable const& gradOutput + ) { + // Create benchmarks if needed + auto& autogradExtension = + inputs[0].tensor().backend().getExtension(); + + std::shared_ptr dataBench; + std::shared_ptr filterBench; + std::shared_ptr biasBench; + if(benchmarks && DynamicBenchmark::getBenchmarkMode()) { + if(!benchmarks->bwdFilterBenchmark) { + benchmarks->bwdFilterBenchmark = + autogradExtension.createBenchmarkOptions(); + filterBench = benchmarks->bwdFilterBenchmark; } - - // Bias gradients - Tensor bs; - const bool computeBiasGrad = - inputs.size() > 2 && inputs[2].isCalcGrad(); - if(hasBias && computeBiasGrad) { - bs = inputs[2].tensor(); - // auto biasGrad = - // bs.backend().getExtension().conv2dBackwardBias( - // gradOutput.tensor(), bs, biasBench, payload); - - // inputs[2].addGrad(Variable(biasGrad, false)); // bias + if(!benchmarks->bwdDataBenchmark) { + benchmarks->bwdDataBenchmark = + autogradExtension.createBenchmarkOptions(); + dataBench = benchmarks->bwdDataBenchmark; } - - auto& in = inputs[0].tensor(); - auto& wt = inputs[1].tensor(); - - // Data (input) gradients - if(inputs[0].isCalcGrad()) { - auto dataGrad = - in.backend().getExtension().conv2dBackwardData( - gradOutput.tensor(), - in, - wt, - sx, - sy, - px, - py, - dx, - dy, - groups, - dataBench, - payload - ); - - inputs[0].addGrad(Variable(dataGrad, false)); // input/data + if(!benchmarks->bwdBiasBenchmark) { + benchmarks->bwdBiasBenchmark = + autogradExtension.createBenchmarkOptions(); + biasBench = benchmarks->bwdBiasBenchmark; } + } + + // Bias gradients + Tensor bs; + bool const computeBiasGrad = + inputs.size() > 2 && inputs[2].isCalcGrad(); + if(hasBias && computeBiasGrad) { + bs = inputs[2].tensor(); + // auto biasGrad = + // bs.backend().getExtension().conv2dBackwardBias( + // gradOutput.tensor(), bs, biasBench, payload); + + // inputs[2].addGrad(Variable(biasGrad, false)); // bias + } + + auto& in = inputs[0].tensor(); + auto& wt = inputs[1].tensor(); + + // Data (input) gradients + if(inputs[0].isCalcGrad()) { + auto dataGrad = + in.backend().getExtension().conv2dBackwardData( + gradOutput.tensor(), + in, + wt, + sx, + sy, + px, + py, + dx, + dy, + groups, + dataBench, + payload + ); - // Filter (weight) and bias gradients - if(inputs[1].isCalcGrad() || computeBiasGrad) { - auto [filterGrad, biasGrad] = wt.backend() - .getExtension() - .conv2dBackwardFilterBias( - gradOutput.tensor(), - in, - wt, - bs, - sx, - sy, - px, - py, - dx, - dy, - groups, - filterBench, - biasBench, - payload - ); - if(inputs[1].isCalcGrad()) { - inputs[1].addGrad(Variable(filterGrad, false)); // filter/weight - } - if(computeBiasGrad) - inputs[2].addGrad(Variable(biasGrad, false)); + inputs[0].addGrad(Variable(dataGrad, false)); // input/data + } + + // Filter (weight) and bias gradients + if(inputs[1].isCalcGrad() || computeBiasGrad) { + auto [filterGrad, biasGrad] = wt.backend() + .getExtension() + .conv2dBackwardFilterBias( + gradOutput.tensor(), + in, + wt, + bs, + sx, + sy, + px, + py, + dx, + dy, + groups, + filterBench, + biasBench, + payload + ); + if(inputs[1].isCalcGrad()) { + inputs[1].addGrad(Variable(filterGrad, false)); // filter/weight } - }; + if(computeBiasGrad) + inputs[2].addGrad(Variable(biasGrad, false)); + } + }; if(hasBias) return Variable(output, {input, weights, bias}, gradFunc); return Variable(output, {input, weights}, gradFunc); } Variable pool2d( - const Variable& input, + Variable const& input, int wx, int wy, int sx, @@ -1518,40 +1591,41 @@ Variable pool2d( auto gradFunc = [wx, wy, sx, sy, px, py, mode, output, payload]( std::vector& inputs, - const Variable& gradOutput) { - auto& in = inputs[0]; - if(!in.isCalcGrad()) - return; + Variable const& gradOutput + ) { + auto& in = inputs[0]; + if(!in.isCalcGrad()) + return; - in.addGrad( - Variable( - in.tensor().backend().getExtension().pool2dBackward( - gradOutput.tensor(), - in.tensor(), - output, - wx, - wy, - sx, - sy, - px, - py, - mode, - payload - ), - false - ) - ); - }; + in.addGrad( + Variable( + in.tensor().backend().getExtension().pool2dBackward( + gradOutput.tensor(), + in.tensor(), + output, + wx, + wy, + sx, + sy, + px, + py, + mode, + payload + ), + false + ) + ); + }; return Variable(output, {input}, gradFunc); } Variable batchnorm( - const Variable& _input, - const Variable& weight, - const Variable& bias, + Variable const& _input, + Variable const& weight, + Variable const& bias, Variable& runningMean, Variable& runningVar, - const std::vector& axes, + std::vector const& axes, bool train, double momentum, double epsilon @@ -1581,41 +1655,41 @@ Variable batchnorm( train, axes, epsilon, - payload](std::vector& inputs, const Variable& _gradOutput) { - auto& in = inputs[0]; - auto& wt = inputs[1]; - auto& bs = inputs[2]; - - auto gradOutput = detail::adjustInputType(_gradOutput, "batchnorm"); - - if(!in.isCalcGrad() && !wt.isCalcGrad() && !bs.isCalcGrad()) - return; - - auto [gradIn, gradWt, gradBs] = - in.tensor() - .backend() - .getExtension() - .batchnormBackward( - gradOutput.tensor(), - saveMean, - saveVar, - detail::adjustInputType(in.tensor(), "batchnorm"), - wt.tensor(), - axes, - train, - epsilon, - payload - ); - - in.addGrad(Variable(gradIn.astype(in.type()), false)); - wt.addGrad(Variable(gradWt.astype(wt.type()), false)); - if(!bs.isEmpty()) - bs.addGrad(Variable(gradBs.astype(bs.type()), false)); - }; + payload](std::vector& inputs, Variable const& _gradOutput) { + auto& in = inputs[0]; + auto& wt = inputs[1]; + auto& bs = inputs[2]; + + auto gradOutput = detail::adjustInputType(_gradOutput, "batchnorm"); + + if(!in.isCalcGrad() && !wt.isCalcGrad() && !bs.isCalcGrad()) + return; + + auto [gradIn, gradWt, gradBs] = + in.tensor() + .backend() + .getExtension() + .batchnormBackward( + gradOutput.tensor(), + saveMean, + saveVar, + detail::adjustInputType(in.tensor(), "batchnorm"), + wt.tensor(), + axes, + train, + epsilon, + payload + ); + + in.addGrad(Variable(gradIn.asType(in.type()), false)); + wt.addGrad(Variable(gradWt.asType(wt.type()), false)); + if(!bs.isEmpty()) + bs.addGrad(Variable(gradBs.asType(bs.type()), false)); + }; return Variable(output, {input, weight, bias}, gradFunc); } -Variable gatedlinearunit(const Variable& input, const int dim) { +Variable gatedlinearunit(Variable const& input, int const dim) { if(dim >= input.ndim()) throw std::invalid_argument( "gatedlinearunit - passed dim is great than the " @@ -1643,21 +1717,22 @@ Variable gatedlinearunit(const Variable& input, const int dim) { auto gradFunc = [fhalf, shalf, fhalfout, shalfout, inDims, inType]( std::vector& inputs, - const Variable& gradOutput) { - auto gradGlu = Tensor(inDims, inType); - gradGlu(fhalf) = shalfout * gradOutput.tensor(); - gradGlu(shalf) = - shalfout * (1.0 - shalfout) * fhalfout * gradOutput.tensor(); - inputs[0].addGrad(Variable(gradGlu, false)); - }; + Variable const& gradOutput + ) { + auto gradGlu = Tensor(inDims, inType); + gradGlu(fhalf) = shalfout * gradOutput.tensor(); + gradGlu(shalf) = + shalfout * (1.0 - shalfout) * fhalfout * gradOutput.tensor(); + inputs[0].addGrad(Variable(gradGlu, false)); + }; return Variable(fhalfout * shalfout, {input.withoutData()}, gradFunc); } std::tuple rnn( - const Variable& input, - const Variable& hiddenState, - const Variable& cellState, - const Variable& weights, + Variable const& input, + Variable const& hiddenState, + Variable const& cellState, + Variable const& weights, int hiddenSize, int numLayers, RnnMode mode, @@ -1691,63 +1766,65 @@ std::tuple rnn( dropProb, gradData, payload]( - std::vector& inputs, - const Variable& /* gradOutput */) { - auto& input = inputs[0]; - auto& hiddenState = inputs[1]; - auto& cellState = inputs[2]; - auto& weights = inputs[3]; - - if( - !(input.isCalcGrad() || hiddenState.isCalcGrad() + std::vector& inputs, + Variable const& /* gradOutput */ + + ) { + auto& input = inputs[0]; + auto& hiddenState = inputs[1]; + auto& cellState = inputs[2]; + auto& weights = inputs[3]; + + if( + !(input.isCalcGrad() || hiddenState.isCalcGrad() || cellState.isCalcGrad() || weights.isCalcGrad()) - ) - return; - - auto [dy, dhy, dcy, dweights] = - input.tensor().backend().getExtension().rnnBackward( - input.tensor(), - hiddenState.tensor(), - cellState.tensor(), - weights.tensor(), - gradData, - output, - numLayers, - hiddenSize, - mode, - bidirectional, - dropProb, - payload - ); + ) + return; + + auto [dy, dhy, dcy, dweights] = + input.tensor().backend().getExtension().rnnBackward( + input.tensor(), + hiddenState.tensor(), + cellState.tensor(), + weights.tensor(), + gradData, + output, + numLayers, + hiddenSize, + mode, + bidirectional, + dropProb, + payload + ); - input.addGrad(Variable(dy.astype(input.type()), false)); - hiddenState.addGrad(Variable(dhy.astype(hiddenState.type()), false)); - cellState.addGrad(Variable(dcy.astype(cellState.type()), false)); - weights.addGrad(Variable(dweights.astype(weights.type()), false)); - }; + input.addGrad(Variable(dy.asType(input.type()), false)); + hiddenState.addGrad(Variable(dhy.asType(hiddenState.type()), false)); + cellState.addGrad(Variable(dcy.asType(cellState.type()), false)); + weights.addGrad(Variable(dweights.asType(weights.type()), false)); + }; Variable dummy(Tensor(), {input, hiddenState, cellState, weights}, gradFunc); auto dyGradFunc = - [gradData](std::vector& inputs, const Variable& gradOutput) { - if(!inputs[0].isGradAvailable()) - inputs[0].addGrad(Variable(Tensor(), false)); - gradData->dy = gradOutput.tensor().asContiguousTensor(); - }; + [gradData](std::vector& inputs, Variable const& gradOutput) { + if(!inputs[0].isGradAvailable()) + inputs[0].addGrad(Variable(Tensor(), false)); + gradData->dy = gradOutput.tensor().asContiguousTensor(); + }; auto dhyGradFunc = - [gradData](std::vector& inputs, const Variable& gradOutput) { - if(!inputs[0].isGradAvailable()) - inputs[0].addGrad(Variable(Tensor(), false)); - gradData->dhy = gradOutput.tensor().asContiguousTensor(); - }; + [gradData](std::vector& inputs, Variable const& gradOutput) { + if(!inputs[0].isGradAvailable()) + inputs[0].addGrad(Variable(Tensor(), false)); + gradData->dhy = gradOutput.tensor().asContiguousTensor(); + }; auto dcyGradFunc = - [gradData](std::vector& inputs, const Variable& gradOutput) { - if(!inputs[0].isGradAvailable()) - inputs[0].addGrad(Variable(Tensor(), false)); - gradData->dcy = gradOutput.tensor().asContiguousTensor(); - }; + [gradData](std::vector& inputs, Variable const& gradOutput) { + if(!inputs[0].isGradAvailable()) + inputs[0].addGrad(Variable(Tensor(), false)); + gradData->dcy = gradOutput.tensor().asContiguousTensor(); + }; Variable yv(output, {dummy}, dyGradFunc); // output Variable hyv(hiddenOut, {dummy}, dhyGradFunc); // hidden state output @@ -1755,63 +1832,67 @@ std::tuple rnn( return std::make_tuple(yv, hyv, cyv); } -Variable embedding(const Variable& input, const Variable& embeddings) { +Variable embedding(Variable const& input, Variable const& embeddings) { // TODO{fl::Tensor}{4-dims} - relax this if(input.ndim() >= 4) - throw std::invalid_argument("embedding input must have 3 or fewer dims"); + throw std::invalid_argument{"embedding input must have 3 or fewer dims"}; - auto idxs = input.tensor().flatten(); + auto const idxs = input.tensor().flatten(); auto inDims = input.shape(); std::vector rDims(input.ndim() + 1); rDims[0] = embeddings.dim(0); - for(unsigned i = 1; i < input.ndim() + 1; i++) + for(Dim i = 1; i < input.ndim() + 1; i++) rDims[i] = inDims[i - 1]; - Shape resultDims(rDims); - Tensor result = fl::reshape(embeddings.tensor()(fl::span, idxs), resultDims); - - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto& w = inputs[1]; - if(!w.isCalcGrad()) - return; - - auto ip = inputs[0].tensor().flatten(); - unsigned size = ip.elements(); - auto deltas = fl::reshape(gradOutput.tensor(), {w.dim(0), size}); - - // Sparse Tensor - auto sp = Tensor( - ip.elements(), - w.dim(1), - fl::full({size}, 1, deltas.type()), - fl::arange({size + 1}, 0, fl::dtype::s32), - ip.astype(fl::dtype::s32), - fl::StorageType::CSR - ); - - auto grad = transpose( - fl::matmul( - sp, - transpose(deltas), /* lhsProp = */ - MatrixProperty::Transpose - ) - ); - w.addGrad(Variable(grad, false)); - }; - - return Variable(result, {input, embeddings}, gradFunc); + + Shape const resultDims{rDims}; + auto const result = fl::reshape(embeddings.tensor()(fl::span, idxs), resultDims); + + auto grad_func = []( + std::vector& inputs, + Variable const& gradOutput + ) { + auto& w = inputs[1]; + if(!w.isCalcGrad()) + return; + + auto const ip = inputs[0].tensor().flatten(); + auto size = static_cast(ip.elements()); + auto const deltas = fl::reshape(gradOutput.tensor(), {w.dim(0), size}); + + // Sparse Tensor + auto const sp = Tensor{ + static_cast(ip.elements()), + w.dim(1), + fl::full({size}, 1, deltas.type()), + fl::arange({size + 1}, 0, fl::dtype::s32), + ip.asType(fl::dtype::s32), + fl::StorageType::CSR + }; + + auto const grad = transpose( + fl::matmul( + sp, + transpose(deltas), + /* lhsProp = */ + MatrixProperty::Transpose + ) + ); + w.addGrad(Variable{grad, false}); + }; + + return Variable{result, {input, embeddings}, grad_func}; } Variable padding( - const Variable& input, + Variable const& input, std::vector> pad, double val ) { if(pad.size() > input.ndim()) - throw std::invalid_argument( + throw std::invalid_argument{ "padding: number of padding dimensions exceeds number " "of input dimensions" - ); + }; Shape opDims = input.shape(); std::vector inSeq(input.ndim(), fl::span); @@ -1823,33 +1904,34 @@ Variable padding( result(inSeq) = input.tensor(); auto gradFunc = - [inSeq](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor()(inSeq), false)); - }; + [inSeq](std::vector& inputs, Variable const& gradOutput) { + inputs[0].addGrad(Variable(gradOutput.tensor()(inSeq), false)); + }; return Variable(result, {input.withoutData()}, gradFunc); } -Variable dropout(const Variable& input, double p) { +Variable dropout(Variable const& input, double p) { if(p > 0.0) { auto mask = Variable( - (fl::rand(input.shape(), input.type()) > p).astype(input.type()), + (fl::rand(input.shape(), input.type()) > p).asType(input.type()), false ); return 1.0 / (1.0 - p) * mask * input; - } else + } + else return input; } -Variable relu(const Variable& input) { return max(input, 0.0); } +Variable relu(Variable const& input) { return max(input, 0.0); } -Variable gelu(const Variable& in) { +Variable gelu(Variable const& in) { auto input = FL_ADJUST_INPUT_TYPE(in); return 0.5 * input - * (1.0 - + fl::tanh(0.7978845608 * (input + 0.044715 * input * input * input))); + * (1.0 + + fl::tanh(0.7978845608 * (input + 0.044715 * input * input * input))); } -fl::Variable relativePositionEmbeddingRotate(const fl::Variable& input) { +fl::Variable relativePositionEmbeddingRotate(fl::Variable const& input) { if(input.ndim() != 3) throw std::invalid_argument( "relativePositionEmbeddingRotate - " @@ -1870,31 +1952,32 @@ fl::Variable relativePositionEmbeddingRotate(const fl::Variable& input) { data = fl::reshape(data, {d0 + d1 - 1, d1, d2}); auto gradFunc = [d0, d1, d2]( std::vector& inputs, - const fl::Variable& gradOutput) { - auto gradData = gradOutput.tensor(); - gradData = fl::reshape(gradData, {(d0 + d1 - 1) * d1, 1, d2}); - gradData = fl::concatenate( - 0, - gradData, - fl::full({d1, 1, d2}, 0.0, gradData.type()) - ); - gradData = reshape(gradData, {d0 + d1, d1, d2}); - gradData = Variable(gradData, false)(fl::range(0, d0)).tensor(); - inputs[0].addGrad(fl::Variable(gradData, false)); - }; + fl::Variable const& gradOutput + ) { + auto gradData = gradOutput.tensor(); + gradData = fl::reshape(gradData, {(d0 + d1 - 1) * d1, 1, d2}); + gradData = fl::concatenate( + 0, + gradData, + fl::full({d1, 1, d2}, 0.0, gradData.type()) + ); + gradData = reshape(gradData, {d0 + d1, d1, d2}); + gradData = Variable(gradData, false)(fl::range(0, d0)).tensor(); + inputs[0].addGrad(fl::Variable(gradData, false)); + }; return fl::Variable(data, {input}, gradFunc); } fl::Variable multiheadAttention( - const fl::Variable& query, - const fl::Variable& key, - const fl::Variable& value, - const fl::Variable& posEmb, - const fl::Variable& mask, - const fl::Variable& padMask, - const int32_t nHeads, - const double pDropout, - const int32_t offset /* = 0 */ + fl::Variable const& query, + fl::Variable const& key, + fl::Variable const& value, + fl::Variable const& posEmb, + fl::Variable const& mask, + fl::Variable const& padMask, + int32_t const nHeads, + double const pDropout, + int32_t const offset /* = 0 */ ) { if(query.ndim() != 3) throw std::invalid_argument( @@ -1925,12 +2008,12 @@ fl::Variable multiheadAttention( if(!posEmb.isEmpty()) { int n = posEmb.dim(0) / 2 - offset; auto pscores = - relativePositionEmbeddingRotate(matmulNT(posEmb.astype(q.type()), q)); + relativePositionEmbeddingRotate(matmulNT(posEmb.asType(q.type()), q)); scores = scores + transpose(pscores(fl::range(n, n + k.dim(0))), {1, 0, 2}); } if(!mask.isEmpty()) - scores = scores + tileAs(mask.astype(scores.type()), scores); + scores = scores + tileAs(mask.asType(scores.type()), scores); if(!padMask.isEmpty()) { if(padMask.dim(0) != query.dim(0)) throw std::invalid_argument( @@ -1941,13 +2024,13 @@ fl::Variable multiheadAttention( tileAs(padMaskTile, {padMask.dim(0), padMask.dim(0), nHeads, bsz}); scores = scores + moddims( - padMaskTile.astype(scores.type()), + padMaskTile.asType(scores.type()), {padMask.dim(0), padMask.dim(0), nHeads * bsz} ); } auto attn = dropout(softmax(scores, 1), pDropout); - auto result = matmul(attn.astype(v.type()), v); + auto result = matmul(attn.asType(v.type()), v); result = moddims(result, {-1, headDim * nHeads, bsz}); return result; } diff --git a/flashlight/fl/autograd/Functions.h b/flashlight/fl/autograd/Functions.h index b2d23a0..e406a59 100644 --- a/flashlight/fl/autograd/Functions.h +++ b/flashlight/fl/autograd/Functions.h @@ -71,11 +71,11 @@ namespace detail { && optimLevel != OptimLevel::DEFAULT ) // Not in the excluded list - cast to f16 - res = in.astype(fl::dtype::f16); + res = in.asType(fl::dtype::f16); else { // Upcast to f32 only if we have an f16 input - otherwise, leave as is if(in.type() == fl::dtype::f16) - res = in.astype(fl::dtype::f32); + res = in.asType(fl::dtype::f32); else res = in; } @@ -449,7 +449,7 @@ FL_API Variable concatenate(const std::vector& concatInputs, int dim); * divisible, last chunk of smaller splitSize will be included. * @param dim dimension along which to split the Variable */ -FL_API std::vector split(const Variable& input, long splitSize, int dim); +FL_API std::vector split(const Variable& input, Dim splitSize, int dim); /** * Splits a Variable into smaller chunks. @@ -458,7 +458,7 @@ FL_API std::vector split(const Variable& input, long splitSize, int di * @param splitSizes vector of integers specifying the sizes for each split * @param dim dimension along which to split the Variable */ -FL_API std::vector split(const Variable& input, const std::vector& splitSizes, int dim); +FL_API std::vector split(const Variable& input, std::vector const& splitSizes, int dim); /** * Repeats the tensor `input` along specific dimensions. The number of diff --git a/flashlight/fl/autograd/Variable.cpp b/flashlight/fl/autograd/Variable.cpp index bd1fb6a..fae4ec8 100644 --- a/flashlight/fl/autograd/Variable.cpp +++ b/flashlight/fl/autograd/Variable.cpp @@ -38,9 +38,7 @@ Variable::Variable( std::any_of( inputs.begin(), inputs.end(), - [](const Variable& input) { - return input.isCalcGrad(); - } + [](const Variable& input) { return input.isCalcGrad(); } ) ) { sharedGrad_->calcGrad = true; @@ -56,15 +54,16 @@ Variable Variable::operator()(const std::vector& indices) const { auto gradFunc = [indices, inDims, inType]( std::vector& inputs, - const Variable& gradOutput) { - if(!inputs[0].isGradAvailable()) { - auto grad = fl::full(inDims, 0.0, inType); - inputs[0].addGrad(Variable(grad, false)); - } - - auto& grad = inputs[0].grad().tensor(); - grad(indices) += gradOutput.tensor(); - }; + const Variable& gradOutput + ) { + if(!inputs[0].isGradAvailable()) { + auto grad = fl::full(inDims, 0.0, inType); + inputs[0].addGrad(Variable(grad, false)); + } + + auto& grad = inputs[0].grad().tensor(); + grad(indices) += gradOutput.tensor(); + }; return Variable(result, {this->withoutData()}, gradFunc); } @@ -75,34 +74,33 @@ Variable Variable::flat(const fl::Index& index) const { auto gradFunc = [index, inDims, inType]( std::vector& inputs, - const Variable& gradOutput) { - if(!inputs[0].isGradAvailable()) { - auto grad = fl::full(inDims, 0.0, inType); - inputs[0].addGrad(Variable(grad, false)); - } - auto& grad = inputs[0].grad().tensor(); - grad.flat(index) += gradOutput.tensor(); - }; + const Variable& gradOutput + ) { + if(!inputs[0].isGradAvailable()) { + auto grad = fl::full(inDims, 0.0, inType); + inputs[0].addGrad(Variable(grad, false)); + } + auto& grad = inputs[0].grad().tensor(); + grad.flat(index) += gradOutput.tensor(); + }; return Variable(result, {this->withoutData()}, gradFunc); } -Tensor& Variable::tensor() const { - return sharedData_->data; -} +Tensor& Variable::tensor() const { return sharedData_->data; } -Variable Variable::copy() const { - return Variable(sharedData_->data, sharedGrad_->calcGrad); -} +Variable Variable::copy() const { return Variable(sharedData_->data, sharedGrad_->calcGrad); } -Variable Variable::astype(fl::dtype newType) const { - auto output = tensor().astype(newType); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto& input = inputs[0]; - // Cast the grad output to match the type of the input's grad - input.addGrad(Variable(gradOutput.tensor().astype(input.type()), false)); - }; +Variable Variable::asType(fl::dtype newType) const { + auto output = tensor().asType(newType); + auto gradFunc = []( + std::vector& inputs, + const Variable& gradOutput + ) { + auto& input = inputs[0]; + // Cast the grad output to match the type of the input's grad + input.addGrad(Variable(gradOutput.tensor().asType(input.type()), false)); + }; return Variable(output, {this->withoutData()}, gradFunc); } @@ -116,13 +114,9 @@ Variable& Variable::grad() const { return *sharedGrad_->grad; } -std::vector& Variable::getInputs() const { - return sharedGrad_->inputs; -} +std::vector& Variable::getInputs() const { return sharedGrad_->inputs; } -bool Variable::isCalcGrad() const { - return sharedGrad_->calcGrad; -} +bool Variable::isCalcGrad() const { return sharedGrad_->calcGrad; } bool Variable::isGradAvailable() const { if(!sharedGrad_->calcGrad) @@ -130,17 +124,11 @@ bool Variable::isGradAvailable() const { return sharedGrad_->grad != nullptr; } -Shape Variable::shape() const { - return tensor().shape(); -} +Shape Variable::shape() const { return tensor().shape(); } -bool Variable::isEmpty() const { - return tensor().isEmpty(); -} +bool Variable::isEmpty() const { return tensor().isEmpty(); } -bool Variable::isContiguous() const { - return tensor().isContiguous(); -} +bool Variable::isContiguous() const { return tensor().isContiguous(); } Variable Variable::asContiguous() const { if(!isEmpty() && !isContiguous()) @@ -148,29 +136,17 @@ Variable Variable::asContiguous() const { return *this; } -fl::dtype Variable::type() const { - return tensor().type(); -} +fl::dtype Variable::type() const { return tensor().type(); } -Dim Variable::elements() const { - return tensor().elements(); -} +Dim Variable::elements() const { return tensor().elements(); } -size_t Variable::bytes() const { - return tensor().bytes(); -} +size_t Variable::bytes() const { return tensor().bytes(); } -unsigned Variable::ndim() const { - return tensor().ndim(); -} +unsigned Variable::ndim() const { return tensor().ndim(); } -Dim Variable::dim(unsigned dim) const { - return tensor().dim(dim); -} +Dim Variable::dim(unsigned dim) const { return tensor().dim(dim); } -void Variable::eval() const { - fl::eval(tensor()); -} +void Variable::eval() const { fl::eval(tensor()); } void Variable::zeroGrad() { sharedGrad_->grad.reset(); } @@ -190,8 +166,8 @@ void Variable::addGrad(const Variable& childGrad) { if(childGrad.type() != this->type()) { std::stringstream ss; ss << "Variable::addGrad: attempted to add child gradient of type " - << childGrad.type() << " to a Variable of type " << this->type() - << ". You might be performing an operation with " + << childGrad.type() << " to a Variable of type " << this->type() + << ". You might be performing an operation with " "two inputs of different types."; throw std::invalid_argument(ss.str()); } @@ -199,8 +175,8 @@ void Variable::addGrad(const Variable& childGrad) { std::stringstream ss; ss << "Variable::addGrad: given gradient has dimensions not equal " "to this Variable's dimensions: this variable has shape " - << this->shape() << " whereas the child gradient has dimensions " - << childGrad.shape() << std::endl; + << this->shape() << " whereas the child gradient has dimensions " + << childGrad.shape() << std::endl; throw std::invalid_argument(ss.str()); } if(sharedGrad_->grad) @@ -273,14 +249,14 @@ Variable::DAG Variable::build() const { // Topological sort recurse = [&](const Variable& var) { - auto id = var.sharedGrad_.get(); - if(cache.find(id) != cache.end()) - return; - for(const auto& input : var.getInputs()) - recurse(input); - cache.insert(id); - dag.push_back(var); - }; + auto id = var.sharedGrad_.get(); + if(cache.find(id) != cache.end()) + return; + for(const auto& input : var.getInputs()) + recurse(input); + cache.insert(id); + dag.push_back(var); + }; recurse(*this); return dag; diff --git a/flashlight/fl/autograd/Variable.h b/flashlight/fl/autograd/Variable.h index 60fb040..2ce498d 100644 --- a/flashlight/fl/autograd/Variable.h +++ b/flashlight/fl/autograd/Variable.h @@ -128,7 +128,12 @@ class FL_API Variable { * * @return returns the casted variable. */ - Variable astype(fl::dtype type) const; + Variable asType(fl::dtype type) const; + + /** + * @deprecated use @ref Variable::asType(fl::dtype) const instead + */ + Variable astype(fl::dtype type) const { return asType(type); } /** * @return a reference to the underlying gradient Variable. @@ -207,25 +212,19 @@ class FL_API Variable { * Must eventually be freed manually via `free` or a related call. */ template - T* host() const { - return tensor().host(); - } + T* host() const { return tensor().host(); } /** * Copies the array to the existing host pointer `ptr` */ template - void host(T* ptr) const { - tensor().host(ptr); - } + void host(T* ptr) const { tensor().host(ptr); } /** * Get the first element of the array as a scalar */ template - T scalar() const { - return tensor().scalar(); - } + T scalar() const { return tensor().scalar(); } /** * Remove the gradient stored by the Variable diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp index 9f6b315..bb4e7ad 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp @@ -1,8 +1,8 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ #include "flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h" @@ -48,15 +48,15 @@ namespace { if(minAxis == 0) { modeOut = CUDNN_BATCHNORM_PER_ACTIVATION; - inDescDimsOut = Shape( + inDescDimsOut = Shape{ { 1, 1, nfeatures, - static_cast(input.elements() / nfeatures) + static_cast(input.elements() / nfeatures) } - ); - wtDescDimsOut = Shape({1, 1, nfeatures}); + }; + wtDescDimsOut = Shape{1, 1, nfeatures}; } else { modeOut = CUDNN_BATCHNORM_SPATIAL; #if CUDNN_VERSION >= 7003 @@ -67,15 +67,15 @@ namespace { int batchsz = 1; for(int i = maxAxis + 1; i < input.ndim(); ++i) batchsz *= input.dim(i); - inDescDimsOut = Shape( + inDescDimsOut = Shape{ { 1, - static_cast(input.elements() / (nfeatures * batchsz)), + static_cast(input.elements() / (nfeatures * batchsz)), nfeatures, batchsz, } - ); - wtDescDimsOut = Shape({1, 1, nfeatures}); + }; + wtDescDimsOut = Shape{1, 1, nfeatures}; } } @@ -101,7 +101,7 @@ Tensor CudnnAutogradExtension::batchnorm( ); FL_TENSOR_DTYPES_MATCH_CHECK(weight, bias, runningMean, runningVar); - auto output = Tensor(input.shape(), input.type()); + auto output = Tensor{input.shape(), input.type()}; cudnnBatchNormMode_t mode; Shape inDescDims, wtDescDims; @@ -115,15 +115,15 @@ Tensor CudnnAutogradExtension::batchnorm( // Weight, bias, and running mean/var arrays can't be fp16 (must be fp32) Tensor weightArray = weight.isEmpty() ? fl::full(wtDescDims, 1.0, fl::dtype::f32) - : weight.astype(fl::dtype::f32); + : weight.asType(fl::dtype::f32); Tensor biasArray = bias.isEmpty() ? fl::full(wtDescDims, 0.0, fl::dtype::f32) - : bias.astype(fl::dtype::f32); + : bias.asType(fl::dtype::f32); fl::dtype scalarsType = input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); - auto inDesc = TensorDescriptor(input.type(), inDescDims); - auto wtDesc = TensorDescriptor(weightArray.type(), wtDescDims); + auto inDesc = TensorDescriptor{input.type(), inDescDims}; + auto wtDesc = TensorDescriptor{weightArray.type(), wtDescDims}; { DevicePtr inRaw(input); @@ -140,8 +140,8 @@ Tensor CudnnAutogradExtension::batchnorm( ); if(train) { - saveMean = Tensor({wtDescDims[2]}, scalarsType); - saveVar = Tensor({wtDescDims[2]}, scalarsType); + saveMean = Tensor{{wtDescDims[2]}, scalarsType}; + saveVar = Tensor{{wtDescDims[2]}, scalarsType}; DevicePtr saveMeanRaw(saveMean); DevicePtr saveVarRaw(saveVar); @@ -153,11 +153,11 @@ Tensor CudnnAutogradExtension::batchnorm( mode, kOne(scalarsType), kZero(scalarsType), - inDesc.descriptor, + inDesc.get(), inRaw.get(), - inDesc.descriptor, + inDesc.get(), outRaw.get(), - wtDesc.descriptor, + wtDesc.get(), wtRaw.get(), bsRaw.get(), momentum, @@ -175,11 +175,11 @@ Tensor CudnnAutogradExtension::batchnorm( mode, kOne(scalarsType), kZero(scalarsType), - inDesc.descriptor, + inDesc.get(), inRaw.get(), - inDesc.descriptor, + inDesc.get(), outRaw.get(), - wtDesc.descriptor, + wtDesc.get(), wtRaw.get(), bsRaw.get(), runMeanRaw.get(), @@ -223,13 +223,13 @@ std::tuple CudnnAutogradExtension::batchnormBackward( const void* one1 = kOne(scalarsType); const void* zero0 = kZero(scalarsType); - auto iDesc = TensorDescriptor(input.type(), inDescDims); - auto wDesc = TensorDescriptor(wt.type(), wtDescDims); + auto iDesc = TensorDescriptor{input.type(), inDescDims}; + auto wDesc = TensorDescriptor{wt.type(), wtDescDims}; // CuDNN doesn't support calculating only the gradients // required for batchnorm - auto gradIn = Tensor(input.shape(), input.type()); - auto gradWt = Tensor(wt.shape(), wt.type()); - auto gradBs = Tensor(wt.shape(), wt.type()); + auto gradIn = Tensor{input.shape(), input.type()}; + auto gradWt = Tensor{wt.shape(), wt.type()}; + auto gradBs = Tensor{wt.shape(), wt.type()}; { DevicePtr iRaw(input); DevicePtr wRaw(wt); @@ -257,13 +257,13 @@ std::tuple CudnnAutogradExtension::batchnormBackward( zero0, one1, zero0, - iDesc.descriptor, + iDesc.get(), iRaw.get(), - iDesc.descriptor, + iDesc.get(), gradOpRaw.get(), - iDesc.descriptor, + iDesc.get(), gradInRaw.get(), - wDesc.descriptor, + wDesc.get(), wRaw.get(), gradWtRaw.get(), gradBsRaw.get(), diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CMakeLists.txt b/flashlight/fl/autograd/tensor/backend/cudnn/CMakeLists.txt index 49660c9..2d7f083 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CMakeLists.txt +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CMakeLists.txt @@ -8,25 +8,27 @@ target_sources( ${CMAKE_CURRENT_LIST_DIR}/Conv2D.cpp ${CMAKE_CURRENT_LIST_DIR}/CudnnUtils.h ${CMAKE_CURRENT_LIST_DIR}/CudnnUtils.cpp + ${CMAKE_CURRENT_LIST_DIR}/CudnnRnnUtils.h + ${CMAKE_CURRENT_LIST_DIR}/CudnnRnnUtils.cpp ${CMAKE_CURRENT_LIST_DIR}/Pool2D.cpp ${CMAKE_CURRENT_LIST_DIR}/RNN.cpp - ) +) target_link_libraries( flashlight PUBLIC ${CUDNN_LIBRARY_PATH} - ) +) target_include_directories( flashlight PUBLIC ${CUDNN_INCLUDE_PATH} - ) +) target_compile_definitions( flashlight PUBLIC "-DNO_CUDNN_DESTROY_HANDLE" "-DNO_CUDNN_DESTROY_STREAM" - ) +) diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp index bb89e61..e6560e5 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp @@ -1,8 +1,8 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ #include "flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h" @@ -270,7 +270,7 @@ namespace { ) { CUDNN_CHECK_ERR( cudnnSetConvolutionMathType( - cDesc.descriptor, + cDesc.get(), kKernelModesToCudnnMathType.at(kernelOptions->currentOption()) ) ); @@ -280,13 +280,13 @@ namespace { if(input.type() == fl::dtype::f16) CUDNN_CHECK_ERR( cudnnSetConvolutionMathType( - cDesc.descriptor, + cDesc.get(), CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION ) ); else CUDNN_CHECK_ERR( - cudnnSetConvolutionMathType(cDesc.descriptor, CUDNN_DEFAULT_MATH) + cudnnSetConvolutionMathType(cDesc.get(), CUDNN_DEFAULT_MATH) ); } @@ -314,42 +314,42 @@ Tensor CudnnAutogradExtension::conv2d( auto hasBias = bias.elements() > 0; - auto inDesc = TensorDescriptor(input); - auto wtDesc = FilterDescriptor(weights); - auto convDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); + auto inDesc = TensorDescriptor{input}; + auto wtDesc = FilterDescriptor{weights}; + auto convDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups}; if(input.type() == fl::dtype::f16) CUDNN_CHECK_ERR( cudnnSetConvolutionMathType( - convDesc.descriptor, + convDesc.get(), CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION ) ); else CUDNN_CHECK_ERR( - cudnnSetConvolutionMathType(convDesc.descriptor, CUDNN_DEFAULT_MATH) + cudnnSetConvolutionMathType(convDesc.get(), CUDNN_DEFAULT_MATH) ); std::array odims; CUDNN_CHECK_ERR( cudnnGetConvolutionNdForwardOutputDim( - convDesc.descriptor, - inDesc.descriptor, - wtDesc.descriptor, + convDesc.get(), + inDesc.get(), + wtDesc.get(), 4, odims.data() ) ); - auto output = Tensor({odims[3], odims[2], odims[1], odims[0]}, input.type()); - auto outDesc = TensorDescriptor(output); + auto output = Tensor{{odims[3], odims[2], odims[1], odims[0]}, input.type()}; + auto outDesc = TensorDescriptor{output}; auto handle = getCudnnHandle(); const auto& cudnnStream = getCudnnStream(); auto fwdAlgoBestPerf = getFwdAlgo( - inDesc.descriptor, - wtDesc.descriptor, - convDesc.descriptor, - outDesc.descriptor, + inDesc.get(), + wtDesc.get(), + convDesc.get(), + outDesc.get(), input.type() ); @@ -357,22 +357,22 @@ Tensor CudnnAutogradExtension::conv2d( try { wspace = - Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8); + Tensor{{static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8}; } catch(const std::exception&) { fwdAlgoBestPerf.algo = kFwdDefaultAlgo; CUDNN_CHECK_ERR( cudnnGetConvolutionForwardWorkspaceSize( handle, - inDesc.descriptor, - wtDesc.descriptor, - convDesc.descriptor, - outDesc.descriptor, + inDesc.get(), + wtDesc.get(), + convDesc.get(), + outDesc.get(), fwdAlgoBestPerf.algo, &fwdAlgoBestPerf.memory ) ); wspace = - Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8); + Tensor{{static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8}; } { DevicePtr inPtr(input); @@ -390,22 +390,22 @@ Tensor CudnnAutogradExtension::conv2d( cudnnConvolutionForward( handle, one, - inDesc.descriptor, + inDesc.get(), inPtr.get(), - wtDesc.descriptor, + wtDesc.get(), wtPtr.get(), - convDesc.descriptor, + convDesc.get(), fwdAlgoBestPerf.algo, wspacePtr.get(), fwdAlgoBestPerf.memory, zero, - outDesc.descriptor, + outDesc.get(), outPtr.get() ) ); if(hasBias) { - auto bsDesc = TensorDescriptor(bias); + auto bsDesc = TensorDescriptor{bias}; DevicePtr bsPtr(bias); // ensure cudnn compute stream waits on stream of bias tensor relativeSync(cudnnStream, {bias}); @@ -413,10 +413,10 @@ Tensor CudnnAutogradExtension::conv2d( cudnnAddTensor( handle, one, - bsDesc.descriptor, + bsDesc.get(), bsPtr.get(), one, - outDesc.descriptor, + outDesc.get(), outPtr.get() ) ); @@ -453,10 +453,10 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( // benchmarking suggests input or weight casting should occur, these // descriptors may not be used/new ones with the correct types will be // used instead. - auto iDesc = TensorDescriptor(input); - auto wDesc = FilterDescriptor(weight); - auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); - auto oDesc = TensorDescriptor(gradOutput); + auto iDesc = TensorDescriptor{input}; + auto wDesc = FilterDescriptor{weight}; + auto cDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups}; + auto oDesc = TensorDescriptor{gradOutput}; setDefaultMathType(cDesc, input); @@ -481,40 +481,40 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( relativeSync(cudnnStream, {wtTensor}); bool isStrided = (dx * dy) > 1; auto bwdDataAlgoBestPerf = getBwdDataAlgo( - iDesc.descriptor, - wDesc.descriptor, - cDesc.descriptor, - oDesc.descriptor, + iDesc.get(), + wDesc.get(), + cDesc.get(), + oDesc.get(), isStrided, inTensor.type() ); Tensor ws; try { - ws = Tensor( - {static_cast(bwdDataAlgoBestPerf.memory)}, + ws = Tensor{ + {static_cast(bwdDataAlgoBestPerf.memory)}, fl::dtype::b8 - ); + }; } catch(const std::exception&) { bwdDataAlgoBestPerf.algo = kBwdDataDefaultAlgo; CUDNN_CHECK_ERR( cudnnGetConvolutionBackwardDataWorkspaceSize( hndl, - wDesc.descriptor, - oDesc.descriptor, - cDesc.descriptor, - iDesc.descriptor, + wDesc.get(), + oDesc.get(), + cDesc.get(), + iDesc.get(), bwdDataAlgoBestPerf.algo, &bwdDataAlgoBestPerf.memory ) ); - ws = Tensor( - {static_cast(bwdDataAlgoBestPerf.memory)}, + ws = Tensor{ + {static_cast(bwdDataAlgoBestPerf.memory)}, fl::dtype::b8 - ); + }; } - auto gradInput = Tensor(inTensor.shape(), inTensor.type()); + auto gradInput = Tensor{inTensor.shape(), inTensor.type()}; { DevicePtr gradInputPtr(gradInput); DevicePtr gradResultPtr(gradOutputTensor); @@ -525,16 +525,16 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( cudnnConvolutionBackwardData( hndl, oneg, - wDesc.descriptor, + wDesc.get(), wPtr.get(), - oDesc.descriptor, + oDesc.get(), gradResultPtr.get(), - cDesc.descriptor, + cDesc.get(), bwdDataAlgoBestPerf.algo, wsPtr.get(), bwdDataAlgoBestPerf.memory, zerog, - iDesc.descriptor, + iDesc.get(), gradInputPtr.get() ) ); @@ -570,18 +570,18 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( &wtTensorF32, &gradOutput, &gradOutputTensorF32]() { - inTensorF32 = input.astype(fl::dtype::f32); - wtTensorF32 = weight.astype(fl::dtype::f32); - gradOutputTensorF32 = gradOutput.astype(fl::dtype::f32); + inTensorF32 = input.asType(fl::dtype::f32); + wtTensorF32 = weight.asType(fl::dtype::f32); + gradOutputTensorF32 = gradOutput.asType(fl::dtype::f32); }, /* incrementCount = */ false ); - auto iDescF32 = TensorDescriptor(inTensorF32); - auto wDescF32 = FilterDescriptor(wtTensorF32); + auto iDescF32 = TensorDescriptor{inTensorF32}; + auto wDescF32 = FilterDescriptor{wtTensorF32}; auto cDescF32 = - ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups); - auto oDescF32 = TensorDescriptor(gradOutputTensorF32); + ConvDescriptor{fl::dtype::f32, px, py, sx, sy, dx, dy, groups}; + auto oDescF32 = TensorDescriptor{gradOutputTensorF32}; // core bwd data computation dataGradBenchmark->audit( [&dataGradOut, @@ -671,10 +671,10 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( // benchmarking suggests input or weight casting should occur, these // descriptors may not be used/new ones with the correct types will be // used instead. - auto iDesc = TensorDescriptor(input); - auto wDesc = FilterDescriptor(weight); - auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); - auto oDesc = TensorDescriptor(gradOutput); + auto iDesc = TensorDescriptor{input}; + auto wDesc = FilterDescriptor{weight}; + auto cDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups}; + auto oDesc = TensorDescriptor{gradOutput}; setDefaultMathType(cDesc, input); @@ -699,39 +699,39 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( // ensure cudnn compute stream waits on stream of input tensor relativeSync(cudnnStream, {inTensor}); auto bwdFilterAlgoBestPerf = getBwdFilterAlgo( - iDesc.descriptor, - wDesc.descriptor, - cDesc.descriptor, - oDesc.descriptor, + iDesc.get(), + wDesc.get(), + cDesc.get(), + oDesc.get(), inTensor.type() ); Tensor ws; try { - ws = Tensor( - {static_cast(bwdFilterAlgoBestPerf.memory)}, + ws = Tensor{ + {static_cast(bwdFilterAlgoBestPerf.memory)}, fl::dtype::b8 - ); + }; } catch(const std::exception&) { bwdFilterAlgoBestPerf.algo = kBwdFilterDefaultAlgo; CUDNN_CHECK_ERR( cudnnGetConvolutionBackwardFilterWorkspaceSize( hndl, - iDesc.descriptor, - oDesc.descriptor, - cDesc.descriptor, - wDesc.descriptor, + iDesc.get(), + oDesc.get(), + cDesc.get(), + wDesc.get(), bwdFilterAlgoBestPerf.algo, &bwdFilterAlgoBestPerf.memory ) ); - ws = Tensor( - {static_cast(bwdFilterAlgoBestPerf.memory)}, + ws = Tensor{ + {static_cast(bwdFilterAlgoBestPerf.memory)}, fl::dtype::b8 - ); + }; } - auto gradWeight = Tensor(wtTensor.shape(), wtTensor.type()); + auto gradWeight = Tensor{wtTensor.shape(), wtTensor.type()}; { DevicePtr gradWeightPtr(gradWeight); DevicePtr gradResultPtr(gradOutputTensor); @@ -742,16 +742,16 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( cudnnConvolutionBackwardFilter( hndl, oneg, - iDesc.descriptor, + iDesc.get(), iPtr.get(), - oDesc.descriptor, + oDesc.get(), gradResultPtr.get(), - cDesc.descriptor, + cDesc.get(), bwdFilterAlgoBestPerf.algo, wsPtr.get(), bwdFilterAlgoBestPerf.memory, zerog, - wDesc.descriptor, + wDesc.get(), gradWeightPtr.get() ) ); @@ -787,18 +787,18 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( &wtTensorF32, &gradOutput, &gradOutputTensorF32]() { - inTensorF32 = input.astype(fl::dtype::f32); - wtTensorF32 = weight.astype(fl::dtype::f32); - gradOutputTensorF32 = gradOutput.astype(fl::dtype::f32); + inTensorF32 = input.asType(fl::dtype::f32); + wtTensorF32 = weight.asType(fl::dtype::f32); + gradOutputTensorF32 = gradOutput.asType(fl::dtype::f32); }, /* incrementCount = */ false ); - auto iDescF32 = TensorDescriptor(inTensorF32); - auto wDescF32 = FilterDescriptor(wtTensorF32); + auto iDescF32 = TensorDescriptor{inTensorF32}; + auto wDescF32 = FilterDescriptor{wtTensorF32}; auto cDescF32 = - ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups); - auto oDescF32 = TensorDescriptor(gradOutputTensorF32); + ConvDescriptor{fl::dtype::f32, px, py, sx, sy, dx, dy, groups}; + auto oDescF32 = TensorDescriptor{gradOutputTensorF32}; // core bwd data computation filterGradBenchmark->audit( [&filterGradOut, @@ -860,21 +860,21 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( const Tensor& bsTensor, const Tensor& gradOutput, const TensorDescriptor& oDesc) -> Tensor { - auto gradBias = Tensor(bsTensor.shape(), bsTensor.type()); + auto gradBias = Tensor{bsTensor.shape(), bsTensor.type()}; { DevicePtr gradBiasPtr(gradBias); DevicePtr gradResultPtr(gradOutput); // ensure cudnn compute stream waits on gradient tensor streams relativeSync(cudnnStream, {gradOutput, gradBias}); - auto bDesc = TensorDescriptor(bsTensor); + auto bDesc = TensorDescriptor{bsTensor}; CUDNN_CHECK_ERR( cudnnConvolutionBackwardBias( hndl, oneg, - oDesc.descriptor, + oDesc.get(), gradResultPtr.get(), zerog, - bDesc.descriptor, + bDesc.get(), gradBiasPtr.get() ) ); @@ -906,12 +906,12 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( // Time cast bias and grad output if benchmarking biasGradBenchmark->audit( [&bias, &gradOutput, &biasF32, &gradOutputF32]() { - biasF32 = bias.astype(fl::dtype::f32); - gradOutputF32 = gradOutput.astype(fl::dtype::f32); + biasF32 = bias.asType(fl::dtype::f32); + gradOutputF32 = gradOutput.asType(fl::dtype::f32); }, /* incrementCount = */ false ); - auto oDescF32 = TensorDescriptor(gradOutputF32); + auto oDescF32 = TensorDescriptor{gradOutputF32}; // Perform bias gradient computation biasGradBenchmark->audit( [&biasGradOut, diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp index 305a6cc..eaac140 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp @@ -16,13 +16,11 @@ namespace fl { std::shared_ptr CudnnAutogradExtension::createBenchmarkOptions() { return std::make_shared( std::make_shared>( - std::vector( - { - KernelMode::F32, - KernelMode::F32_ALLOW_CONVERSION, - KernelMode::F16 - } - ), + std::vector{ + KernelMode::F32, + KernelMode::F32_ALLOW_CONVERSION, + KernelMode::F16 + }, fl::kDynamicBenchmarkDefaultCount ) ); diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h index b960c30..2edfd89 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h @@ -1,8 +1,8 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ #pragma once @@ -19,96 +19,96 @@ class CudnnAutogradExtension : public AutogradExtension { // TODO(jacobkahn): implement getCudnnHandle public: - bool isDataTypeSupported(const fl::dtype& dtype) const override; + bool isDataTypeSupported(fl::dtype const& dtype) const override; - enum class KernelMode {F32 = 0, F32_ALLOW_CONVERSION = 1, F16 = 2}; + enum class KernelMode { F32 = 0, F32_ALLOW_CONVERSION = 1, F16 = 2 }; std::shared_ptr createBenchmarkOptions() override; /**************************** Forward ****************************/ Tensor conv2d( - const Tensor& input, - const Tensor& weights, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, + Tensor const& input, + Tensor const& weights, + Tensor const& bias, + int sx, + int sy, + int px, + int py, + int dx, + int dy, + int groups, std::shared_ptr payload ) override; Tensor pool2d( - const Tensor& input, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, + Tensor const& input, + int wx, + int wy, + int sx, + int sy, + int px, + int py, + PoolingMode mode, std::shared_ptr payload ) override; Tensor batchnorm( Tensor& saveMean, Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, + Tensor const& input, + Tensor const& weight, + Tensor const& bias, Tensor& runningMean, Tensor& runningVar, - const std::vector& axes, - const bool train, - const double momentum, - const double epsilon, + std::vector const& axes, + bool train, + double momentum, + double epsilon, std::shared_ptr payload ) override; std::tuple rnn( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const int hiddenSize, - const int numLayers, - const RnnMode mode, - const bool bidirectional, - const float dropout, - std::shared_ptr payload + Tensor const& input, + Tensor const& hiddenState, + Tensor const& cellState, + Tensor const& weights, + int hiddenSize, + int numLayers, + RnnMode mode, + bool bidirectional, + float dropProb, + std::shared_ptr autogradPayload ) override; /**************************** Backward ****************************/ // ]----- Convolution Tensor conv2dBackwardData( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& weight, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, + Tensor const& gradOutput, + Tensor const& input, + Tensor const& weight, + int sx, + int sy, + int px, + int py, + int dx, + int dy, + int groups, std::shared_ptr dataGradBenchmark, std::shared_ptr payload ) override; std::pair conv2dBackwardFilterBias( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& weights, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, + Tensor const& gradOutput, + Tensor const& input, + Tensor const& weights, + Tensor const& bias, + int sx, + int sy, + int px, + int py, + int dx, + int dy, + int groups, std::shared_ptr filterBench, std::shared_ptr biasBench, std::shared_ptr autogradPayload @@ -116,47 +116,59 @@ class CudnnAutogradExtension : public AutogradExtension { // ]----- pool2D Tensor pool2dBackward( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& poolOutput, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, + Tensor const& gradOutput, + Tensor const& input, + Tensor const& poolOutput, + int wx, + int wy, + int sx, + int sy, + int px, + int py, + PoolingMode mode, std::shared_ptr payload ) override; // ]----- batchnorm std::tuple batchnormBackward( - const Tensor& gradOutput, - const Tensor& saveMean, - const Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const std::vector& axes, - const bool train, - const float epsilon, + Tensor const& gradOutput, + Tensor const& saveMean, + Tensor const& saveVar, + Tensor const& input, + Tensor const& weight, + std::vector const& axes, + bool train, + float epsilon, std::shared_ptr payload ) override; // ]----- rnn std::tuple rnnBackward( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const std::shared_ptr gradData, - const Tensor& output, - const int numLayers, - const int hiddenSize, - const RnnMode mode, - const bool bidirectional, - const float dropProb, - std::shared_ptr payload + Tensor const& input, + Tensor const& hiddenState, + Tensor const& cellState, + Tensor const& weights, + std::shared_ptr gradData, + Tensor const& output, + int numLayers, + int hiddenSize, + RnnMode mode, + bool bidirectional, + float dropProb, + std::shared_ptr autogradPayload ) override; + +private: + + static void checkHiddenStateDims(int hiddenSize, Tensor const& hiddenState, int batchSize, int totalLayers); + static void checkCellStateDims( + int hiddenSize, + RnnMode mode, + Tensor const& cellState, + int batchSize, + int totalLayers + ); + }; } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.cpp new file mode 100644 index 0000000..c64c234 --- /dev/null +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.cpp @@ -0,0 +1,305 @@ +#include "CudnnRnnUtils.h" + +#include "flashlight/fl/common/DevicePtr.h" +#include "flashlight/fl/tensor/Compute.h" + + +namespace fl { +namespace { + struct temp_space_sizes { + size_t size; + size_t reserveSize; + }; + + temp_space_sizes rnn_temp_space_sizes( + cudnnHandle_t handle, + RNNDescriptor const& rnnDescriptor, + RNNDataDescriptor const& xDescriptor, + cudnnForwardMode_t mode + ) { + temp_space_sizes sizes{}; + + CUDNN_CHECK_ERR( + cudnnGetRNNTempSpaceSizes( + handle, + rnnDescriptor.get(), + mode, + xDescriptor.get(), + &sizes.size, + &sizes.reserveSize + ) + ); + + return sizes; + } + + size_t rnn_weight_space_size( + cudnnHandle_t handle, + RNNDescriptor const& rnnDescriptor + ) { + size_t size = 0; + + CUDNN_CHECK_ERR( + cudnnGetRNNWeightSpaceSize(handle,rnnDescriptor.get(),&size) + ); + return size; + } + + std::optional create_dev_seq_lengths(int batchSize, int seqLength) { + //see cudnn docs for cudnnRNNForward as explanation +#if CUDNN_VERSION >= 8901 + return std::nullopt; +#else + return fl::full({batchSize}, seqLength, fl::dtype::s32); +#endif + } + +} +} + +namespace fl { +void cudnn_rnn_forward( + int batchSize, + int seqLength, + bool train, + RNNDescriptor const& rnnDesc, + Tensor const& x, + Tensor const& y, + Tensor const& weights, + TensorDescriptor const& cxDesc, + TensorDescriptor const& hxDesc, + Tensor const& hy, + Tensor const& cy, + Tensor const& hiddenState, + Tensor const& cellState, + Tensor& reserveSpace +) { + RNNDataDescriptor xDesc{x.type(), x.shape()}; + RNNDataDescriptor yDesc{y.type(), y.shape()}; + + auto handle = getCudnnHandle(); + + size_t weightSpaceSize = rnn_weight_space_size(handle, rnnDesc); + + if(weightSpaceSize != weights.bytes()) + throw std::invalid_argument("invalid # of parameters or wrong input shape for RNN"); + + auto const forwardMode = train ? CUDNN_FWD_MODE_TRAINING : CUDNN_FWD_MODE_INFERENCE; + + auto [workspaceSize, reserveSize] = rnn_temp_space_sizes(handle, rnnDesc, xDesc, forwardMode); + + Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8); + // Space must be reused between forward and backward for cuDNN + + reserveSpace = Tensor{{static_cast(reserveSize)}, fl::dtype::b8}; + + auto devSeqLengths = create_dev_seq_lengths(batchSize, seqLength); + + auto const& cudnnStream = getCudnnStream(); + + { + auto contiguousX = x.asContiguousTensor(); + auto contiguousWeights = weights.asContiguousTensor(); + DevicePtr xRaw(contiguousX); + DevicePtr hxRaw(hiddenState); + DevicePtr cxRaw(cellState); + DevicePtr weightSpaceRaw(contiguousWeights); + DevicePtr yRaw(y); + DevicePtr hyRaw(hy); + DevicePtr cyRaw(cy); + DevicePtr workspaceRaw(workspace); + DevicePtr reserveSpaceRaw(reserveSpace); + + std::optional devSeqLengthsRaw{}; + + if(devSeqLengths) + devSeqLengthsRaw.emplace(*devSeqLengths); + + // ensure cudnn compute stream waits greaterThanEqual(&on input/output tensor streams + + std::vector waits{ + contiguousX, + hiddenState, + cellState, + contiguousWeights, + y, + hy, + cy, + workspace, + reserveSpace, + }; + if(devSeqLengths) + waits.push_back(*devSeqLengths); + + relativeSync(cudnnStream, waits); + + + CUDNN_CHECK_ERR( + cudnnRNNForward( + handle, + rnnDesc.get(), + forwardMode, + devSeqLengthsRaw ? devSeqLengthsRaw->getAs() : nullptr, + + xDesc.get(), + xRaw.get(), + yDesc.get(), + yRaw.get(), + + hxDesc.get(), + hxRaw.get(), + hyRaw.get(), + cxDesc.get(), + cxRaw.get(), + cyRaw.get(), + + weightSpaceSize, + weightSpaceRaw.get(), + + workspaceSize, + workspaceRaw.get(), + + reserveSize, + reserveSpaceRaw.get() + ) + ); + } + + // ensure output tensor streams wait on cudnn compute stream + relativeSync({y, hy, cy}, cudnnStream); +} + +void cudnn_rnn_backward( + int batchSize, + int seqLength, + RNNDescriptor const& rnnDesc, + + Tensor const& x, + Tensor const& y, + Tensor const& dy, + Tensor const& weights, + TensorDescriptor const& cxDesc, + TensorDescriptor const& hxDesc, + Tensor const& dhy, + Tensor const& dcy, + Tensor const& hiddenState, + Tensor const& cellState, + Tensor const& dx, + Tensor const& dhx, + Tensor const& dcx, + Tensor const& dw, + + Tensor const& reserveSpace +) { + auto handle = getCudnnHandle(); + auto const& cudnnStream = getCudnnStream(); + + RNNDataDescriptor xDesc{x.type(), x.shape()}; + RNNDataDescriptor yDesc{y.type(), y.shape()}; + + size_t weightSpaceSize = rnn_weight_space_size(handle, rnnDesc); + auto [workspaceSize, reserveSize] = rnn_temp_space_sizes(handle, rnnDesc, xDesc, CUDNN_FWD_MODE_TRAINING); + + Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8); + + auto devSeqLengths = create_dev_seq_lengths(batchSize, seqLength); + + std::vector waits = {y, workspace, reserveSpace}; + if(devSeqLengths) + waits.push_back(*devSeqLengths); + + // ensure cudnn compute stream waits on input/output tensor streams + relativeSync(cudnnStream, waits); + + DevicePtr yRaw(y); + DevicePtr workspaceRaw(workspace); + DevicePtr reserveSpaceRaw(reserveSpace); + + std::optional devSeqLengthsRaw{}; + if(devSeqLengths) + devSeqLengthsRaw.emplace(*devSeqLengths); + + { + DevicePtr dyRaw(dy); // Has to be set to 0 if empty + DevicePtr dhyRaw(dhy); + DevicePtr dcyRaw(dcy); + + DevicePtr wRaw(weights); + + DevicePtr hxRaw(hiddenState); + DevicePtr cxRaw(cellState); + + DevicePtr dxRaw(dx); + DevicePtr dhxRaw(dhx); + DevicePtr dcxRaw(dcx); + + // ensure cudnn compute stream waits on input/output tensor streams + relativeSync( + cudnnStream, + {dy, dhy, dcy, weights, hiddenState, cellState, dx, dhx, dcx} + ); + + /* We need to update reserveSpace even if we just want the + * weight gradients. */ + CUDNN_CHECK_ERR( + cudnnRNNBackwardData_v8( + handle, + rnnDesc.get(), + devSeqLengthsRaw ? devSeqLengthsRaw->getAs() : nullptr, + yDesc.get(), + yRaw.get(), + dyRaw.get(), + xDesc.get(), + dxRaw.get(), + hxDesc.get(), + hxRaw.get(), + dhyRaw.get(), + dhxRaw.get(), + cxDesc.get(), + cxRaw.get(), + dcyRaw.get(), + dcxRaw.get(), + weightSpaceSize, + wRaw.get(), + workspaceSize, + workspaceRaw.get(), + reserveSpace.bytes(), + reserveSpaceRaw.get() + ) + ); + } + + { + DevicePtr xRaw(x); + DevicePtr dwRaw(dw); + DevicePtr hxRaw(hiddenState); + + // ensure cudnn compute stream waits on input/output tensor streams + relativeSync(cudnnStream, {x, dw, hiddenState}); + + CUDNN_CHECK_ERR( + cudnnRNNBackwardWeights_v8( + handle, + rnnDesc.get(), + CUDNN_WGRAD_MODE_ADD, + devSeqLengthsRaw ? devSeqLengthsRaw->getAs() : nullptr, + xDesc.get(), + xRaw.get(), + hxDesc.get(), + hxRaw.get(), + yDesc.get(), + yRaw.get(), + weightSpaceSize, + dwRaw.get(), + workspaceSize, + workspaceRaw.get(), + reserveSpace.bytes(), + reserveSpaceRaw.get() + ) + ); + } + + // ensure output tensor streams wait on cudnn compute stream + relativeSync({dx, dhx, dcx, dw}, cudnnStream); +} +} diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.h b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.h new file mode 100644 index 0000000..3ed5b07 --- /dev/null +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.h @@ -0,0 +1,45 @@ +#pragma once +#include "CudnnUtils.h" + +namespace fl { +void cudnn_rnn_forward( + int batchSize, + int seqLength, + bool train, + RNNDescriptor const& rnnDesc, + + Tensor const& x, + Tensor const& y, + Tensor const& weights, + TensorDescriptor const& cxDesc, + TensorDescriptor const& hxDesc, + Tensor const& hy, + Tensor const& cy, + Tensor const& hiddenState, + Tensor const& cellState, + + Tensor& reserveSpace // out +); +void cudnn_rnn_backward( + int batchSize, + int seqLength, + RNNDescriptor const& rnnDesc, + + Tensor const& x, + Tensor const& y, + Tensor const& dy, + Tensor const& weights, + TensorDescriptor const& cxDesc, + TensorDescriptor const& hxDesc, + Tensor const& dhy, + Tensor const& dcy, + Tensor const& hiddenState, + Tensor const& cellState, + Tensor const& dx, + Tensor const& dhx, + Tensor const& dcx, + Tensor const& dw, + + Tensor const& reserveSpace +); +} diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp index 82cadcb..4a20900 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp @@ -1,8 +1,8 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ #include "flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h" @@ -25,16 +25,16 @@ struct DeviceHandle { std::shared_ptr stream; explicit DeviceHandle(std::shared_ptr _stream) : cudnnHandle(nullptr), - stream(_stream) { + stream(_stream) { CUDNN_CHECK_ERR(cudnnCreate(&cudnnHandle)); CUDNN_CHECK_ERR(cudnnSetStream(cudnnHandle, stream->handle())); } ~DeviceHandle() { if(cudnnHandle) { -// See https://git.io/fNQnM - sometimes, at exit, the CUDA context -// (or something) is already destroyed by the time a handle gets destroyed -// because of an issue with the destruction order. + // See https://git.io/fNQnM - sometimes, at exit, the CUDA context + // (or something) is already destroyed by the time a handle gets destroyed + // because of an issue with the destruction order. #ifdef NO_CUDNN_DESTROY_HANDLE #else CUDNN_CHECK_ERR(cudnnDestroy(cudnnHandle)); @@ -43,16 +43,16 @@ struct DeviceHandle { } }; -const float kFloatZero = 0.0; -const float kFloatOne = 1.0; +constexpr float kFloatZero = 0.0; +constexpr float kFloatOne = 1.0; -const double kDoubleZero = 0.0; -const double kDoubleOne = 1.0; +constexpr double kDoubleZero = 0.0; +constexpr double kDoubleOne = 1.0; // TODO: move this to CudnnAutogradExtension if we make it a singleton std::unordered_map handles; -const DeviceHandle& getActiveDeviceHandle() { +DeviceHandle const& getActiveDeviceHandle() { auto& manager = fl::DeviceManager::getInstance(); auto& cudaDevice = manager.getActiveDevice(fl::DeviceType::CUDA).impl(); @@ -88,58 +88,43 @@ namespace fl { void cudnnCheckErr(cudnnStatus_t status) { if(status == CUDNN_STATUS_SUCCESS) return; - const char* err = cudnnGetErrorString(status); + char const* err = cudnnGetErrorString(status); switch(status) { - case CUDNN_STATUS_BAD_PARAM: - throw std::invalid_argument(err); - default: - throw std::runtime_error(err); + case CUDNN_STATUS_BAD_PARAM: throw std::invalid_argument(err); + default: throw std::runtime_error(err); } } -cudnnDataType_t cudnnMapToType(const fl::dtype& t) { +cudnnDataType_t cudnnMapToType(fl::dtype const& t) { switch(t) { - case fl::dtype::f16: - return CUDNN_DATA_HALF; - case fl::dtype::f32: - return CUDNN_DATA_FLOAT; - case fl::dtype::f64: - return CUDNN_DATA_DOUBLE; - default: - throw std::invalid_argument("unsupported data type for cuDNN"); + case fl::dtype::f16: return CUDNN_DATA_HALF; + case fl::dtype::f32: return CUDNN_DATA_FLOAT; + case fl::dtype::f64: return CUDNN_DATA_DOUBLE; + default: throw std::invalid_argument("unsupported data type for cuDNN"); } } -cudnnPoolingMode_t cudnnMapToPoolingMode(const PoolingMode mode) { +cudnnPoolingMode_t cudnnMapToPoolingMode(PoolingMode const mode) { switch(mode) { - case PoolingMode::MAX: - return CUDNN_POOLING_MAX; - case PoolingMode::AVG_INCLUDE_PADDING: - return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - case PoolingMode::AVG_EXCLUDE_PADDING: - return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; - default: - throw std::invalid_argument("unsupported pooling mode for cuDNN"); + case PoolingMode::MAX: return CUDNN_POOLING_MAX; + case PoolingMode::AVG_INCLUDE_PADDING: return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + case PoolingMode::AVG_EXCLUDE_PADDING: return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + default: throw std::invalid_argument("unsupported pooling mode for cuDNN"); } } -cudnnRNNMode_t cudnnMapToRNNMode(const RnnMode mode) { +cudnnRNNMode_t cudnnMapToRNNMode(RnnMode const mode) { switch(mode) { - case RnnMode::RELU: - return CUDNN_RNN_RELU; - case RnnMode::TANH: - return CUDNN_RNN_TANH; - case RnnMode::LSTM: - return CUDNN_LSTM; - case RnnMode::GRU: - return CUDNN_GRU; - default: - throw std::invalid_argument("unsupported RNN mode for cuDNN"); + case RnnMode::RELU: return CUDNN_RNN_RELU; + case RnnMode::TANH: return CUDNN_RNN_TANH; + case RnnMode::LSTM: return CUDNN_LSTM; + case RnnMode::GRU: return CUDNN_GRU; + default: throw std::invalid_argument("unsupported RNN mode for cuDNN"); } } -TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) { - CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor)); +TensorDescriptor::TensorDescriptor(fl::dtype const type, Shape const& flDims) { + CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&_handle)); cudnnDataType_t cudnntype = cudnnMapToType(type); std::array dims = {1, 1, 1, 1}; @@ -156,7 +141,7 @@ TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) { CUDNN_CHECK_ERR( cudnnSetTensorNdDescriptor( - descriptor, + _handle, cudnntype, dims.size(), dims.data(), @@ -165,8 +150,8 @@ TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) { ); } -TensorDescriptor::TensorDescriptor(const Tensor& input) { - CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor)); +TensorDescriptor::TensorDescriptor(Tensor const& input) { + CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&_handle)); cudnnDataType_t cudnntype = cudnnMapToType(input.type()); auto flStrides = input.strides(); @@ -185,7 +170,7 @@ TensorDescriptor::TensorDescriptor(const Tensor& input) { CUDNN_CHECK_ERR( cudnnSetTensorNdDescriptor( - descriptor /* descriptor handle */, + _handle /* descriptor handle */, cudnntype /* = dataType */, 4, dims.data(), @@ -194,21 +179,19 @@ TensorDescriptor::TensorDescriptor(const Tensor& input) { ); } -TensorDescriptor::~TensorDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyTensorDescriptor(descriptor)); -} +TensorDescriptor::~TensorDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyTensorDescriptor(_handle)); } TensorDescriptorArray::TensorDescriptorArray( int size, - const fl::dtype type, - const Shape& dims + fl::dtype const type, + Shape const& dims ) { - desc_vec.reserve(size); + _descVec.reserve(size); for(int i = 0; i < size; i++) { - desc_vec.emplace_back(type, dims); - desc_raw_vec.push_back(desc_vec.back().descriptor); + _descVec.emplace_back(type, dims); + _descRawVec.push_back(_descVec.back().get()); } - descriptors = desc_raw_vec.data(); + descriptors = _descRawVec.data(); } TensorDescriptorArray::~TensorDescriptorArray() = default; @@ -222,7 +205,7 @@ PoolingDescriptor::PoolingDescriptor( int py, PoolingMode mode ) { - CUDNN_CHECK_ERR(cudnnCreatePoolingDescriptor(&descriptor)); + CUDNN_CHECK_ERR(cudnnCreatePoolingDescriptor(&_handle)); std::array window = {static_cast(wy), static_cast(wx)}; std::array padding = {static_cast(py), static_cast(px)}; std::array stride = {static_cast(sy), static_cast(sx)}; @@ -230,7 +213,7 @@ PoolingDescriptor::PoolingDescriptor( auto cudnnpoolingmode = cudnnMapToPoolingMode(mode); CUDNN_CHECK_ERR( cudnnSetPoolingNdDescriptor( - descriptor, + _handle, cudnnpoolingmode, CUDNN_PROPAGATE_NAN, 2, @@ -241,12 +224,10 @@ PoolingDescriptor::PoolingDescriptor( ); } -PoolingDescriptor::~PoolingDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyPoolingDescriptor(descriptor)); -} +PoolingDescriptor::~PoolingDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyPoolingDescriptor(_handle)); } -FilterDescriptor::FilterDescriptor(const Tensor& input) { - CUDNN_CHECK_ERR(cudnnCreateFilterDescriptor(&descriptor)); +FilterDescriptor::FilterDescriptor(Tensor const& input) { + CUDNN_CHECK_ERR(cudnnCreateFilterDescriptor(&_handle)); cudnnDataType_t cudnntype = cudnnMapToType(input.type()); auto flDims = input.shape(); @@ -258,7 +239,7 @@ FilterDescriptor::FilterDescriptor(const Tensor& input) { CUDNN_CHECK_ERR( cudnnSetFilterNdDescriptor( - descriptor, + _handle, cudnntype, CUDNN_TENSOR_NCHW, 4, @@ -267,121 +248,144 @@ FilterDescriptor::FilterDescriptor(const Tensor& input) { ); } -FilterDescriptor::~FilterDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyFilterDescriptor(descriptor)); -} +FilterDescriptor::~FilterDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyFilterDescriptor(_handle)); } + +DropoutDescriptor::DropoutDescriptor(float dropProb) { + CUDNN_CHECK_ERR(cudnnCreateDropoutDescriptor(&_handle)); + + auto const cudnnHandle = getCudnnHandle(); + constexpr int64_t seed = 0; + size_t stateSize; -DropoutDescriptor::DropoutDescriptor(float drop_prob) { - CUDNN_CHECK_ERR(cudnnCreateDropoutDescriptor(&descriptor)); - auto cudnnHandle = getCudnnHandle(); - unsigned long long seed = 0; - size_t state_size; - CUDNN_CHECK_ERR(cudnnDropoutGetStatesSize(cudnnHandle, &state_size)); - auto& dropout_states = getDropoutStates(); - if(dropout_states.isEmpty()) { - dropout_states = - Tensor({static_cast(state_size)}, fl::dtype::b8); - DevicePtr statesraw(dropout_states); + CUDNN_CHECK_ERR(cudnnDropoutGetStatesSize(cudnnHandle, &stateSize)); + + auto& dropoutStates = getDropoutStates(); + + if(dropoutStates.isEmpty()) { + dropoutStates = + Tensor{{static_cast(stateSize)}, fl::dtype::b8}; + DevicePtr statesraw(dropoutStates); CUDNN_CHECK_ERR( cudnnSetDropoutDescriptor( - descriptor, + _handle, cudnnHandle, - drop_prob, + dropProb, statesraw.get(), - state_size, + stateSize, seed ) ); - } else { - DevicePtr statesraw(dropout_states); -// See https://git.io/fp9oo for an explanation. -#if CUDNN_VERSION >= 7000 + } + else { + DevicePtr statesraw(dropoutStates); CUDNN_CHECK_ERR( cudnnRestoreDropoutDescriptor( - descriptor, + _handle, cudnnHandle, - drop_prob, + dropProb, statesraw.get(), - state_size, + stateSize, seed ) ); -#else - auto dropout_struct = reinterpret_cast(descriptor); - dropout_struct->dropout = drop_prob; - dropout_struct->nstates = state_size; - dropout_struct->states = statesraw.get(); -#endif } } -DropoutDescriptor::~DropoutDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyDropoutDescriptor(descriptor)); -} +DropoutDescriptor::~DropoutDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyDropoutDescriptor(_handle)); } Tensor& DropoutDescriptor::getDropoutStates() { - thread_local Tensor dropout_states; - return dropout_states; + thread_local Tensor dropoutStates; + return dropoutStates; } RNNDescriptor::RNNDescriptor( fl::dtype type, - int hidden_size, - int num_layers, + int inputSize, + int hiddenSize, + int numLayers, RnnMode mode, bool bidirectional, DropoutDescriptor& dropout ) { - CUDNN_CHECK_ERR(cudnnCreateRNNDescriptor(&descriptor)); - - auto cudnnHandle = getCudnnHandle(); + CUDNN_CHECK_ERR(cudnnCreateRNNDescriptor(&_handle)); - cudnnRNNInputMode_t in_mode = CUDNN_LINEAR_INPUT; + constexpr auto inMode = CUDNN_LINEAR_INPUT; + constexpr auto algo = CUDNN_RNN_ALGO_STANDARD; - cudnnDirectionMode_t dir = - bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + auto const dir = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - cudnnRNNMode_t cell = cudnnMapToRNNMode(mode); - cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; - cudnnDataType_t cudnntype = cudnnMapToType(type); + auto const cell = cudnnMapToRNNMode(mode); + auto const dataType = cudnnMapToType(type); -#if CUDNN_VERSION >= 7000 && CUDNN_VERSION < 8000 CUDNN_CHECK_ERR( - cudnnSetRNNDescriptor( - cudnnHandle, - descriptor, - hidden_size, - num_layers, - dropout.descriptor, - in_mode, - dir, - cell, + //https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn-892/api/index.html#cudnnSetRNNDescriptor_v8 + cudnnSetRNNDescriptor_v8( + _handle, algo, - cudnntype + cell, + CUDNN_RNN_DOUBLE_BIAS, //TODO review; double is default for old cudnn + dir, + inMode, + dataType, + dataType, // math precision + mathType(type), + inputSize, + hiddenSize, + hiddenSize, //projection size (unused) + numLayers, + dropout.get(), + 0 ) ); -#else +} + +RNNDescriptor::~RNNDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyRNNDescriptor(_handle)); } + +} + +namespace fl { + + +RNNDataDescriptor::RNNDataDescriptor(fl::dtype type, Shape const& dims) { + create(); + + auto sizes = max(dims, {1, 1, 1}); + + auto const inputSize = static_cast(sizes[0]); + auto const batchSize = static_cast(sizes[1]); + auto const maxSeqSize = static_cast(sizes[2]); + + std::vector seqSizes(batchSize, maxSeqSize); + + set(type, inputSize, maxSeqSize, seqSizes); +} + +RNNDataDescriptor::~RNNDataDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyRNNDataDescriptor(_handle)); } +void RNNDataDescriptor::create() { CUDNN_CHECK_ERR(cudnnCreateRNNDataDescriptor(&_handle)); } +void RNNDataDescriptor::set( + fl::dtype type, + int inputSize, + int maxSeqSize, + std::span sequenceSizes +) const { CUDNN_CHECK_ERR( - cudnnSetRNNDescriptor_v6( - cudnnHandle, - descriptor, - hidden_size, - num_layers, - dropout.descriptor, - in_mode, - dir, - cell, - algo, - cudnntype + cudnnSetRNNDataDescriptor( + _handle, + cudnnMapToType(type), + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED, + maxSeqSize, + sequenceSizes.size(), //batch size + inputSize, + sequenceSizes.data(), + nullptr //no padding ) ); -#endif } -RNNDescriptor::~RNNDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyRNNDescriptor(descriptor)); } +namespace fl { + ConvDescriptor::ConvDescriptor( fl::dtype type, int px, @@ -392,7 +396,7 @@ ConvDescriptor::ConvDescriptor( int dy, int groups ) { - CUDNN_CHECK_ERR(cudnnCreateConvolutionDescriptor(&descriptor)); + CUDNN_CHECK_ERR(cudnnCreateConvolutionDescriptor(&_handle)); cudnnDataType_t cudnntype = cudnnMapToType(type); std::array padding = {static_cast(py), static_cast(px)}; std::array stride = {static_cast(sy), static_cast(sx)}; @@ -400,7 +404,7 @@ ConvDescriptor::ConvDescriptor( CUDNN_CHECK_ERR( cudnnSetConvolutionNdDescriptor( - descriptor, + _handle, 2, padding.data(), stride.data(), @@ -410,39 +414,34 @@ ConvDescriptor::ConvDescriptor( ) ); - CUDNN_CHECK_ERR(cudnnSetConvolutionGroupCount(descriptor, groups)); + CUDNN_CHECK_ERR(cudnnSetConvolutionGroupCount(_handle, groups)); } -ConvDescriptor::~ConvDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyConvolutionDescriptor(descriptor)); +ConvDescriptor::~ConvDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyConvolutionDescriptor(_handle)); } } +namespace fl { + cudnnHandle_t getCudnnHandle() { return getActiveDeviceHandle().cudnnHandle; } -const CUDAStream& getCudnnStream() { return *getActiveDeviceHandle().stream; } +CUDAStream const& getCudnnStream() { return *getActiveDeviceHandle().stream; } -const void* kOne(const fl::dtype t) { +void const* kOne(fl::dtype const t) { switch(t) { case fl::dtype::f16: - case fl::dtype::f32: - return &kFloatOne; - case fl::dtype::f64: - return &kDoubleOne; - default: - throw std::invalid_argument("unsupported data type for cuDNN"); + case fl::dtype::f32: return &kFloatOne; + case fl::dtype::f64: return &kDoubleOne; + default: throw std::invalid_argument("unsupported data type for cuDNN"); } } -const void* kZero(const fl::dtype t) { +void const* kZero(fl::dtype const t) { switch(t) { case fl::dtype::f16: - case fl::dtype::f32: - return &kFloatZero; - case fl::dtype::f64: - return &kDoubleZero; - default: - throw std::invalid_argument("unsupported data type for cuDNN"); + case fl::dtype::f32: return &kFloatZero; + case fl::dtype::f64: return &kDoubleZero; + default: throw std::invalid_argument("unsupported data type for cuDNN"); } } -} // namespace fl +} // namespace fl \ No newline at end of file diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h index fca9969..76e9318 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h @@ -1,10 +1,9 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ - #pragma once #include @@ -13,35 +12,46 @@ #include "flashlight/fl/runtime/CUDAStream.h" #include "flashlight/fl/tensor/TensorBase.h" +#include + namespace fl { class TensorDescriptor { public: - explicit TensorDescriptor(const Tensor& a); - - TensorDescriptor(const fl::dtype type, const Shape& af_dims); + explicit TensorDescriptor(Tensor const& a); - cudnnTensorDescriptor_t descriptor; + TensorDescriptor(fl::dtype const type, Shape const& afDims); ~TensorDescriptor(); + +private: + cudnnTensorDescriptor_t _handle; + +public: + [[nodiscard]] constexpr auto get() const { return _handle; } }; class TensorDescriptorArray { public: - TensorDescriptorArray(int size, const fl::dtype type, const Shape& dims); + TensorDescriptorArray(int size, fl::dtype const type, Shape const& dims); cudnnTensorDescriptor_t* descriptors; ~TensorDescriptorArray(); private: - std::vector desc_vec; - std::vector desc_raw_vec; + std::vector _descVec; + std::vector _descRawVec; }; class FilterDescriptor { public: - explicit FilterDescriptor(const Tensor& a); - cudnnFilterDescriptor_t descriptor; + explicit FilterDescriptor(Tensor const& input); ~FilterDescriptor(); + +private: + cudnnFilterDescriptor_t _handle; + +public: + [[nodiscard]] constexpr auto get() const { return _handle; } }; class ConvDescriptor { @@ -56,8 +66,13 @@ class ConvDescriptor { int dy, int groups = 1 ); - cudnnConvolutionDescriptor_t descriptor; ~ConvDescriptor(); + +private: + cudnnConvolutionDescriptor_t _handle; + +public: + [[nodiscard]] constexpr auto get() const { return _handle; } }; class PoolingDescriptor { @@ -71,45 +86,99 @@ class PoolingDescriptor { int py, PoolingMode mode ); - cudnnPoolingDescriptor_t descriptor; ~PoolingDescriptor(); + +private: + cudnnPoolingDescriptor_t _handle; + +public: + [[nodiscard]] constexpr auto get() const { return _handle; } }; class DropoutDescriptor { public: - explicit DropoutDescriptor(float drop_prob); - cudnnDropoutDescriptor_t descriptor; + explicit DropoutDescriptor(float dropProb); ~DropoutDescriptor(); Tensor& getDropoutStates(); + +private: + cudnnDropoutDescriptor_t _handle; + +public: + [[nodiscard]] constexpr auto get() const { return _handle; } + }; class RNNDescriptor { public: RNNDescriptor( fl::dtype type, - int hidden_size, - int num_layers, + int inputSize, + int hiddenSize, + int numLayers, RnnMode mode, bool bidirectional, DropoutDescriptor& dropout ); - cudnnRNNDescriptor_t descriptor; ~RNNDescriptor(); + +private: + cudnnRNNDescriptor_t _handle = nullptr; + + static constexpr auto mathType(fl::dtype type) { + return type == fl::dtype::f16 ? CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION : CUDNN_DEFAULT_MATH; + } + +public: + /** + * @return descriptor handle + */ + constexpr auto get() const { return _handle; } }; + +class RNNDataDescriptor { +public: + RNNDataDescriptor( + fl::dtype type, + Shape const& dims + ); + + ~RNNDataDescriptor(); + +private: + void create(); + void set(dtype type, int inputSize, int maxSeqSize, std::span sequenceSizes) const; + + cudnnRNNDataDescriptor_t _handle = nullptr; + +public: + /** + * @return descriptor handle + */ + constexpr auto get() const { return _handle; } +}; + +} + +namespace fl { + #define CUDNN_CHECK_ERR(expr) ::fl::cudnnCheckErr((expr)) void cudnnCheckErr(cudnnStatus_t status); -cudnnDataType_t cudnnMapToType(const fl::dtype& t); +cudnnDataType_t cudnnMapToType(fl::dtype const& t); -const void* kOne(const fl::dtype t); +void const* kOne(fl::dtype const t); -const void* kZero(const fl::dtype t); +void const* kZero(fl::dtype const t); // TODO: move this to CudnnAutogradExtension if we make it a singleton cudnnHandle_t getCudnnHandle(); -const CUDAStream& getCudnnStream(); +CUDAStream const& getCudnnStream(); + } // namespace fl + + diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp index 24b08c8..255620f 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp @@ -25,10 +25,10 @@ Tensor CudnnAutogradExtension::pool2d( const PoolingMode mode, std::shared_ptr ) { - auto inDesc = TensorDescriptor(input); + auto inDesc = TensorDescriptor{input}; // init pooling descriptor - auto poolDesc = PoolingDescriptor(wx, wy, sx, sy, px, py, mode); + auto poolDesc = PoolingDescriptor{wx, wy, sx, sy, px, py, mode}; // init output descriptor auto ix = input.dim(0); @@ -36,7 +36,7 @@ Tensor CudnnAutogradExtension::pool2d( auto ox = 1 + (ix + 2 * px - wx) / sx; auto oy = 1 + (iy + 2 * py - wy) / sy; - auto output = Tensor( + auto output = Tensor{ { ox, oy, @@ -44,8 +44,8 @@ Tensor CudnnAutogradExtension::pool2d( input.ndim() < 4 ? 1 : input.dim(3) }, input.type() - ); - auto outDesc = TensorDescriptor(output); + }; + auto outDesc = TensorDescriptor{output}; { DevicePtr inputraw(input); DevicePtr resultraw(output); @@ -60,12 +60,12 @@ Tensor CudnnAutogradExtension::pool2d( CUDNN_CHECK_ERR( cudnnPoolingForward( handle, - poolDesc.descriptor, + poolDesc.get(), one, - inDesc.descriptor, + inDesc.get(), inputraw.get(), zero, - outDesc.descriptor, + outDesc.get(), resultraw.get() ) ); @@ -90,11 +90,11 @@ Tensor CudnnAutogradExtension::pool2dBackward( const PoolingMode mode, std::shared_ptr ) { - auto i_desc = TensorDescriptor(input); - auto o_desc = TensorDescriptor(poolOutput); - auto p_desc = PoolingDescriptor(wx, wy, sx, sy, px, py, mode); + auto i_desc = TensorDescriptor{input}; + auto o_desc = TensorDescriptor{poolOutput}; + auto p_desc = PoolingDescriptor{wx, wy, sx, sy, px, py, mode}; - auto gradInput = Tensor(input.shape(), input.type()); + auto gradInput = Tensor{input.shape(), input.type()}; auto hndl = getCudnnHandle(); const auto& cudnnStream = getCudnnStream(); @@ -112,16 +112,16 @@ Tensor CudnnAutogradExtension::pool2dBackward( CUDNN_CHECK_ERR( cudnnPoolingBackward( hndl, - p_desc.descriptor, + p_desc.get(), oneg, - o_desc.descriptor, + o_desc.get(), outraw.get(), - o_desc.descriptor, + o_desc.get(), gradresultraw.get(), - i_desc.descriptor, + i_desc.get(), inraw.get(), zerog, - i_desc.descriptor, + i_desc.get(), gradinputraw.get() ) ); diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp index 17b242b..ccdeda0 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp @@ -5,404 +5,210 @@ * LICENSE file in the root directory of this source tree. */ + #include "flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h" #include +#include "flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.h" #include "flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h" -#include "flashlight/fl/common/DevicePtr.h" -#include "flashlight/fl/tensor/Compute.h" + namespace fl { namespace { - size_t getWorkspaceSize( - cudnnHandle_t handle, - const RNNDescriptor& rnnDesc, - const int seqLength, - const TensorDescriptorArray& xDescs - ) { - size_t workspaceSize; - CUDNN_CHECK_ERR( - cudnnGetRNNWorkspaceSize( - handle, - rnnDesc.descriptor, - seqLength, - xDescs.descriptors, - &workspaceSize - ) - ); - return workspaceSize; - } - - size_t getReserveSize( - cudnnHandle_t handle, - const RNNDescriptor& rnnDesc, - const int seqLength, - const TensorDescriptorArray& xDescs - ) { - size_t reserveSize; - CUDNN_CHECK_ERR( - cudnnGetRNNTrainingReserveSize( - handle, - rnnDesc.descriptor, - seqLength, - xDescs.descriptors, - &reserveSize - ) - ); - return reserveSize; - } - - void setCudnnRnnMathType(const Tensor& input, const RNNDescriptor& rnnDesc) { - if(input.type() == fl::dtype::f16) - CUDNN_CHECK_ERR( - cudnnSetRNNMatrixMathType( - rnnDesc.descriptor, - CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION - ) - ); - else - CUDNN_CHECK_ERR( - cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH) - ); - } - struct CudnnRnnAutogradPayload : public detail::AutogradPayloadData { Tensor reserveSpace; }; - } // namespace +void CudnnAutogradExtension::checkHiddenStateDims( + int const hiddenSize, + Tensor const& hiddenState, + int batchSize, + int totalLayers +) { + auto const& hxDims = hiddenState.shape(); + int const hxHiddenSize = static_cast(hxDims[0]); + int const hxBatchSize = hiddenState.ndim() < 2 ? 1 : static_cast(hxDims[1]); + int const hxTotalLayers = hiddenState.ndim() < 3 ? 1 : static_cast(hxDims[2]); + + if( + hxHiddenSize != hiddenSize || hxBatchSize != batchSize + || hxTotalLayers != totalLayers + ) + throw std::invalid_argument("invalid hidden state dims for RNN"); +} +void CudnnAutogradExtension::checkCellStateDims( + int const hiddenSize, + RnnMode const mode, + Tensor const& cellState, + int batchSize, + int totalLayers +) { + if(mode != RnnMode::LSTM || cellState.dim(0) != hiddenSize + || cellState.dim(1) != batchSize || cellState.dim(2) != totalLayers) + throw std::invalid_argument("invalid cell state dims for RNN"); +} + std::tuple CudnnAutogradExtension::rnn( - const Tensor& input, - const Tensor& hiddenStateIn, - const Tensor& cellStateIn, - const Tensor& weights, - const int hiddenSize, - const int numLayers, - const RnnMode mode, - const bool bidirectional, - const float dropProb, + Tensor const& input, + Tensor const& hiddenState, + Tensor const& cellState, + Tensor const& weights, + int const hiddenSize, + int const numLayers, + RnnMode const mode, + bool const bidirectional, + float const dropProb, std::shared_ptr autogradPayload ) { - FL_TENSOR_DTYPES_MATCH_CHECK(input, hiddenStateIn, cellStateIn, weights); + FL_TENSOR_DTYPES_MATCH_CHECK(input, hiddenState, cellState, weights); - bool train = (autogradPayload != nullptr); - auto payload = std::make_shared(); + bool const train = (autogradPayload != nullptr); + auto const payload = std::make_shared(); if(train) autogradPayload->data = payload; - Tensor x = input.asContiguousTensor(); - Tensor hiddenState = hiddenStateIn.asContiguousTensor(); - Tensor cellState = cellStateIn.asContiguousTensor(); + auto const x = input.asContiguousTensor(); + + auto const cHiddenState = hiddenState.asContiguousTensor(); + auto const cCellState = cellState.asContiguousTensor(); + + DropoutDescriptor dropout{dropProb}; - DropoutDescriptor dropout(dropProb); - RNNDescriptor rnnDesc( - input.type(), hiddenSize, numLayers, mode, bidirectional, dropout); - setCudnnRnnMathType(input, rnnDesc); + auto const& dims = max(x.shape(), {1, 1, 1}); + + + auto const inputSize = static_cast(dims[0]); + auto batchSize = static_cast(dims[1]); + auto seqLength = static_cast(dims[2]); + + + RNNDescriptor const rnnDesc{ + input.type(), + inputSize, + hiddenSize, + numLayers, + mode, + bidirectional, + dropout + }; - auto dims = x.shape(); - int inputSize = dims[0]; - int batchSize = dims.ndim() < 2 ? 1 : dims[1]; - int seqLength = dims.ndim() < 3 ? 1 : dims[2]; int totalLayers = numLayers * (bidirectional ? 2 : 1); int outSize = hiddenSize * (bidirectional ? 2 : 1); - TensorDescriptorArray xDescs( - seqLength, x.type(), {1, 1, inputSize, batchSize}); + if(!cHiddenState.isEmpty()) + checkHiddenStateDims(hiddenSize, cHiddenState, batchSize, totalLayers); - if(!hiddenState.isEmpty()) { - auto hxDims = hiddenState.shape(); - int hxHiddenSize = hxDims[0]; - int hxBatchSize = hiddenState.ndim() < 2 ? 1 : hxDims[1]; - int hxTotalLayers = hiddenState.ndim() < 3 ? 1 : hxDims[2]; + if(!cCellState.isEmpty()) + checkCellStateDims(hiddenSize, mode, cCellState, batchSize, totalLayers); - if( - !(hxHiddenSize == hiddenSize && hxBatchSize == batchSize - && hxTotalLayers == totalLayers) - ) - throw std::invalid_argument("invalid hidden state dims for RNN"); - } - if( - !cellState.isEmpty() - && !(mode == RnnMode::LSTM && cellState.dim(0) == hiddenSize - && cellState.dim(1) == batchSize && cellState.dim(2) == totalLayers) - ) - throw std::invalid_argument("invalid cell state dims for RNN"); + Shape const hDims = {1, hiddenSize, batchSize, totalLayers}; + TensorDescriptor const hxDesc{x.type(), hDims}; + TensorDescriptor const cxDesc{x.type(), hDims}; - Shape hDims = {1, hiddenSize, batchSize, totalLayers}; - TensorDescriptor hxDesc(x.type(), hDims); - TensorDescriptor cxDesc(x.type(), hDims); - - auto handle = getCudnnHandle(); - const auto& cudnnStream = getCudnnStream(); - - size_t paramSize; - CUDNN_CHECK_ERR( - cudnnGetRNNParamsSize( - handle, - rnnDesc.descriptor, - xDescs.descriptors[0], - ¶mSize, - cudnnMapToType(weights.type()) - ) - ); - if(paramSize != weights.bytes()) - throw std::invalid_argument( - "invalid # of parameters or wrong input shape for RNN" - ); - FilterDescriptor wDesc(weights); + Tensor y{{outSize, batchSize, seqLength}, input.type()}; - Tensor y({outSize, batchSize, seqLength}, input.type()); - TensorDescriptorArray yDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); + Tensor hy{{hiddenSize, batchSize, totalLayers}, x.type()}; - Tensor hy({hiddenSize, batchSize, totalLayers}, x.type()); - TensorDescriptor hyDesc(x.type(), hDims); - - Tensor cy; + Tensor cy{}; if(mode == RnnMode::LSTM) - cy = Tensor(hy.shape(), x.type()); - - TensorDescriptor cyDesc(x.type(), hDims); - - size_t workspaceSize = getWorkspaceSize(handle, rnnDesc, seqLength, xDescs); - size_t reserveSize = getReserveSize(handle, rnnDesc, seqLength, xDescs); - - Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8); - // Space must be reused between forward and backward for cuDNN - payload->reserveSpace = Tensor({static_cast(reserveSize)}, fl::dtype::b8); - - { - auto contiguousX = x.asContiguousTensor(); - auto contiguousWeights = weights.asContiguousTensor(); - DevicePtr xRaw(contiguousX); - DevicePtr hxRaw(hiddenState); - DevicePtr cxRaw(cellState); - DevicePtr wRaw(contiguousWeights); - DevicePtr yRaw(y); - DevicePtr hyRaw(hy); - DevicePtr cyRaw(cy); - DevicePtr workspaceRaw(workspace); - DevicePtr reserveSpaceRaw(payload->reserveSpace); - // ensure cudnn compute stream waits on input/output tensor streams - relativeSync( - cudnnStream, - { - contiguousX, hiddenState, cellState, contiguousWeights, y, hy, cy, - workspace, payload->reserveSpace, - } - ); - - CUDNN_CHECK_ERR( - cudnnRNNForwardTraining( - handle, - rnnDesc.descriptor, - seqLength, - xDescs.descriptors, - xRaw.get(), - hxDesc.descriptor, - hxRaw.get(), - cxDesc.descriptor, - cxRaw.get(), - wDesc.descriptor, - wRaw.get(), - yDesc.descriptors, - yRaw.get(), - hyDesc.descriptor, - hyRaw.get(), - cyDesc.descriptor, - cyRaw.get(), - workspaceRaw.get(), - workspaceSize, - reserveSpaceRaw.get(), - reserveSize - ) - ); - } + cy = Tensor{hy.shape(), x.type()}; + + cudnn_rnn_forward( + batchSize, + seqLength, + train, + rnnDesc, + x, + y, + weights, + cxDesc, + hxDesc, + hy, + cy, + cHiddenState, + cCellState, + payload->reserveSpace // output + ); - // ensure output tensor streams wait on cudnn compute stream - relativeSync({y, hy, cy}, cudnnStream); return std::make_tuple(y, hy, cy); } std::tuple CudnnAutogradExtension::rnnBackward( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const std::shared_ptr gradData, - const Tensor& output, - const int numLayers, - const int hiddenSize, - const RnnMode mode, - const bool bidirectional, - const float dropProb, + Tensor const& input, + Tensor const& hiddenState, + Tensor const& cellState, + Tensor const& weights, + std::shared_ptr const gradData, + Tensor const& output, + int const numLayers, + int const hiddenSize, + RnnMode const mode, + bool const bidirectional, + float const dropProb, std::shared_ptr autogradPayload ) { if(!autogradPayload) throw std::invalid_argument( "CudnnAutogradExtension::rnnBackward given null detail::AutogradPayload" ); - auto payload = - std::static_pointer_cast(autogradPayload->data); + auto const payload = std::static_pointer_cast(autogradPayload->data); - auto handle = getCudnnHandle(); - const auto& cudnnStream = getCudnnStream(); - - auto x = input.asContiguousTensor(); + auto const x = input.asContiguousTensor(); auto& y = output; - auto dims = x.shape(); - int inputSize = dims[0]; - int batchSize = dims.ndim() < 2 ? 1 : dims[1]; - int seqLength = dims.ndim() < 3 ? 1 : dims[2]; - int totalLayers = numLayers * (bidirectional ? 2 : 1); - int outSize = hiddenSize * (bidirectional ? 2 : 1); - - DropoutDescriptor dropout(dropProb); - RNNDescriptor rnnDesc(input.type(), hiddenSize, numLayers, mode, bidirectional, dropout); - setCudnnRnnMathType(input, rnnDesc); - - TensorDescriptorArray yDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); - TensorDescriptorArray dyDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); - - Shape hDims = {1, hiddenSize, batchSize, totalLayers}; - TensorDescriptor dhyDesc(x.type(), hDims); - TensorDescriptor dcyDesc(x.type(), hDims); - TensorDescriptor hxDesc(x.type(), hDims); - TensorDescriptor cxDesc(x.type(), hDims); + auto const& dims = x.shape(); + int const inputSize = dims[0]; + int const batchSize = dims.ndim() < 2 ? 1 : dims[1]; + int const seqLength = dims.ndim() < 3 ? 1 : dims[2]; + int const totalLayers = numLayers * (bidirectional ? 2 : 1); - Tensor dhx(hiddenState.shape(), hiddenState.type()); - Tensor dcx(cellState.shape(), cellState.type()); - TensorDescriptor dhxDesc(x.type(), hDims); - TensorDescriptor dcxDesc(x.type(), hDims); + DropoutDescriptor dropout{dropProb}; + RNNDescriptor const rnnDesc{input.type(), inputSize, hiddenSize, numLayers, mode, bidirectional, dropout}; - FilterDescriptor wDesc(weights); + Shape const hDims = {1, hiddenSize, batchSize, totalLayers}; + TensorDescriptor const hxDesc{x.type(), hDims}; + TensorDescriptor const cxDesc{x.type(), hDims}; - Tensor dx(input.shape(), input.type()); - TensorDescriptorArray dxDescs( - seqLength, dx.type(), {1, 1, inputSize, batchSize}); + Tensor dhx{hiddenState.shape(), hiddenState.type()}; + Tensor dcx{cellState.shape(), cellState.type()}; - size_t workspaceSize = - getWorkspaceSize(handle, rnnDesc, seqLength, dxDescs); - Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8); + Tensor dx{input.shape(), input.type()}; + Tensor dw = fl::full(weights.shape(), 0, weights.type()); auto& dy = gradData->dy; if(dy.isEmpty()) dy = fl::full(y.shape(), 0.0, y.type()); - auto& dhy = gradData->dhy; - auto& dcy = gradData->dcy; - - DevicePtr yRaw(output); - DevicePtr workspaceRaw(workspace); - DevicePtr reserveSpaceRaw(payload->reserveSpace); - // ensure cudnn compute stream waits on input/output tensor streams - relativeSync(cudnnStream, {output, workspace, payload->reserveSpace}); - - { - DevicePtr dyRaw(dy); // Has to be set to 0 if empty - DevicePtr dhyRaw(dhy); - DevicePtr dcyRaw(dcy); - - DevicePtr wRaw(weights); - - DevicePtr hxRaw(hiddenState); - DevicePtr cxRaw(cellState); - - DevicePtr dxRaw(dx); - DevicePtr dhxRaw(dhx); - DevicePtr dcxRaw(dcx); - // ensure cudnn compute stream waits on input/output tensor streams - relativeSync( - cudnnStream, - {dy, dhy, dcy, weights, hiddenState, cellState, dx, dhx, dcx} - ); - - /* We need to update reserveSpace even if we just want the - * weight gradients. */ - CUDNN_CHECK_ERR( - cudnnRNNBackwardData( - handle, - rnnDesc.descriptor, - seqLength, - yDesc.descriptors, - yRaw.get(), - dyDesc.descriptors, - dyRaw.get(), - dhyDesc.descriptor, - dhyRaw.get(), - dcyDesc.descriptor, - dcyRaw.get(), - wDesc.descriptor, - wRaw.get(), - hxDesc.descriptor, - hxRaw.get(), - cxDesc.descriptor, - cxRaw.get(), - dxDescs.descriptors, - dxRaw.get(), - dhxDesc.descriptor, - dhxRaw.get(), - dcxDesc.descriptor, - dcxRaw.get(), - workspaceRaw.get(), - workspaceSize, - reserveSpaceRaw.get(), - payload->reserveSpace.bytes() - ) - ); - } - - if(input.type() == fl::dtype::f16) - CUDNN_CHECK_ERR( - cudnnSetRNNMatrixMathType( - rnnDesc.descriptor, - CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION - ) - ); - else - CUDNN_CHECK_ERR( - cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH) - ); - TensorDescriptorArray xDescs( - seqLength, x.type(), {1, 1, inputSize, batchSize}); - Tensor dw = fl::full(weights.shape(), 0, weights.type()); - - FilterDescriptor dwDesc(dw); - - { - DevicePtr xRaw(x); - DevicePtr dwRaw(dw); - DevicePtr hxRaw(hiddenState); - // ensure cudnn compute stream waits on input/output tensor streams - relativeSync(cudnnStream, {x, dw, hiddenState}); - - CUDNN_CHECK_ERR( - cudnnRNNBackwardWeights( - handle, - rnnDesc.descriptor, - seqLength, - xDescs.descriptors, - xRaw.get(), - hxDesc.descriptor, - hxRaw.get(), - yDesc.descriptors, - yRaw.get(), - workspaceRaw.get(), - workspaceSize, - dwDesc.descriptor, - dwRaw.get(), - reserveSpaceRaw.get(), - payload->reserveSpace.bytes() - ) - ); - } + auto const& dhy = gradData->dhy; + auto const& dcy = gradData->dcy; + + cudnn_rnn_backward( + batchSize, + seqLength, + rnnDesc, + x, + y, + dy, + weights, + cxDesc, + hxDesc, + dhy, + dcy, + hiddenState, + cellState, + dx, + dhx, + dcx, + dw, + payload->reserveSpace + ); - // ensure output tensor streams wait on cudnn compute stream - relativeSync({dx, dhx, dcx, dw}, cudnnStream); return std::make_tuple(dx, dhx, dcx, dw); } + + } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp b/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp index b673281..29cf2dc 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp @@ -49,7 +49,7 @@ namespace { 1, 1, nfeatures, - static_cast(input.elements() / nfeatures) + static_cast(input.elements() / nfeatures) } ); else { @@ -59,7 +59,7 @@ namespace { inDescDims = Shape( { 1, - static_cast(input.elements() / (nfeatures * batchsz)), + static_cast(input.elements() / (nfeatures * batchsz)), nfeatures, batchsz } diff --git a/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp b/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp index f866f1c..06fe93c 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp @@ -97,7 +97,7 @@ namespace { } } - auto weightsFlat = weights.flatten().astype(weights.type()); + auto weightsFlat = weights.flatten().asType(weights.type()); // cuDNN RNN weights, for each layer, are arranged with a chunk of // input-hidden weights for each layer followed by a chunk of hidden-hidden // weights for each layer: diff --git a/flashlight/fl/common/DevicePtr.h b/flashlight/fl/common/DevicePtr.h index 4666413..3bdedea 100644 --- a/flashlight/fl/common/DevicePtr.h +++ b/flashlight/fl/common/DevicePtr.h @@ -1,8 +1,8 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ #pragma once @@ -68,7 +68,7 @@ class FL_API DevicePtr { template T* getAs() const { - return reinterpret_cast(ptr_); + return static_cast(ptr_); } protected: diff --git a/flashlight/fl/common/Serialization-inl.h b/flashlight/fl/common/Serialization-inl.h index 7777890..e55380d 100644 --- a/flashlight/fl/common/Serialization-inl.h +++ b/flashlight/fl/common/Serialization-inl.h @@ -209,9 +209,7 @@ void save( throw cereal::Exception( "Serialzation of sparse Tensor is not supported yet!" ); - std::vector vec(tensor.bytes()); - tensor.host(vec.data()); - ar(tensor.shape(), tensor.type(), vec); + ar(tensor.shape(), tensor.type(), tensor.host()); } template diff --git a/flashlight/fl/common/Types.h b/flashlight/fl/common/Types.h index dcb3168..3537402 100644 --- a/flashlight/fl/common/Types.h +++ b/flashlight/fl/common/Types.h @@ -16,7 +16,7 @@ namespace fl { namespace detail { - + // TODO remove, somebody smoked something before writing this /** * Precision specifications for autograd operators based on optimization level. */ diff --git a/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp b/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp index 4fe16aa..a042370 100644 --- a/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp +++ b/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp @@ -83,7 +83,7 @@ Variable AdaptiveEmbedding::forward(const Variable& input) { for(int tailIdx = 1; tailIdx < cutoff_.size(); tailIdx++) { Tensor tailMask = flatInput.tensor() < cutoff_[tailIdx] && flatInput.tensor() >= cutoff_[tailIdx - 1]; - if(fl::any(tailMask).asScalar()) { + if(fl::any_of(tailMask).asScalar()) { auto tailEmbedding = embedding( flatInput(tailMask) - cutoff_[tailIdx - 1], reorder(params_[tailIdx * 2], {1, 0}) diff --git a/flashlight/fl/contrib/modules/CMakeLists.txt b/flashlight/fl/contrib/modules/CMakeLists.txt index c1f56a6..3da2d83 100644 --- a/flashlight/fl/contrib/modules/CMakeLists.txt +++ b/flashlight/fl/contrib/modules/CMakeLists.txt @@ -13,4 +13,4 @@ target_sources( ${CMAKE_CURRENT_LIST_DIR}/Transformer.cpp ${CMAKE_CURRENT_LIST_DIR}/TDSBlock.cpp ${CMAKE_CURRENT_LIST_DIR}/SpecAugment.cpp - ) +) diff --git a/flashlight/fl/contrib/modules/Conformer.cpp b/flashlight/fl/contrib/modules/Conformer.cpp index 6559c08..5a7b16a 100644 --- a/flashlight/fl/contrib/modules/Conformer.cpp +++ b/flashlight/fl/contrib/modules/Conformer.cpp @@ -189,7 +189,7 @@ Variable Conformer::mhsa(const Variable& input, const Variable& inputPadMask) { Variable mask, posEmb; if(posEmbContextSize_ > 0) - posEmb = tile(params_[0].astype(input.type()), {1, 1, nHeads_ * bsz}); + posEmb = tile(params_[0].asType(input.type()), {1, 1, nHeads_ * bsz}); fl::Variable padMask; // TODO{fl::Tensor}{resize} - emulate the ArrayFire resize operation for @@ -222,7 +222,7 @@ Variable Conformer::conv(const Variable& _input) { // input C x T x B x 1 // apply first pointwise conv auto result = gatedlinearunit( - (*conv1_)(((*normConv1_)(input)).astype(input.type())), + (*conv1_)(((*normConv1_)(input)).asType(input.type())), 0 ); result = reorder(result, {1, 3, 0, 2}); @@ -231,7 +231,7 @@ Variable Conformer::conv(const Variable& _input) { result = (*convDepthWise_)(result); result = reorder(result, {2, 0, 3, 1}); // C x T x B x 1 - result = fl::swish(((*normConv2_)(result)).astype(input.type()), 1.); + result = fl::swish(((*normConv2_)(result)).asType(input.type()), 1.); // apply second pointwise conv result = dropout((*conv2_)(result), pDropout); return moddims(result, _input.shape()); @@ -260,7 +260,7 @@ std::vector Conformer::forward(const std::vector& input) { auto ffn1 = dropout( (*w12_)( dropout( - fl::swish((*w11_)(((*norm1_)(x)).astype(x.type())), 1.), + fl::swish((*w11_)(((*norm1_)(x)).asType(x.type())), 1.), pDropout ) ), @@ -275,14 +275,14 @@ std::vector Conformer::forward(const std::vector& input) { auto ffn2 = dropout( (*w22_)( dropout( - fl::swish((*w21_)(((*norm2_)(x)).astype(x.type())), 1.), + fl::swish((*w21_)(((*norm2_)(x)).asType(x.type())), 1.), pDropout ) ), pDropout ); x = x + f * 0.5 * ffn2; - x = ((*norm3_)(x)).astype(x.type()); + x = ((*norm3_)(x)).asType(x.type()); return {x}; } diff --git a/flashlight/fl/contrib/modules/PositionEmbedding.cpp b/flashlight/fl/contrib/modules/PositionEmbedding.cpp index a95bb08..66af82b 100644 --- a/flashlight/fl/contrib/modules/PositionEmbedding.cpp +++ b/flashlight/fl/contrib/modules/PositionEmbedding.cpp @@ -50,7 +50,7 @@ std::vector PositionEmbedding::forward( int n = input[0].dim(1); Variable posEmb = tileAs( - params_[0].astype(input[0].type())(fl::span, fl::range(0, n)), input[0]); + params_[0].asType(input[0].type())(fl::span, fl::range(0, n)), input[0]); if(dropout_ > 0.0 && train_) return {input[0] + dropout(posEmb, dropout_)}; else diff --git a/flashlight/fl/contrib/modules/Residual.cpp b/flashlight/fl/contrib/modules/Residual.cpp index 31f800b..f7d6a61 100644 --- a/flashlight/fl/contrib/modules/Residual.cpp +++ b/flashlight/fl/contrib/modules/Residual.cpp @@ -99,7 +99,7 @@ Variable Residual::forward(const Variable& input) { connectionOut = modules_[shortcut.second] ->forward({outputs[shortcut.first]}) .front(); - output = output + connectionOut.astype(output.type()); + output = output + connectionOut.asType(output.type()); } output = modules_[moduleIndex] ->forward({applyScale(output, layerIndex)}) @@ -115,7 +115,7 @@ Variable Residual::forward(const Variable& input) { connectionOut = modules_[shortcut.second] ->forward({outputs[shortcut.first]}) .front(); - output = output + connectionOut.astype(output.type()); + output = output + connectionOut.asType(output.type()); } return applyScale(output, nLayers); } diff --git a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp index 22ddc2a..5c55b41 100644 --- a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp +++ b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp @@ -79,8 +79,8 @@ std::vector SinusoidalPositionEmbedding::forward( // Generate the embedding transformation with the precomputed scale and shift // factors. positions = fl::sin( - positions * fl::tile(scale_.astype(numType), {1, nPositions}) - + fl::tile(cosShifts_.astype(numType), {1, nPositions}) + positions * fl::tile(scale_.asType(numType), {1, nPositions}) + + fl::tile(cosShifts_.asType(numType), {1, nPositions}) ); // Convert the positional embedding into a variable (for gradient tracking). Variable embeddingsPos = Variable(positions, false); diff --git a/flashlight/fl/contrib/modules/TDSBlock.cpp b/flashlight/fl/contrib/modules/TDSBlock.cpp index 3720ca9..bd7a86f 100644 --- a/flashlight/fl/contrib/modules/TDSBlock.cpp +++ b/flashlight/fl/contrib/modules/TDSBlock.cpp @@ -69,9 +69,9 @@ TDSBlock::TDSBlock( std::vector TDSBlock::forward(const std::vector& inputs) { auto out = inputs[0]; - out = module(0)->forward({out})[0].astype(out.type()) + out; + out = module(0)->forward({out})[0].asType(out.type()) + out; out = module(1)->forward({out})[0]; - out = module(2)->forward({out})[0].astype(out.type()) + out; + out = module(2)->forward({out})[0].asType(out.type()) + out; return module(3)->forward({out}); } diff --git a/flashlight/fl/contrib/modules/Transformer.cpp b/flashlight/fl/contrib/modules/Transformer.cpp index 5eff578..474ef1c 100644 --- a/flashlight/fl/contrib/modules/Transformer.cpp +++ b/flashlight/fl/contrib/modules/Transformer.cpp @@ -129,7 +129,7 @@ Variable Transformer::selfAttention(const std::vector& input) { Variable mask, posEmb; if(bptt_ > 0) posEmb = - tile(params_[0].astype(encoderInput.type()), {1, 1, nHeads_ * bsz}); + tile(params_[0].asType(encoderInput.type()), {1, 1, nHeads_ * bsz}); if(useMask_ && encoderInput.dim(1) > 1) // mask future if we use the previous state (then n is previous time) mask = getMask(n, input.size() == 3); @@ -201,11 +201,11 @@ std::vector Transformer::forward(const std::vector& input) { if(train_ && (fl::rand({1}).scalar() < pLayerdrop_)) f = 0.0; if(preLN_) { - auto h = (f * (*norm1_)(selfAttention(input))).astype(x.type()) + x; - return {f* (*norm2_)(mlp(h)).astype(h.type()) + h}; + auto h = (f * (*norm1_)(selfAttention(input))).asType(x.type()) + x; + return {f* (*norm2_)(mlp(h)).asType(h.type()) + h}; } else { - auto h = (*norm1_)((f* selfAttention(input)).astype(x.type()) + x); - return {(*norm2_)((f* mlp(h)).astype(h.type()) + h)}; + auto h = (*norm1_)((f* selfAttention(input)).asType(x.type()) + x); + return {(*norm2_)((f* mlp(h)).asType(h.type()) + h)}; } } diff --git a/flashlight/fl/dataset/BlobDataset.cpp b/flashlight/fl/dataset/BlobDataset.cpp index 67ae662..96533d4 100644 --- a/flashlight/fl/dataset/BlobDataset.cpp +++ b/flashlight/fl/dataset/BlobDataset.cpp @@ -20,9 +20,7 @@ BlobDatasetEntryBuffer::BlobDatasetEntryBuffer() = default; void BlobDatasetEntryBuffer::clear() { data_.clear(); } -int64_t BlobDatasetEntryBuffer::size() const { - return data_.size() / nFieldPerEntry_; -} +int64_t BlobDatasetEntryBuffer::size() const { return data_.size() / nFieldPerEntry_; } void BlobDatasetEntryBuffer::resize(int64_t size) { data_.resize(size * nFieldPerEntry_); } @@ -52,15 +50,11 @@ void BlobDatasetEntryBuffer::add(const BlobDatasetEntry& e) { char* BlobDatasetEntryBuffer::data() { return (char*) data_.data(); } -int64_t BlobDatasetEntryBuffer::bytes() const { - return data_.size() * sizeof(int64_t); -}; +int64_t BlobDatasetEntryBuffer::bytes() const { return data_.size() * sizeof(int64_t); }; BlobDataset::BlobDataset() = default; -int64_t BlobDataset::size() const { - return offsets_.size(); -} +int64_t BlobDataset::size() const { return offsets_.size(); } std::vector BlobDataset::get(const int64_t idx) const { std::vector sample; @@ -128,12 +122,12 @@ void BlobDataset::add(const BlobDataset& blob, int64_t chunkSize) { int64_t remainCopySize = copySize - nChunk * chunkSize; std::vector buffer; auto copyChunk = [&buffer, &blob, this, &blobOffset](int64_t size) { - buffer.resize(size); - blob.readData(blobOffset, buffer.data(), size); - blobOffset += size; - this->writeData(indexOffset_, buffer.data(), size); - this->indexOffset_ += size; - }; + buffer.resize(size); + blob.readData(blobOffset, buffer.data(), size); + blobOffset += size; + this->writeData(indexOffset_, buffer.data(), size); + this->indexOffset_ += size; + }; for(int64_t i = 0; i < nChunk; i++) copyChunk(chunkSize); if(remainCopySize > 0) @@ -168,14 +162,15 @@ Tensor BlobDataset::readArray(const BlobDatasetEntry& e, int i) const { ); else return keyval->second(buffer.data(), e.dims, e.type); - } else + } + else return Tensor(); } void BlobDataset::writeArray(const BlobDatasetEntry& e, const Tensor& array) { - std::vector buffer(array.bytes()); - array.host(buffer.data()); - writeData(e.offset, (char*) buffer.data(), buffer.size()); + auto const tensorData = array.host(); + + writeData(e.offset, tensorData.data(), tensorData.size()); } void BlobDataset::writeIndex() { diff --git a/flashlight/fl/dataset/BlobDataset.h b/flashlight/fl/dataset/BlobDataset.h index 7401fa2..28531d7 100644 --- a/flashlight/fl/dataset/BlobDataset.h +++ b/flashlight/fl/dataset/BlobDataset.h @@ -96,12 +96,15 @@ class FL_API BlobDataset : public Dataset { protected: void readIndex(); + //TODO modernize with std::span + /** * Write raw data in the blob. * Implementation must be thread-safe. * @param[in] offset Offset in the blob in bytes. * @param[in] data Raw data bytes. * @param[in] size Raw data size in bytes. + * @return byte size of written data */ virtual int64_t writeData(int64_t offset, const char* data, int64_t size) const = 0; diff --git a/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp b/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp index c73eaaf..f703cd9 100644 --- a/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp +++ b/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp @@ -110,7 +110,7 @@ void allReduce(fl::Tensor& tensor, bool async /* = false */) { size_t tensorSize = tensor.elements() * fl::getTypeSize(tensor.type()); if(tensorSize > cacheTensor_.elements()) cacheTensor_ = - fl::Tensor({static_cast(tensorSize)}, fl::dtype::b8); + fl::Tensor({static_cast(tensorSize)}, fl::dtype::b8); DevicePtr tensorPtr(tensor); DevicePtr cacheTensorPtr(cacheTensor_); memcpy(cacheTensorPtr.get(), tensorPtr.get(), tensorSize); diff --git a/flashlight/fl/examples/Benchmark.cpp b/flashlight/fl/examples/Benchmark.cpp index dde158e..b8d9710 100644 --- a/flashlight/fl/examples/Benchmark.cpp +++ b/flashlight/fl/examples/Benchmark.cpp @@ -70,7 +70,7 @@ double embedding() { int num_elems = 400; Variable input( - (fl::rand({num_elems}) * vocab_size).astype(fl::dtype::s32), false); + (fl::rand({num_elems}) * vocab_size).asType(fl::dtype::s32), false); Variable grad_output( fl::randn({embed_dim, num_elems}, fl::dtype::f32), false); diff --git a/flashlight/fl/examples/CMakeLists.txt b/flashlight/fl/examples/CMakeLists.txt index 187e738..41f49f9 100644 --- a/flashlight/fl/examples/CMakeLists.txt +++ b/flashlight/fl/examples/CMakeLists.txt @@ -1,7 +1,8 @@ cmake_minimum_required(VERSION 3.16) project(flashlight-examples LANGUAGES CXX C VERSION 0.4.0) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) # If building in source, we already have the flashlight target if (NOT TARGET flashlight) diff --git a/flashlight/fl/examples/Mnist.cpp b/flashlight/fl/examples/Mnist.cpp index 481e77e..6d23657 100644 --- a/flashlight/fl/examples/Mnist.cpp +++ b/flashlight/fl/examples/Mnist.cpp @@ -221,7 +221,7 @@ int read_int(std::ifstream& f) { template Tensor load_data( const std::string& im_file, - const std::vector& dims + const std::vector& dims ) { std::ifstream file(im_file, std::ios::binary); if(!file.is_open()) @@ -243,7 +243,7 @@ Tensor load_data( data.push_back(tmp); } - std::vector rdims(dims.rbegin(), dims.rend()); + std::vector rdims(dims.rbegin(), dims.rend()); // af is column-major return Tensor::fromBuffer(Shape(rdims), data.data(), MemoryLocation::Host); } diff --git a/flashlight/fl/meter/EditDistanceMeter.cpp b/flashlight/fl/meter/EditDistanceMeter.cpp index 65d2ccd..6871b5b 100644 --- a/flashlight/fl/meter/EditDistanceMeter.cpp +++ b/flashlight/fl/meter/EditDistanceMeter.cpp @@ -36,13 +36,9 @@ void EditDistanceMeter::add(const Tensor& output, const Tensor& target) { int len1 = output.dim(0); int len2 = target.dim(0); - int* in1raw = output.host(); - int* in2raw = target.host(); - auto err_state = levensteinDistance(in1raw, in2raw, len1, len2); - free(in1raw); - in1raw = nullptr; - free(in2raw); - in2raw = nullptr; + auto in1raw = output.host(); + auto in2raw = target.host(); + auto err_state = levensteinDistance(in1raw[0], in2raw[0], len1, len2); add(err_state, target.dim(0)); } diff --git a/flashlight/fl/meter/TopKMeter.cpp b/flashlight/fl/meter/TopKMeter.cpp index 399cff0..c5b9b6b 100644 --- a/flashlight/fl/meter/TopKMeter.cpp +++ b/flashlight/fl/meter/TopKMeter.cpp @@ -27,7 +27,7 @@ void TopKMeter::add(const Tensor& output, const Tensor& target) { Tensor maxVals, maxIds, match; topk(maxVals, maxIds, output, k_, 0); match = maxIds == fl::reshape(target, {1, target.dim(0), 1, 1}); - const Tensor correct = fl::any(match, {0}); + const Tensor correct = fl::any_of(match, {0}); correct_ += fl::countNonzero(correct).asScalar(); const int batchsize = target.dim(0); diff --git a/flashlight/fl/nn/Init.cpp b/flashlight/fl/nn/Init.cpp index 4644b33..af458cc 100644 --- a/flashlight/fl/nn/Init.cpp +++ b/flashlight/fl/nn/Init.cpp @@ -71,7 +71,7 @@ namespace detail { } Tensor erfinv(const Tensor& y) { - if(fl::any(fl::abs(y) >= 1.).scalar()) + if(fl::any_of(fl::abs(y) >= 1.).scalar()) throw std::runtime_error("[erfinv] input is out of range (-1, 1)"); double a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331}; double b[4] = {-2.118377725, 1.442710462, -0.329097515, 0.012229801}; @@ -90,7 +90,7 @@ namespace detail { num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0]; dem = (d[1] * z + d[0]) * z + 1.0; // TODO{fl::Tensor}{operator} - check af::sign - zero case? - z = fl::sign(y).astype(fl::dtype::f32); // -1 for negative, 1 for positive + z = fl::sign(y).asType(fl::dtype::f32); // -1 for negative, 1 for positive z = z * num / dem; x = x + z * !centralMask; @@ -98,8 +98,8 @@ namespace detail { x = x - (fl::erf(x) - y) / ((2.0 / std::sqrt(M_PI)) * fl::exp(-x * x)); x = x - (fl::erf(x) - y) / ((2.0 / std::sqrt(M_PI)) * fl::exp(-x * x)); if( - fl::any(fl::isnan(x)).asScalar() - || fl::any(fl::isinf(x)).asScalar() + fl::any_of(fl::isnan(x)).asScalar() + || fl::any_of(fl::isinf(x)).asScalar() ) throw std::runtime_error("[erfinv] invalid result"); return x; diff --git a/flashlight/fl/nn/Init.h b/flashlight/fl/nn/Init.h index 6bb0422..b8cc6ac 100644 --- a/flashlight/fl/nn/Init.h +++ b/flashlight/fl/nn/Init.h @@ -289,7 +289,7 @@ FL_API Variable constant( * \ingroup nn_init_utils */ template -Variable scalar(T val, fl::dtype type = dtype_traits::ctype, bool calcGrad = true) { +Variable scalar(T val, fl::dtype type = dtype_traits::fl_type, bool calcGrad = true) { return Variable(fromScalar(val, type), calcGrad); } diff --git a/flashlight/fl/nn/Utils.cpp b/flashlight/fl/nn/Utils.cpp index fc9e8b2..7230307 100644 --- a/flashlight/fl/nn/Utils.cpp +++ b/flashlight/fl/nn/Utils.cpp @@ -24,14 +24,14 @@ int64_t numTotalParams(std::shared_ptr module) { } bool allParamsClose( - const Module& a, - const Module& b, + Module const& a, + Module const& b, double absTolerance /* = 1e-5 */ ) { if(a.params().size() != b.params().size()) return false; - const auto aParams = a.params(); - const auto bParams = b.params(); + auto const aParams = a.params(); + auto const bParams = b.params(); for(int p = 0; p < aParams.size(); ++p) if(!allClose(aParams[p], bParams[p], absTolerance)) return false; @@ -74,8 +74,7 @@ namespace detail { break; case RnnMode::RELU: case RnnMode::TANH: - default: - break; + default: break; } return n_params; @@ -97,52 +96,49 @@ int derivePadding(int inSz, int filterSz, int stride, int pad, int dilation) { } Tensor join( - const std::vector& inputs, + std::vector const& inputs, double padValue /* = 0.0 */, int batchDim /* = -1 */ ) { if(inputs.empty()) - return Tensor(); + return {}; - Dim maxNumDims = 0; - for(const auto& in : inputs) - if(in.ndim() > maxNumDims) - maxNumDims = in.ndim(); + size_t maxNumDims = 0; + for(auto const& in : inputs) + maxNumDims = std::max(in.ndim(), maxNumDims); // If the batch dim > the max number of dims, make those dims singleton - int outNdims = std::max(batchDim + 1, static_cast(maxNumDims)); + int const outNdims = std::max(batchDim + 1, static_cast(maxNumDims)); - Shape maxDims(std::vector(outNdims, 1)); + Shape maxDims{std::vector(outNdims, 1)}; - fl::dtype type = inputs[0].type(); + auto const type = inputs[0].type(); bool isEmpty = true; - for(const auto& in : inputs) { + + for(auto const& in : inputs) { isEmpty = isEmpty && in.isEmpty(); for(int d = 0; d < in.ndim(); ++d) { maxDims[d] = std::max(maxDims[d], in.dim(d)); if(in.type() != type) - throw std::invalid_argument( - "join: all arrays should of same type for join" - ); + throw std::invalid_argument{"join: all arrays should of same type for join"}; } } if(batchDim < 0) batchDim = maxDims.ndim() - 1; if(batchDim < maxDims.ndim() && maxDims[batchDim] > 1) - throw std::invalid_argument( - "join: no singleton dim available for batching" - ); + throw std::invalid_argument{"join: no singleton dim available for batching"}; + maxDims[batchDim] = inputs.size(); if(isEmpty) return Tensor(maxDims, type); auto padSeq = fl::full(maxDims, padValue, type); - std::vector sel( - std::max(maxNumDims, static_cast(batchDim + 1)), fl::span); + std::vector sel{std::max(static_cast(maxNumDims), static_cast(batchDim + 1)), fl::span}; + for(int i = 0; i < inputs.size(); ++i) { for(int d = 0; d < maxNumDims; ++d) sel[d] = fl::range(inputs[i].dim(d)); - sel[batchDim] = fl::range(i, i + 1); + sel[batchDim] = fl::range(i, static_cast(i + 1)); if(!inputs[i].isEmpty()) padSeq(sel) = inputs[i]; } diff --git a/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp b/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp index 8ab04ea..c2cf989 100644 --- a/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp +++ b/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp @@ -66,7 +66,7 @@ Variable AdaptiveSoftMax::getFullLogProb( Tensor output({outputSize, batchSize}, inputs.type()); output( - fl::range(0, cutoff_[0] + static_cast(cutoff_.size()) - 1) + fl::range(0, cutoff_[0] + static_cast(cutoff_.size()) - 1) ) = headOutput.tensor(); @@ -115,7 +115,7 @@ Variable AdaptiveSoftMax::predict(const Variable& inputs) const { auto notInShortlist = (prediction >= cutoff_[0]); Variable ret = Variable(prediction, false); - if(fl::any(notInShortlist).asScalar()) { + if(fl::any_of(notInShortlist).asScalar()) { headOutput = logSoftmax(headOutput, 0); auto logProbTailPositions = getFullLogProb( inputsFlattened(fl::span, notInShortlist), diff --git a/flashlight/fl/nn/modules/Conv2D.cpp b/flashlight/fl/nn/modules/Conv2D.cpp index 83ec842..427f536 100644 --- a/flashlight/fl/nn/modules/Conv2D.cpp +++ b/flashlight/fl/nn/modules/Conv2D.cpp @@ -147,8 +147,8 @@ Variable Conv2D::forward(const Variable& input) { if(bias_) return conv2d( input, - params_[0].astype(input.type()), - params_[1].astype(input.type()), + params_[0].asType(input.type()), + params_[1].asType(input.type()), xStride_, yStride_, px, @@ -161,7 +161,7 @@ Variable Conv2D::forward(const Variable& input) { else return conv2d( input, - params_[0].astype(input.type()), + params_[0].asType(input.type()), xStride_, yStride_, px, diff --git a/flashlight/fl/nn/modules/Dropout.cpp b/flashlight/fl/nn/modules/Dropout.cpp index 8e40f6f..c0b2268 100644 --- a/flashlight/fl/nn/modules/Dropout.cpp +++ b/flashlight/fl/nn/modules/Dropout.cpp @@ -26,7 +26,7 @@ std::unique_ptr Dropout::clone() const { } std::string Dropout::prettyString() const { - return "Dropout (" + std::to_string(ratio_) + ")"; + return std::format("Dropout ({0})", ratio_); } } // namespace fl diff --git a/flashlight/fl/nn/modules/LayerNorm.cpp b/flashlight/fl/nn/modules/LayerNorm.cpp index e4dd845..b4f0d47 100644 --- a/flashlight/fl/nn/modules/LayerNorm.cpp +++ b/flashlight/fl/nn/modules/LayerNorm.cpp @@ -108,8 +108,8 @@ Variable LayerNorm::forward(const Variable& _input) { } if(affine_) { - Variable weight = params_[0].astype(output.type()); - Variable bias = params_[1].astype(output.type()); + Variable weight = params_[0].asType(output.type()); + Variable bias = params_[1].asType(output.type()); if(axisSize_ != kLnVariableAxisSize) { Shape affineDims = input.shape(); for(int ax : axisComplement_) @@ -118,8 +118,8 @@ Variable LayerNorm::forward(const Variable& _input) { throw std::invalid_argument( "[LayerNorm] Input size along the norm axis doesn't with axisSize." ); - weight = moddims(params_[0].astype(output.type()), affineDims); - bias = moddims(params_[1].astype(output.type()), affineDims); + weight = moddims(params_[0].asType(output.type()), affineDims); + bias = moddims(params_[1].asType(output.type()), affineDims); } output = tileAs(weight, input) * output + tileAs(bias, input); } diff --git a/flashlight/fl/nn/modules/Linear.cpp b/flashlight/fl/nn/modules/Linear.cpp index d6414bb..e690788 100644 --- a/flashlight/fl/nn/modules/Linear.cpp +++ b/flashlight/fl/nn/modules/Linear.cpp @@ -53,10 +53,10 @@ Variable Linear::forward(const Variable& input) { if(bias_) return linear( input, - params_[0].astype(input.type()), - params_[1].astype(input.type()) + params_[0].asType(input.type()), + params_[1].asType(input.type()) ); - return linear(input, params_[0].astype(input.type())); + return linear(input, params_[0].asType(input.type())); } void Linear::initialize() { diff --git a/flashlight/fl/nn/modules/Loss.cpp b/flashlight/fl/nn/modules/Loss.cpp index 2e38298..7d6de66 100644 --- a/flashlight/fl/nn/modules/Loss.cpp +++ b/flashlight/fl/nn/modules/Loss.cpp @@ -166,12 +166,12 @@ Variable AdaptiveSoftMaxLoss::forward( // Tail forwawrd for(int i = 0; i < cutoff.size() - 1; i++) { auto mask = (target >= cutoff[i]) && (target < cutoff[i + 1]); - if(!fl::any(mask.tensor()).scalar()) + if(!fl::any_of(mask.tensor()).scalar()) continue; auto indicesArray = fl::nonzero(mask.tensor()); headTarget = - headTarget + (mask * (cutoff[0] + i)).astype(headTarget.type()); + headTarget + (mask * (cutoff[0] + i)).asType(headTarget.type()); auto tailTarget = target(indicesArray) - cutoff[i]; auto selectedInput = embedding(Variable(indicesArray, false), input); auto tailOutput = matmul(params_[1 + i * 2], selectedInput); diff --git a/flashlight/fl/nn/modules/PrecisionCast.cpp b/flashlight/fl/nn/modules/PrecisionCast.cpp index 1669a5b..e016972 100644 --- a/flashlight/fl/nn/modules/PrecisionCast.cpp +++ b/flashlight/fl/nn/modules/PrecisionCast.cpp @@ -18,7 +18,7 @@ std::vector PrecisionCast::forward( ) { std::vector outputs; for(const auto& input : inputs) { - auto output = input.astype(targetType_); + auto output = input.asType(targetType_); outputs.push_back(output); } return outputs; diff --git a/flashlight/fl/nn/modules/RNN.cpp b/flashlight/fl/nn/modules/RNN.cpp index 3de1125..839e12e 100644 --- a/flashlight/fl/nn/modules/RNN.cpp +++ b/flashlight/fl/nn/modules/RNN.cpp @@ -80,9 +80,9 @@ std::vector RNN::forward(const std::vector& inputs) { auto rnnRes = rnn( input, - hiddenState.astype(input.type()), - cellState.astype(input.type()), - params_[0].astype(input.type()), + hiddenState.asType(input.type()), + cellState.asType(input.type()), + params_[0].asType(input.type()), hiddenSize_, numLayers_, mode_, diff --git a/flashlight/fl/runtime/DeviceType.cpp b/flashlight/fl/runtime/DeviceType.cpp index 94bb697..e97a3b2 100644 --- a/flashlight/fl/runtime/DeviceType.cpp +++ b/flashlight/fl/runtime/DeviceType.cpp @@ -9,16 +9,7 @@ namespace fl { -std::string deviceTypeToString(const DeviceType type) { - switch(type) { - case DeviceType::x64: return "x64"; - case DeviceType::CUDA: return "CUDA"; - } -} - -std::ostream& operator<<(std::ostream& os, const DeviceType& type) { return os << deviceTypeToString(type); } - -const std::unordered_set& getDeviceTypes() { +std::unordered_set const& getDeviceTypes() { static std::unordered_set types = { DeviceType::x64, DeviceType::CUDA diff --git a/flashlight/fl/runtime/DeviceType.h b/flashlight/fl/runtime/DeviceType.h index 7f906e7..7f1ac53 100644 --- a/flashlight/fl/runtime/DeviceType.h +++ b/flashlight/fl/runtime/DeviceType.h @@ -7,7 +7,9 @@ #pragma once +#include #include +#include #include #include @@ -20,10 +22,41 @@ namespace fl { * NOTE update `fl::getAllDeviceTypes` after changing enum values. */ enum class DeviceType { - x64, + DEVICE_TYPES_FIRST, + x64 = DEVICE_TYPES_FIRST, CUDA, + DEVICE_TYPES_SIZE, }; +namespace detail { + [[nodiscard]] constexpr auto to_index(DeviceType t) { return static_cast>(t); } + + [[nodiscard]] constexpr auto device_types_size() { return to_index(DeviceType::DEVICE_TYPES_SIZE); } + + constexpr std::array DEVICE_TYPES = [] { + std::array(DeviceType::DEVICE_TYPES_SIZE)> types{}; + + for(auto i = to_index(DeviceType::DEVICE_TYPES_FIRST); i < types.size(); i++) + types[i] = static_cast(i); + + return types; + }(); +} + + +/** + * Gets string representation of device type + * + * @return std::string_view to constexpr string literal + */ +[[nodiscard]] FL_API constexpr std::string_view to_string(DeviceType e) { + switch(e) { + case DeviceType::x64: return "x64"; + case DeviceType::CUDA: return "CUDA"; + default: return "unknown"; + } +} + #if FL_BACKEND_CUDA constexpr DeviceType kDefaultDeviceType = DeviceType::CUDA; #else @@ -31,22 +64,28 @@ constexpr DeviceType kDefaultDeviceType = DeviceType::x64; #endif /** - * Return a readable string representation of the given device type. - * - * @return a string that represents the given device type. + * @deprecated use @ref fl::to_string(DeviceType) instead */ -FL_API std::string deviceTypeToString(const DeviceType type); +FL_API inline std::string deviceTypeToString(DeviceType const type) { return std::string{to_string(type)}; } + + /** * Output a string representation of `type` to `os`. */ -FL_API std::ostream& operator<<(std::ostream& os, const DeviceType& type); +FL_API inline std::ostream& operator<<(std::ostream& os, DeviceType const& type) { return (os << to_string(type)); } /** * Returns all device types. * - * @return an immutable reference to a set of all device types. + * @return span of immutable device types. */ -FL_API const std::unordered_set& getDeviceTypes(); +[[nodiscard]] FL_API constexpr std::span device_types() { return detail::DEVICE_TYPES; } + +/** + * @deprecated use @ref device_types() instead + */ +FL_API std::unordered_set const& getDeviceTypes(); + } // namespace fl diff --git a/flashlight/fl/tensor/CMakeLists.txt b/flashlight/fl/tensor/CMakeLists.txt index 18fff1b..e16abc9 100644 --- a/flashlight/fl/tensor/CMakeLists.txt +++ b/flashlight/fl/tensor/CMakeLists.txt @@ -74,7 +74,7 @@ target_sources( ${CMAKE_CURRENT_LIST_DIR}/TensorBase.cpp ${CMAKE_CURRENT_LIST_DIR}/TensorAdapter.cpp ${CMAKE_CURRENT_LIST_DIR}/TensorExtension.cpp - ${CMAKE_CURRENT_LIST_DIR}/Types.cpp + ${CMAKE_CURRENT_LIST_DIR}/DTypes.cpp ) # Profiling -- TODO: move this to runtime things @@ -95,14 +95,9 @@ if(FL_USE_CUDA) endif() # Link CUDA components - if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.17) - find_package(CUDAToolkit REQUIRED cublas) - target_link_libraries(flashlight PRIVATE CUDA::cublas) - else() - # Remove old branch when requiring CMake >= 3.17 - target_link_libraries(flashlight PRIVATE ${CUDA_LIBRARIES}) - target_include_directories(flashlight PRIVATE ${CUDA_INCLUDE_DIRS}) - endif() + find_package(CUDAToolkit REQUIRED cublas) + target_link_libraries(flashlight PRIVATE CUDA::cublas) + if(FL_BUILD_PROFILING) # Try to find NVTX diff --git a/flashlight/fl/tensor/DTypes.cpp b/flashlight/fl/tensor/DTypes.cpp new file mode 100644 index 0000000..8cc152f --- /dev/null +++ b/flashlight/fl/tensor/DTypes.cpp @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: MIT + * + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) + */ + +#include "flashlight/fl/tensor/Types.h" + +#include +#include + +namespace fl { + +auto const STRING_DTYPE_MAP = [] { + std::unordered_map map{}; + map.reserve(detail::DTYPES_SIZE); + for(size_t i = 0; i < detail::DTYPES_SIZE; i++) { + auto type = static_cast(i); + map.emplace(to_string(type), type); + } + return map; +}(); + + +std::optional dtype_from_string(std::string_view str) { + auto const it = STRING_DTYPE_MAP.find(str); + + if(it == STRING_DTYPE_MAP.end()) + return {}; + + return {it->second}; +} + + + +} // namespace fl diff --git a/flashlight/fl/tensor/DTypes.h b/flashlight/fl/tensor/DTypes.h new file mode 100644 index 0000000..8299f7d --- /dev/null +++ b/flashlight/fl/tensor/DTypes.h @@ -0,0 +1,200 @@ +/* + * SPDX-License-Identifier: MIT + * + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) + */ +#pragma once +#include "flashlight/fl/common/Defines.h" + +#include +#include +#include +#include +#include + +namespace fl { + + +/** + * Enumeration of all supported types + */ +enum class dtype { + f16, // 16-bit float + f32, // 32-bit float + f64, // 64-bit float + b8, // 8-bit boolean + s16, // 16-bit signed integer + s32, // 32-bit signed integer + s64, // 64-bit signed integer + u8, // 8-bit unsigned integer + u16, // 16-bit unsigned integer + u32, // 32-bit unsigned integer + u64, // 64-bit unsigned integer + + DTYPES_SIZE +}; +/** + * Enumeration of the different type groups in @ref dtype + */ +enum class dtype_group { + FLOAT, + BOOL, + SIGNED, + UNSIGNED, + + DTYPE_GROUPS_SIZE +}; + + + +constexpr std::string_view to_string(dtype e) { + switch(e) { + case dtype::f16: return "f16"; + case dtype::f32: return "f32"; + case dtype::f64: return "f64"; + case dtype::b8: return "b8"; + case dtype::s16: return "s16"; + case dtype::s32: return "s32"; + case dtype::s64: return "s64"; + case dtype::u8: return "u8"; + case dtype::u16: return "u16"; + case dtype::u32: return "u32"; + case dtype::u64: return "u64"; + default: return "unknown"; + } +} + +[[nodiscard]] constexpr auto to_index(dtype d) { return static_cast>(d); } + +[[nodiscard]] constexpr auto to_index(dtype_group d) { return static_cast>(d); } + + +/** + * Library details, may change + */ +namespace detail { + constexpr size_t DTYPES_SIZE = to_index(dtype::DTYPES_SIZE); + + /** + * Array of dtype byte sizes + */ + constexpr auto dtype_sizes = [] { + std::array sizes{}; + sizes[to_index(dtype::f16)] = 2; + sizes[to_index(dtype::f32)] = 4; + sizes[to_index(dtype::f64)] = 8; + sizes[to_index(dtype::b8)] = 1; + sizes[to_index(dtype::s16)] = 2; + sizes[to_index(dtype::s32)] = 4; + sizes[to_index(dtype::s64)] = 8; + sizes[to_index(dtype::u8)] = 1; + sizes[to_index(dtype::u16)] = 2; + sizes[to_index(dtype::u32)] = 4; + sizes[to_index(dtype::u64)] = 8; + return sizes; + }(); + + constexpr size_t DTYPE_GROUPS = to_index(dtype_group::DTYPE_GROUPS_SIZE); + + /** + * Gets the dtype group for a c++ standard type + * @tparam T to get group for + * @return dtype group + */ + template + constexpr dtype_group dtype_group_from_type() { + if constexpr(std::is_floating_point_v) return dtype_group::FLOAT; + else if constexpr(std::same_as || std::same_as) return dtype_group::BOOL; + else if constexpr(std::is_signed_v) return dtype_group::SIGNED; + else if constexpr(std::is_unsigned_v) return dtype_group::UNSIGNED; + else + static_assert(DTYPE_GROUPS != 4, "unknown type group"); + return dtype_group{0}; + } + + constexpr auto dtype_group_begins = [] { + std::array begins{}; + begins[to_index(dtype_group::FLOAT)] = dtype::f16; + begins[to_index(dtype_group::BOOL)] = dtype::b8; + begins[to_index(dtype_group::SIGNED)] = dtype::s16; + begins[to_index(dtype_group::UNSIGNED)] = dtype::u8; + return begins; + }(); + + constexpr auto dtype_group_lasts = [] { + std::array lasts{}; + lasts[to_index(dtype_group::FLOAT)] = dtype::f64; + lasts[to_index(dtype_group::BOOL)] = dtype::b8; + lasts[to_index(dtype_group::SIGNED)] = dtype::s64; + lasts[to_index(dtype_group::UNSIGNED)] = dtype::u64; + return lasts; + }(); +} + +/** + * Gets the dtypes size in bytes + * @param[in] type to get size of + */ +[[nodiscard]] FL_API constexpr size_t size_of(dtype type) { return detail::dtype_sizes[to_index(type)]; } + +/** + * Gets the dtype groups first dtype enum index + * @param[in] group dtype group + */ +[[nodiscard]] FL_API constexpr size_t begin_of(dtype_group group) { + return to_index(detail::dtype_group_begins[to_index(group)]); +} +/** + * Gets the groups dtype enum end index (exclusive) + */ +[[nodiscard]] FL_API constexpr size_t end_of(dtype_group group) { + return to_index(detail::dtype_group_lasts[to_index(group)]) + 1; +} +/** + * Gets the size of the dtype group in the dtype enum + */ +[[nodiscard]] FL_API constexpr size_t size_of(dtype_group group) { return end_of(group) - begin_of(group); } + + +/** + * Returns the size of the type in bytes. + * + * @param[in] type the input type to query. + * @deprecated use @ref size_of(dtype) instead + */ +FL_API inline size_t getTypeSize(dtype type) { return size_of(type); } + +/** + * Convert a dtype to its string representation. + * @deprecated use @ref to_string(fl::dtype) instead + */ +FL_API inline std::string dtypeToString(dtype type) { return std::string{to_string(type)}; } + +/** + * Tries to parse dtype from string + * @param str type name + * @return dtype or empty if not found + */ +FL_API std::optional dtype_from_string(std::string_view str); + +/** + * Converts string to a Flashlight dtype + * + * @param[in] string type name as a string. + * + * @return returns the corresponding Flashlight dtype + * @deprecated use @dtype_from_string(std::string_view) instead + */ +FL_API inline fl::dtype stringToDtype(std::string const& string) { return *dtype_from_string(string); } + + +/** + * Write a type's string representation to an output stream. + */ +FL_API inline std::ostream& operator<<(std::ostream& ostream, dtype const& s) { + ostream << to_string(s); + return ostream; +} + +} \ No newline at end of file diff --git a/flashlight/fl/tensor/Meta.h b/flashlight/fl/tensor/Meta.h new file mode 100644 index 0000000..3df829a --- /dev/null +++ b/flashlight/fl/tensor/Meta.h @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: MIT + * + * Copyright (c) 2026 Lukas Thomann (see LICENSE) + */ +#pragma once +#include + +namespace fl { +namespace dev { + + template + concept is_any_of = (std::same_as || ...); + + + template + struct unique_tuple_impl; + + template requires(is_any_of) + struct unique_tuple_impl : unique_tuple_impl {}; + + template requires(!is_any_of) + struct unique_tuple_impl, T, Ts...> : unique_tuple_impl, Ts...> {}; + + template + struct unique_tuple_impl { + using types = Tuple; + }; +} + + + +template +struct unique_tuple : dev::unique_tuple_impl, Ts...> {}; + +template +using unique_tuple_t = unique_tuple::types; +} + +namespace fl { +namespace dev { + template + struct resolution_node { + T operator()(T); + }; + + template + struct overload_set : resolution_node... { + using resolution_node::operator()...; + }; +} + +/** + * Resolves the correct function overload type for T from Ts + * @tparam T resolve target + * @tparam Ts options to resolve from + */ +template +struct resolve_overload_from : std::invoke_result, T> {}; + + +/** + * shorthand for @ref resolve_overload_from + * @tparam T resolve target + * @tparam Ts options to resolve from + */ +template +using resolve_overload_from_t = resolve_overload_from::type; + +/** + * checks if T resolves to any of Ts via function overload resolution + * @tparam T resolve target + * @tparam Ts options to resolve from + */ +template +concept resolves_to_any_of = std::same_as, std::tuple> && requires { + typename resolve_overload_from::type; +}; +} diff --git a/flashlight/fl/tensor/Shape.cpp b/flashlight/fl/tensor/Shape.cpp index c0108ec..3d3f0f0 100644 --- a/flashlight/fl/tensor/Shape.cpp +++ b/flashlight/fl/tensor/Shape.cpp @@ -1,8 +1,8 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ #include "flashlight/fl/tensor/Shape.h" @@ -15,68 +15,58 @@ namespace fl { -Shape::Shape(std::vector d) : dims_(std::move(d)) {} +Shape::Shape(std::vector d) : _dims(std::move(d)) {} Shape::Shape(std::initializer_list d) : Shape(std::vector(d)) {} -const Dim kEmptyShapeNumberOfElements = 1; +Dim const kEmptyShapeNumberOfElements = 1; -void Shape::checkDimsOrThrow(const size_t dim) const { +void Shape::checkDimsOrThrow(size_t const dim) const { if(dim > ndim() - 1) { std::stringstream ss; ss << "Shape index " << std::to_string(dim) - << " out of bounds for shape with " << std::to_string(dims_.size()) - << " dimensions."; + << " out of bounds for shape with " << std::to_string(_dims.size()) + << " dimensions."; throw std::invalid_argument(ss.str()); } } Dim Shape::elements() const { - if(dims_.empty()) + if(_dims.empty()) return kEmptyShapeNumberOfElements; - return std::accumulate(dims_.begin(), dims_.end(), static_cast(1), std::multiplies()); + return std::accumulate(_dims.begin(), _dims.end(), static_cast(1), std::multiplies()); } -int Shape::ndim() const { - return dims_.size(); -} +size_t Shape::ndim() const { return _dims.size(); } -Dim Shape::dim(const size_t dim) const { +Dim Shape::dim(size_t const dim) const { checkDimsOrThrow(dim); - return dims_[dim]; + return _dims[dim]; } -Dim& Shape::operator[](const size_t dim) { +Dim& Shape::operator[](size_t const dim) { checkDimsOrThrow(dim); - return dims_[dim]; + return _dims[dim]; } -const Dim& Shape::operator[](const size_t dim) const { +Dim const& Shape::operator[](size_t const dim) const { checkDimsOrThrow(dim); - return dims_[dim]; + return _dims[dim]; } -bool Shape::operator==(const Shape& other) const { - return dims_ == other.dims_; -} +bool Shape::operator==(Shape const& other) const { return _dims == other._dims; } -bool Shape::operator!=(const Shape& other) const { - return !(this->operator==(other)); -} +bool Shape::operator!=(Shape const& other) const { return !(this->operator==(other)); } -bool Shape::operator==(const std::initializer_list& other) const { - return dims_.size() == other.size() - && std::equal(std::begin(dims_), std::end(dims_), std::begin(other)); +bool Shape::operator==(std::initializer_list const& other) const { + return _dims.size() == other.size() + && std::equal(std::begin(_dims), std::end(_dims), std::begin(other)); } -bool Shape::operator!=(const std::initializer_list& other) const { - return !(this->operator==(other)); -} +bool Shape::operator!=(std::initializer_list const& other) const { return !(this->operator==(other)); } -const std::vector& Shape::get() const { - return dims_; -} +std::vector const& Shape::get() const { return _dims; } -std::vector& Shape::get() { return dims_; }; +std::vector& Shape::get() { return _dims; }; std::string Shape::toString() const { std::stringstream ss; @@ -87,7 +77,7 @@ std::string Shape::toString() const { return ss.str(); } -std::ostream& operator<<(std::ostream& ostr, const Shape& s) { +std::ostream& operator<<(std::ostream& ostr, Shape const& s) { ostr << s.toString(); return ostr; } diff --git a/flashlight/fl/tensor/Shape.h b/flashlight/fl/tensor/Shape.h index 7f6ab41..5c02e7e 100644 --- a/flashlight/fl/tensor/Shape.h +++ b/flashlight/fl/tensor/Shape.h @@ -1,8 +1,8 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ #pragma once @@ -18,7 +18,8 @@ namespace fl { // The type of a dimension. -using Dim = long long; +using dim_t = int64_t; +using Dim = dim_t; /** * An object describing the dimensions of a tensor. @@ -44,13 +45,13 @@ using Dim = long long; class FL_API Shape { // Storage for the dimension values. Defaults to an empty Shape {0}, whereas // {} is a scalar shape. - std::vector dims_; + std::vector _dims; /** * Check if a dimension is valid (i.e. in bounds) given the current size of * the shape. If not valid, throws an exception. */ - void checkDimsOrThrow(const size_t dim) const; + void checkDimsOrThrow(size_t dim) const; public: Shape() = default; @@ -72,7 +73,8 @@ class FL_API Shape { /** * Initialize a Shape via an initializer list. */ - /* implicit */ Shape(std::initializer_list d); + /* implicit */ + Shape(std::initializer_list d); /** * @return the number of elements in a tensor that has the given shape. @@ -82,7 +84,7 @@ class FL_API Shape { /** * @return Number of dimensions in the shape. */ - int ndim() const; + size_t ndim() const; /** * Get the size of a given dimension in the number of arguments. Throws if the @@ -90,30 +92,30 @@ class FL_API Shape { * * @return the number of elements at the given dimension */ - Dim dim(const size_t dim) const; + Dim dim(size_t dim) const; /** * Returns a reference to the given index */ - Dim& operator[](const size_t dim); - const Dim& operator[](const size_t dim) const; + Dim& operator[](size_t dim); + Dim const& operator[](size_t dim) const; /** * Compares two shapes. Returns true if their dim vectors are equal. */ - bool operator==(const Shape& other) const; - bool operator!=(const Shape& other) const; + bool operator==(Shape const& other) const; + bool operator!=(Shape const& other) const; /** * Compare a shape to an initializer list. */ - bool operator==(const std::initializer_list& other) const; - bool operator!=(const std::initializer_list& other) const; + bool operator==(std::initializer_list const& other) const; + bool operator!=(std::initializer_list const& other) const; /** - * Gets a reference to the underying dims vector. + * Gets a reference to the underlying dims vector. */ - const std::vector& get() const; + std::vector const& get() const; std::vector& get(); /** @@ -125,6 +127,52 @@ class FL_API Shape { /** * Write a shape representation to an output stream. */ -FL_API std::ostream& operator<<(std::ostream& ostr, const Shape& s); +FL_API std::ostream& operator<<(std::ostream& ostr, Shape const& s); + + +/** + * Composes two shapes with the given operation. + * @param first shape + * @param second shape + * @tparam Op to apply to elements + * @tparam ExtendVal shapes of unequal size will be implicitly extended with this + * @return element wise composition + */ +template +FL_API Shape element_compose_op(Shape const& first, Shape const& second) { + auto& large = first.ndim() < second.ndim() ? second : first; + auto const outDim = large.ndim(); + auto const sharedDims = std::min(first.ndim(), second.ndim()); + + std::vector resultData(outDim); + + + for(int i = 0; i < sharedDims; i++) + resultData[i] = Op(first[i], second[i]); + + for(int i = sharedDims; i < outDim; i++) + resultData[i] = Op(large[i], ExtendVal); + + return Shape{resultData}; +} + + +/** + * Performs element wise max. + * @param first shape + * @param second shape + * @return element wise max composition + * @details shapes of unequal size will be extended with 0 + */ +FL_API inline Shape max(Shape const& first, Shape const& second) { + constexpr auto max_op = [](Dim x, Dim y) { return std::max(x, y); }; + + if(first.ndim() == 0) + return second; + if(second.ndim() == 0) + return first; + + return element_compose_op(first, second); +} } // namespace fl diff --git a/flashlight/fl/tensor/TensorAdapter.h b/flashlight/fl/tensor/TensorAdapter.h index 43aab5d..57c3015 100644 --- a/flashlight/fl/tensor/TensorAdapter.h +++ b/flashlight/fl/tensor/TensorAdapter.h @@ -182,7 +182,7 @@ class FL_API TensorAdapterBase { * @param[in] the type to which to cast the tensor * @return a tensor with element-wise cast to the new type */ - virtual Tensor astype(const dtype type) = 0; + virtual Tensor asType(const dtype type) = 0; /** * Index into a tensor with a variable number of indices. diff --git a/flashlight/fl/tensor/TensorBackend.cpp b/flashlight/fl/tensor/TensorBackend.cpp index 59dd817..5ba88e1 100644 --- a/flashlight/fl/tensor/TensorBackend.cpp +++ b/flashlight/fl/tensor/TensorBackend.cpp @@ -10,11 +10,11 @@ namespace fl { namespace detail { - bool areBackendsEqual(const Tensor& a, const Tensor& b) { return a.backendType() == b.backendType(); } + bool areBackendsEqual(Tensor const& a, Tensor const& b) { return a.backendType() == b.backendType(); } } // namespace detail -bool TensorBackend::isDataTypeSupported(const fl::dtype& dtype) const { +bool TensorBackend::isDataTypeSupported(fl::dtype const& dtype) const { bool supported = this->supportsDataType(dtype); for(auto& p : extensions_) supported &= p.second->isDataTypeSupported(dtype); @@ -22,75 +22,83 @@ bool TensorBackend::isDataTypeSupported(const fl::dtype& dtype) const { } Tensor TensorBackend::clip( - const Tensor& tensor, - const Tensor& low, - const double& high + Tensor const& tensor, + Tensor const& low, + double const& high ) { return clip( tensor, low, - full(tensor.shape(), high, dtype_traits::ctype) + full(tensor.shape(), high, tensor.type()) ); } Tensor TensorBackend::clip( - const Tensor& tensor, - const double& low, - const Tensor& high + Tensor const& tensor, + double const& low, + Tensor const& high ) { return clip( tensor, - full(tensor.shape(), low, dtype_traits::ctype), + // TODO review, truncated to float in original impl + full(tensor.shape(), low, tensor.type()), high ); } Tensor TensorBackend::clip( - const Tensor& tensor, - const double& low, - const double& high + Tensor const& tensor, + double const& low, + double const& high ) { return clip( tensor, - full(tensor.shape(), low, dtype_traits::ctype), - full(tensor.shape(), high, dtype_traits::ctype) + // TODO review, truncated to float in original impl + full(tensor.shape(), low, tensor.type()), + full(tensor.shape(), high, tensor.type()) ); } Tensor TensorBackend::where( - const Tensor& condition, - const Tensor& x, - const double& y + Tensor const& condition, + Tensor const& x, + double const& y ) { return where(condition, x, full(condition.shape(), y, x.type())); } Tensor TensorBackend::where( - const Tensor& condition, - const double& x, - const Tensor& y + Tensor const& condition, + double const& x, + Tensor const& y ) { return where(condition, full(condition.shape(), x, y.type()), y); } -Tensor TensorBackend::minimum(const Tensor& lhs, const double& rhs) { - return minimum(lhs, full(lhs.shape(), rhs, dtype_traits::ctype)); +Tensor TensorBackend::minimum(Tensor const& lhs, double const& rhs) { + // TODO review, truncated to float in original impl + return minimum(lhs, full(lhs.shape(), rhs, lhs.type())); } -Tensor TensorBackend::minimum(const double& lhs, const Tensor& rhs) { - return minimum(full(rhs.shape(), lhs, dtype_traits::ctype), rhs); +Tensor TensorBackend::minimum(double const& lhs, Tensor const& rhs) { + // TODO review, truncated to float in original impl + return minimum(full(rhs.shape(), lhs, rhs.type()), rhs); } -Tensor TensorBackend::maximum(const Tensor& lhs, const double& rhs) { - return maximum(lhs, full(lhs.shape(), rhs, dtype_traits::ctype)); +Tensor TensorBackend::maximum(Tensor const& lhs, double const& rhs) { + // TODO review, truncated to float in original impl + return maximum(lhs, full(lhs.shape(), rhs, lhs.type())); } -Tensor TensorBackend::maximum(const double& lhs, const Tensor& rhs) { - return maximum(full(rhs.shape(), lhs, dtype_traits::ctype), rhs); +Tensor TensorBackend::maximum(double const& lhs, Tensor const& rhs) { + // TODO review, truncated to float in original impl + return maximum(full(rhs.shape(), lhs, rhs.type()), rhs); } -Tensor TensorBackend::power(const Tensor& lhs, const double& rhs) { - return power(lhs, full(lhs.shape(), rhs, dtype_traits::ctype)); +Tensor TensorBackend::power(Tensor const& lhs, double const& rhs) { + // TODO review, truncated to float in original impl + return power(lhs, full(lhs.shape(), rhs, lhs.type())); } -Tensor TensorBackend::power(const double& lhs, const Tensor& rhs) { - return power(full(rhs.shape(), lhs, dtype_traits::ctype), rhs); +Tensor TensorBackend::power(double const& lhs, Tensor const& rhs) { + // TODO review, truncated to float in original impl + return power(full(rhs.shape(), lhs, rhs.type()), rhs); } } // namespace fl diff --git a/flashlight/fl/tensor/TensorBackend.h b/flashlight/fl/tensor/TensorBackend.h index cd7c250..a1e0db4 100644 --- a/flashlight/fl/tensor/TensorBackend.h +++ b/flashlight/fl/tensor/TensorBackend.h @@ -1,10 +1,9 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ - #pragma once #include @@ -40,18 +39,18 @@ class TensorBackend { virtual TensorBackendType backendType() const = 0; /* -------------------------- Compute Functions -------------------------- */ - virtual void eval(const Tensor& tensor) = 0; - virtual bool supportsDataType(const fl::dtype& dtype) const = 0; + virtual void eval(Tensor const& tensor) = 0; + virtual bool supportsDataType(fl::dtype const& dtype) const = 0; // Memory Management - virtual void getMemMgrInfo(const char* msg, const int deviceId, std::ostream* ostream) = 0; + virtual void getMemMgrInfo(char const* msg, int deviceId, std::ostream* ostream) = 0; virtual void setMemMgrLogStream(std::ostream* stream) = 0; - virtual void setMemMgrLoggingEnabled(const bool enabled) = 0; - virtual void setMemMgrFlushInterval(const size_t interval) = 0; + virtual void setMemMgrLoggingEnabled(bool enabled) = 0; + virtual void setMemMgrFlushInterval(size_t interval) = 0; /* -------------------------- Rand Functions -------------------------- */ - virtual void setSeed(const int seed) = 0; - virtual Tensor randn(const Shape& shape, dtype type) = 0; - virtual Tensor rand(const Shape& shape, dtype type) = 0; + virtual void setSeed(int seed) = 0; + virtual Tensor randn(Shape const& shape, dtype type) = 0; + virtual Tensor rand(Shape const& shape, dtype type) = 0; /* --------------------------- Tensor Operators --------------------------- * For operator documentation and expected behavior, see TensorBase.h. @@ -75,75 +74,76 @@ class TensorBackend { FL_CREATE_FUN_LITERAL_BACKEND_DECL(const unsigned short&); #undef FL_CREATE_FUN_LITERAL_BACKEND_DECL - virtual Tensor identity(const Dim dim, const dtype type) = 0; - virtual Tensor arange(const Shape& shape, const Dim seqDim, const dtype type) = 0; - virtual Tensor iota(const Shape& dims, const Shape& tileDims, const dtype type) = 0; + virtual Tensor identity(Dim dim, dtype type) = 0; + virtual Tensor arange(Shape const& shape, Dim seqDim, dtype type) = 0; + + virtual Tensor iota(Shape const& dims, Shape const& tileDims, dtype type) = 0; /************************ Shaping and Indexing *************************/ - virtual Tensor reshape(const Tensor& tensor, const Shape& shape) = 0; + virtual Tensor reshape(Tensor const& tensor, Shape const& shape) = 0; virtual Tensor transpose( - const Tensor& tensor, - const Shape& axes /* = {} */ + Tensor const& tensor, + Shape const& axes /* = {} */ ) = 0; - virtual Tensor tile(const Tensor& tensor, const Shape& shape) = 0; + virtual Tensor tile(Tensor const& tensor, Shape const& shape) = 0; virtual Tensor concatenate( - const std::vector& tensors, - const unsigned axis + std::vector const& tensors, + unsigned axis ) = 0; - virtual Tensor nonzero(const Tensor& tensor) = 0; + virtual Tensor nonzero(Tensor const& tensor) = 0; virtual Tensor pad( - const Tensor& input, - const std::vector>& padWidths, - const PadType type + Tensor const& input, + std::vector> const& padWidths, + PadType type ) = 0; /************************** Unary Operators ***************************/ - virtual Tensor exp(const Tensor& tensor) = 0; - virtual Tensor log(const Tensor& tensor) = 0; - virtual Tensor negative(const Tensor& tensor) = 0; - virtual Tensor logicalNot(const Tensor& tensor) = 0; - virtual Tensor log1p(const Tensor& tensor) = 0; - virtual Tensor sin(const Tensor& tensor) = 0; - virtual Tensor cos(const Tensor& tensor) = 0; - virtual Tensor sqrt(const Tensor& tensor) = 0; - virtual Tensor tanh(const Tensor& tensor) = 0; - virtual Tensor floor(const Tensor& tensor) = 0; - virtual Tensor ceil(const Tensor& tensor) = 0; - virtual Tensor rint(const Tensor& tensor) = 0; - virtual Tensor absolute(const Tensor& tensor) = 0; - virtual Tensor sigmoid(const Tensor& tensor) = 0; - virtual Tensor erf(const Tensor& tensor) = 0; - virtual Tensor flip(const Tensor& tensor, const unsigned dim) = 0; - virtual Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high) = 0; - virtual Tensor clip(const Tensor& tensor, const Tensor& low, const double& high); - virtual Tensor clip(const Tensor& tensor, const double& low, const Tensor& high); - virtual Tensor clip(const Tensor& tensor, const double& low, const double& high); - virtual Tensor roll(const Tensor& tensor, const int shift, const unsigned axis) = 0; - virtual Tensor isnan(const Tensor& tensor) = 0; - virtual Tensor isinf(const Tensor& tensor) = 0; - virtual Tensor sign(const Tensor& tensor) = 0; - virtual Tensor tril(const Tensor& tensor) = 0; - virtual Tensor triu(const Tensor& tensor) = 0; - virtual Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) = 0; - virtual Tensor where(const Tensor& condition, const Tensor& x, const double& y); - virtual Tensor where(const Tensor& condition, const double& x, const Tensor& y); + virtual Tensor exp(Tensor const& tensor) = 0; + virtual Tensor log(Tensor const& tensor) = 0; + virtual Tensor negative(Tensor const& tensor) = 0; + virtual Tensor logicalNot(Tensor const& tensor) = 0; + virtual Tensor log1p(Tensor const& tensor) = 0; + virtual Tensor sin(Tensor const& tensor) = 0; + virtual Tensor cos(Tensor const& tensor) = 0; + virtual Tensor sqrt(Tensor const& tensor) = 0; + virtual Tensor tanh(Tensor const& tensor) = 0; + virtual Tensor floor(Tensor const& tensor) = 0; + virtual Tensor ceil(Tensor const& tensor) = 0; + virtual Tensor rint(Tensor const& tensor) = 0; + virtual Tensor absolute(Tensor const& tensor) = 0; + virtual Tensor sigmoid(Tensor const& tensor) = 0; + virtual Tensor erf(Tensor const& tensor) = 0; + virtual Tensor flip(Tensor const& tensor, unsigned dim) = 0; + virtual Tensor clip(Tensor const& tensor, Tensor const& low, Tensor const& high) = 0; + virtual Tensor clip(Tensor const& tensor, Tensor const& low, double const& high); + virtual Tensor clip(Tensor const& tensor, double const& low, Tensor const& high); + virtual Tensor clip(Tensor const& tensor, double const& low, double const& high); + virtual Tensor roll(Tensor const& tensor, int shift, unsigned axis) = 0; + virtual Tensor isnan(Tensor const& tensor) = 0; + virtual Tensor isinf(Tensor const& tensor) = 0; + virtual Tensor sign(Tensor const& tensor) = 0; + virtual Tensor tril(Tensor const& tensor) = 0; + virtual Tensor triu(Tensor const& tensor) = 0; + virtual Tensor where(Tensor const& condition, Tensor const& x, Tensor const& y) = 0; + virtual Tensor where(Tensor const& condition, Tensor const& x, double const& y); + virtual Tensor where(Tensor const& condition, double const& x, Tensor const& y); virtual void topk( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned k, - const Dim axis, - const SortMode sortMode + Tensor const& input, + unsigned k, + Dim axis, + SortMode sortMode ) = 0; - virtual Tensor sort(const Tensor& input, const Dim axis, const SortMode sortMode) = 0; + virtual Tensor sort(Tensor const& input, Dim axis, SortMode sortMode) = 0; virtual void sort( Tensor& values, Tensor& indices, - const Tensor& input, - const Dim axis, - const SortMode sortMode + Tensor const& input, + Dim axis, + SortMode sortMode ) = 0; - virtual Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) = 0; + virtual Tensor argsort(Tensor const& input, Dim axis, SortMode sortMode) = 0; /************************** Binary Operators ***************************/ #define FL_BINARY_OP_TYPE_DECL(FUNC, TYPE) \ @@ -191,70 +191,70 @@ class TensorBackend { #undef FL_BINARY_OP_TYPE_DECL #undef FL_BINARY_OP_LITERALS_DECL - virtual Tensor minimum(const Tensor& lhs, const Tensor& rhs) = 0; - virtual Tensor minimum(const Tensor& lhs, const double& rhs); - virtual Tensor minimum(const double& lhs, const Tensor& rhs); - virtual Tensor maximum(const Tensor& lhs, const Tensor& rhs) = 0; - virtual Tensor maximum(const Tensor& lhs, const double& rhs); - virtual Tensor maximum(const double& lhs, const Tensor& rhs); - virtual Tensor power(const Tensor& lhs, const Tensor& rhs) = 0; - virtual Tensor power(const Tensor& lhs, const double& rhs); - virtual Tensor power(const double& lhs, const Tensor& rhs); + virtual Tensor minimum(Tensor const& lhs, Tensor const& rhs) = 0; + virtual Tensor minimum(Tensor const& lhs, double const& rhs); + virtual Tensor minimum(double const& lhs, Tensor const& rhs); + virtual Tensor maximum(Tensor const& lhs, Tensor const& rhs) = 0; + virtual Tensor maximum(Tensor const& lhs, double const& rhs); + virtual Tensor maximum(double const& lhs, Tensor const& rhs); + virtual Tensor power(Tensor const& lhs, Tensor const& rhs) = 0; + virtual Tensor power(Tensor const& lhs, double const& rhs); + virtual Tensor power(double const& lhs, Tensor const& rhs); /******************************* BLAS ********************************/ virtual Tensor matmul( - const Tensor& lhs, - const Tensor& rhs, + Tensor const& lhs, + Tensor const& rhs, MatrixProperty lhsProp, MatrixProperty rhsProp ) = 0; /************************** Reductions ***************************/ - virtual Tensor amin(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - virtual Tensor amax(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + virtual Tensor amin(Tensor const& input, std::vector const& axes, bool keepDims) = 0; + virtual Tensor amax(Tensor const& input, std::vector const& axes, bool keepDims) = 0; virtual void min( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims + Tensor const& input, + unsigned axis, + bool keepDims ) = 0; virtual void max( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims + Tensor const& input, + unsigned axis, + bool keepDims ) = 0; - virtual Tensor sum(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - virtual Tensor cumsum(const Tensor& input, const unsigned axis) = 0; - virtual Tensor argmax(const Tensor& input, const unsigned axis, const bool keepDims) = 0; - virtual Tensor argmin(const Tensor& input, const unsigned axis, const bool keepDims) = 0; - virtual Tensor mean(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - virtual Tensor median(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + virtual Tensor sum(Tensor const& input, std::vector const& axes, bool keepDims) = 0; + virtual Tensor cumsum(Tensor const& input, unsigned axis) = 0; + virtual Tensor argmax(Tensor const& input, unsigned axis, bool keepDims) = 0; + virtual Tensor argmin(Tensor const& input, unsigned axis, bool keepDims) = 0; + virtual Tensor mean(Tensor const& input, std::vector const& axes, bool keepDims) = 0; + virtual Tensor median(Tensor const& input, std::vector const& axes, bool keepDims) = 0; virtual Tensor var( - const Tensor& input, - const std::vector& axes, + Tensor const& input, + std::vector const& axes, bool bias, - const bool keepDims + bool keepDims ) = 0; - virtual Tensor std(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + virtual Tensor std(Tensor const& input, std::vector const& axes, bool keepDims) = 0; virtual Tensor norm( - const Tensor& input, - const std::vector& axes, + Tensor const& input, + std::vector const& axes, double p, - const bool keepDims + bool keepDims ) = 0; virtual Tensor countNonzero( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool keepDims ) = 0; - virtual Tensor any(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - virtual Tensor all(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + virtual Tensor any(Tensor const& input, std::vector const& axes, bool keepDims) = 0; + virtual Tensor all(Tensor const& input, std::vector const& axes, bool keepDims) = 0; /************************** Utils ***************************/ - virtual void print(const Tensor& tensor) = 0; + virtual void print(Tensor const& tensor) = 0; /** * Checks if a datatype is supported by a TensorBackend and its registered @@ -264,7 +264,7 @@ class TensorBackend { * * @return true if the data type is supported, false otherwise */ - virtual bool isDataTypeSupported(const fl::dtype& dtype) const final; + virtual bool isDataTypeSupported(fl::dtype const& dtype) const final; /********************* Tensor Extensions **********************/ template @@ -320,7 +320,8 @@ Tensor toTensorType(Tensor&& in) { in.shape(), in.type(), // TODO: use the void specialization instead of a reinterpret cast - reinterpret_cast(in.device()), // expects contiguous memory + reinterpret_cast(in.device()), + // expects contiguous memory in.location() ) ); @@ -328,29 +329,29 @@ Tensor toTensorType(Tensor&& in) { namespace detail { -/** - * Compare the backends of two tensors. - * - * @return true if the backends of both tensors are the same, else false. - */ + /** + * Compare the backends of two tensors. + * + * @return true if the backends of both tensors are the same, else false. + */ bool areBackendsEqual(const Tensor& a, const Tensor& b); -/** - * Compare the backends of multiple tensors. - * - * @return true if all tensors' backends are the same, false otherwise. - */ + /** + * Compare the backends of multiple tensors. + * + * @return true if all tensors' backends are the same, false otherwise. + */ template bool areBackendsEqual(const Tensor& a, const Tensor& b, const Args&... args) { return areBackendsEqual(a, b) && areBackendsEqual(a, args...) - && areBackendsEqual(b, args...); + && areBackendsEqual(b, args...); } -/** - * - * @return a reference to a tensor backend instance descripting the default - backend. - */ + /** + * + * @return a reference to a tensor backend instance descripting the default + backend. + */ TensorBackend& getDefaultBackend(); } // namespace detail diff --git a/flashlight/fl/tensor/TensorBase.cpp b/flashlight/fl/tensor/TensorBase.cpp index 7fde6f5..78f4829 100644 --- a/flashlight/fl/tensor/TensorBase.cpp +++ b/flashlight/fl/tensor/TensorBase.cpp @@ -1,15 +1,14 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ - #include "flashlight/fl/tensor/TensorBase.h" +#include #include #include -#include #include "flashlight/fl/tensor/DefaultTensorType.h" #include "flashlight/fl/tensor/TensorAdapter.h" @@ -31,181 +30,82 @@ std::unique_ptr Tensor::releaseAdapter() { return std::move(i Tensor::~Tensor() = default; -Tensor::Tensor(const Tensor& tensor) : impl_(tensor.impl_->clone()) {} +Tensor::Tensor(Tensor const& tensor) : impl_(tensor.impl_->clone()) {} Tensor::Tensor(Tensor&& other) noexcept : impl_(std::move(other.impl_)) {} Tensor::Tensor() : impl_(detail::getDefaultAdapter()) {} Tensor::Tensor( - const Shape& shape, + Shape const& shape, fl::dtype type, - const void* ptr, + void const* ptr, MemoryLocation memoryLocation ) : impl_(detail::getDefaultAdapter(shape, type, ptr, memoryLocation)) {} Tensor::Tensor( - const Dim nRows, - const Dim nCols, - const Tensor& values, - const Tensor& rowIdx, - const Tensor& colIdx, + Dim const nRows, + Dim const nCols, + Tensor const& values, + Tensor const& rowIdx, + Tensor const& colIdx, StorageType storageType -) : impl_(detail::getDefaultAdapter( - nRows, - nCols, - values, - rowIdx, - colIdx, - storageType)) {} +) : impl_( + detail::getDefaultAdapter( + nRows, + nCols, + values, + rowIdx, + colIdx, + storageType + ) +) {} Tensor::Tensor( - const Shape& shape, + Shape const& shape, fl::dtype type /* = fl::dtype::f32 */ -) : impl_(detail::getDefaultAdapter(shape, - type)) {} +) : impl_(detail::getDefaultAdapter(shape, type)) {} -Tensor::Tensor(fl::dtype type) : impl_(detail::getDefaultAdapter(Shape({ 0 }), type)) {} +Tensor::Tensor(fl::dtype type) : impl_(detail::getDefaultAdapter(Shape({0}), type)) {} -Tensor Tensor::copy() const { - return impl_->copy(); -} +Tensor Tensor::copy() const { return impl_->copy(); } -Tensor Tensor::shallowCopy() const { - return impl_->shallowCopy(); -} +Tensor Tensor::shallowCopy() const { return impl_->shallowCopy(); } -const Shape& Tensor::shape() const { - return impl_->shape(); -} +Shape const& Tensor::shape() const { return impl_->shape(); } -Location Tensor::location() const { - return impl_->location(); -} +Location Tensor::location() const { return impl_->location(); } -size_t Tensor::elements() const { - return impl_->shape().elements(); -} +size_t Tensor::elements() const { return impl_->shape().elements(); } -Dim Tensor::dim(const size_t dim) const { - return shape().dim(dim); -} +Dim Tensor::dim(size_t const dim) const { return shape().dim(dim); } -int Tensor::ndim() const { - return shape().ndim(); -} +size_t Tensor::ndim() const { return shape().ndim(); } -bool Tensor::isEmpty() const { - return elements() == 0; -} +bool Tensor::isEmpty() const { return elements() == 0; } -bool Tensor::hasAdapter() const { - return impl_.get() != nullptr; -} +bool Tensor::hasAdapter() const { return impl_.get() != nullptr; } -size_t Tensor::bytes() const { - return elements() * getTypeSize(type()); -} +size_t Tensor::bytes() const { return elements() * getTypeSize(type()); } -dtype Tensor::type() const { - return impl_->type(); -} +dtype Tensor::type() const { return impl_->type(); } -bool Tensor::isSparse() const { - return impl_->isSparse(); -} +bool Tensor::isSparse() const { return impl_->isSparse(); } -Tensor Tensor::astype(const dtype type) const { - return impl_->astype(type); -} +Tensor Tensor::asType(dtype const type) const { return impl_->asType(type); } -Tensor Tensor::operator()(const std::vector& indices) const { - return impl_->index(indices); -} +Tensor Tensor::operator()(std::vector const& indices) const { return impl_->index(indices); } -Tensor Tensor::flatten() const { - return impl_->flatten(); -} +Tensor Tensor::flatten() const { return impl_->flatten(); } -Tensor Tensor::flat(const Index& idx) const { - return impl_->flat(idx); -} +Tensor Tensor::flat(Index const& idx) const { return impl_->flat(idx); } -Tensor Tensor::asContiguousTensor() const { - return impl_->asContiguousTensor(); -} +Tensor Tensor::asContiguousTensor() const { return impl_->asContiguousTensor(); } -TensorBackendType Tensor::backendType() const { - return impl_->backendType(); -} +TensorBackendType Tensor::backendType() const { return impl_->backendType(); } -TensorBackend& Tensor::backend() const { - return impl_->backend(); -} +TensorBackend& Tensor::backend() const { return impl_->backend(); } -#define FL_CREATE_MEMORY_OPS(TYPE) \ - template<> FL_API TYPE Tensor::scalar() const { \ - if(isEmpty()) { \ - throw std::invalid_argument("Tensor::scalar called on empty tensor"); \ - } \ - if(type() != dtype_traits::fl_type) { \ - throw std::invalid_argument( \ - "Tensor::scalar: requested type of " + \ - std::string(dtype_traits::getName()) + \ - " doesn't match tensor type, which is " + dtypeToString(type()) \ - ); \ - } \ - TYPE out; \ - impl_->scalar(&out); \ - return out; \ - } \ - \ - template<> FL_API TYPE * Tensor::device() const { \ - if(isEmpty()) { \ - return nullptr; \ - } \ - TYPE* out; \ - void** addr = reinterpret_cast(&out); \ - impl_->device(addr); \ - return out; \ - } \ - \ - template<> \ - FL_API void Tensor::device(TYPE * *ptr) const { \ - if(isEmpty()) { \ - return; \ - } \ - impl_->device(reinterpret_cast(ptr)); \ - } \ - \ - template<> FL_API TYPE * Tensor::host() const { \ - if(isEmpty()) { \ - return nullptr; \ - } \ - TYPE* out = reinterpret_cast(new char[bytes()]); \ - impl_->host(out); \ - return out; \ - } \ - \ - template<> \ - FL_API void Tensor::host(TYPE * ptr) const { \ - if(!isEmpty()) { \ - impl_->host(ptr); \ - } \ - } -FL_CREATE_MEMORY_OPS(int); -FL_CREATE_MEMORY_OPS(unsigned); -FL_CREATE_MEMORY_OPS(char); -FL_CREATE_MEMORY_OPS(unsigned char); -FL_CREATE_MEMORY_OPS(long); -FL_CREATE_MEMORY_OPS(unsigned long); -FL_CREATE_MEMORY_OPS(long long); -FL_CREATE_MEMORY_OPS(unsigned long long); -FL_CREATE_MEMORY_OPS(double); -FL_CREATE_MEMORY_OPS(float); -FL_CREATE_MEMORY_OPS(short); -FL_CREATE_MEMORY_OPS(unsigned short); -// void specializations template<> FL_API void* Tensor::device() const { if(isEmpty()) @@ -215,61 +115,37 @@ FL_API void* Tensor::device() const { return out; } -template<> -FL_API void Tensor::device(void** ptr) const { - if(isEmpty()) - return; - impl_->device(ptr); -} - -template<> -FL_API void* Tensor::host() const { +void* Tensor::raw_host() const { if(isEmpty()) return nullptr; - void* out = reinterpret_cast(new char[bytes()]); + auto* out = reinterpret_cast(new char[bytes()]); impl_->host(out); return out; } -template<> -FL_API void Tensor::host(void* ptr) const { - impl_->host(ptr); +void Tensor::raw_host(void* dst) const { + if(!isEmpty()) + impl_->host(dst); } -#undef FL_CREATE_MEMORY_OPS -void Tensor::unlock() const { - impl_->unlock(); -} -bool Tensor::isLocked() const { - return impl_->isLocked(); -} +void Tensor::unlock() const { impl_->unlock(); } -bool Tensor::isContiguous() const { - return impl_->isContiguous(); -} +bool Tensor::isLocked() const { return impl_->isLocked(); } -Shape Tensor::strides() const { - return impl_->strides(); -} +bool Tensor::isContiguous() const { return impl_->isContiguous(); } -const Stream& Tensor::stream() const { - return impl_->stream(); -} +Shape Tensor::strides() const { return impl_->strides(); } -void Tensor::setContext(void* context) { impl_->setContext(context); } +Stream const& Tensor::stream() const { return impl_->stream(); } -void* Tensor::getContext() const { - return impl_->getContext(); -} +void Tensor::setContext(void* context) const { impl_->setContext(context); } -std::string Tensor::toString() const { - return impl_->toString(); -} +void* Tensor::getContext() const { return impl_->getContext(); } -std::ostream& Tensor::operator<<(std::ostream& ostr) const { - return impl_->operator<<(ostr); -} +std::string Tensor::toString() const { return impl_->toString(); } + +std::ostream& Tensor::operator<<(std::ostream& ostr) const { return impl_->operator<<(ostr); } /******************** Assignment Operators ********************/ #define FL_ASSIGN_OP_TYPE(OP, FUN, TYPE) \ @@ -311,6 +187,8 @@ FL_ASSIGN_OP(operator/=, inPlaceDivide); // Move assignment operator when `this` is a lvalue, e.g., `x = std::move(y)`. // In such cases, we let `this` take over the tensor data of `other`. Tensor& Tensor::operator=(Tensor&& other) & { + if(this == &other) + return *this; this->impl_ = std::move(other.impl_); return *this; } @@ -318,24 +196,41 @@ Tensor& Tensor::operator=(Tensor&& other) & { // Move assignment operator when `this` is a rvalue, e.g., `x(0) = // std::move(y)`. In such cases, we copy the data from `other` to `this`. Tensor& Tensor::operator=(Tensor&& other) && { + if(this == &other) + return *this; + this->impl_->assign(other); return *this; } // Copy assignment operator when `this` is a lvalue, e.g., `x = y`. // In such cases, we let `this` take over the _cloned_ data from `other`. -Tensor& Tensor::operator=(const Tensor& other) & { +Tensor& Tensor::operator=(Tensor const& other) & { + if(this == &other) + return *this; + this->impl_ = other.impl_->clone(); return *this; } // Copy assignment operator when `this` is a lvalue, e.g., `x(0) = y`. // In such cases, we copy the data from `other` to `this`. -Tensor& Tensor::operator=(const Tensor& other) && { +Tensor& Tensor::operator=(Tensor const& other) && { + if(this == &other) + return *this; + this->impl_->assign(other); return *this; } + +void Tensor::scalar_impl(void* out) const { impl_->scalar(&out); } + +} + + +namespace fl { + /* --------------------------- Tensor Operators --------------------------- */ /******************** Tensor Creation Functions ********************/ @@ -361,13 +256,25 @@ FL_CREATE_FUN_LITERAL_TYPE(const short&); FL_CREATE_FUN_LITERAL_TYPE(const unsigned short&); #undef FL_CREATE_FUN_LITERAL_TYPE -Tensor identity(const Dim dim, const dtype type) { return defaultTensorBackend().identity(dim, type); } +Tensor identity(Dim const dim, dtype const type) { return defaultTensorBackend().identity(dim, type); } + +namespace { + template + Tensor arange_impl(T start, T end, T step, dtype type) { + return fl::dispatch_dtype( + type, + [&]() { + return fl::arange({static_cast((end - start) / step)}, 0, type) + * static_cast(step) + + static_cast(start); + } + ); + } +} // namespace #define FL_ARANGE_FUN_DEF(TYPE) \ - template<> FL_API Tensor arange(TYPE start, TYPE end, TYPE step, const dtype type) { \ - return fl::arange({static_cast((end - start) / step)}, 0, type) * \ - step + \ - start; \ + template<> FL_API Tensor arange(TYPE start, TYPE end, TYPE step, dtype type) { \ + return arange_impl(start, end, step, type); \ } FL_ARANGE_FUN_DEF(const double&); FL_ARANGE_FUN_DEF(const float&); @@ -378,131 +285,128 @@ FL_ARANGE_FUN_DEF(const unsigned long&); FL_ARANGE_FUN_DEF(const long long&); FL_ARANGE_FUN_DEF(const unsigned long long&); -Tensor arange(const Shape& shape, const Dim seqDim, const dtype type) { +Tensor arange(Shape const& shape, Dim const seqDim, dtype const type) { return defaultTensorBackend().arange(shape, seqDim, type); } -Tensor iota(const Shape& dims, const Shape& tileDims, const dtype type) { +Tensor iota(Shape const& dims, Shape const& tileDims, dtype const type) { return defaultTensorBackend().iota(dims, tileDims, type); } /************************ Shaping and Indexing *************************/ -Tensor reshape(const Tensor& tensor, const Shape& shape) { return tensor.backend().reshape(tensor, shape); } +Tensor reshape(Tensor const& tensor, Shape const& shape) { return tensor.backend().reshape(tensor, shape); } -Tensor transpose(const Tensor& tensor, const Shape& axes /* = {} */) { +Tensor transpose(Tensor const& tensor, Shape const& axes /* = {} */) { return tensor.backend().transpose(tensor, axes); } -Tensor tile(const Tensor& tensor, const Shape& shape) { return tensor.backend().tile(tensor, shape); } +Tensor tile(Tensor const& tensor, Shape const& shape) { return tensor.backend().tile(tensor, shape); } -Tensor concatenate(const std::vector& tensors, const unsigned axis) { +Tensor concatenate(std::vector const& tensors, unsigned const axis) { if(tensors.empty()) - throw std::invalid_argument("concatenate: called on empty set of tensors"); + throw std::invalid_argument{"concatenate: called on empty set of tensors"}; // Check all backends match - const TensorBackendType b = tensors.front().backendType(); - const bool matches = - std::all_of( - tensors.begin(), - tensors.end(), - [b](const Tensor& t) { - return t.backendType() == b; - } + TensorBackendType const b = tensors.front().backendType(); + bool const matches = + std::ranges::all_of( + tensors, + [b](Tensor const& t) { return t.backendType() == b; } ); if(!matches) - throw std::invalid_argument( + throw std::invalid_argument{ "concatenate: tried to concatenate tensors of different backends" - ); + }; return tensors.front().backend().concatenate(tensors, axis); } -Tensor nonzero(const Tensor& tensor) { return tensor.backend().nonzero(tensor); } +Tensor nonzero(Tensor const& tensor) { return tensor.backend().nonzero(tensor); } Tensor pad( - const Tensor& input, - const std::vector>& padWidths, - const PadType type + Tensor const& input, + std::vector> const& padWidths, + PadType const type ) { return input.backend().pad(input, padWidths, type); } /************************** Unary Operators ***************************/ -Tensor exp(const Tensor& tensor) { return tensor.backend().exp(tensor); } +Tensor exp(Tensor const& tensor) { return tensor.backend().exp(tensor); } -Tensor log(const Tensor& tensor) { return tensor.backend().log(tensor); } +Tensor log(Tensor const& tensor) { return tensor.backend().log(tensor); } -Tensor negative(const Tensor& tensor) { return tensor.backend().negative(tensor); } +Tensor negative(Tensor const& tensor) { return tensor.backend().negative(tensor); } -Tensor logicalNot(const Tensor& tensor) { return tensor.backend().logicalNot(tensor); } +Tensor logicalNot(Tensor const& tensor) { return tensor.backend().logicalNot(tensor); } -Tensor log1p(const Tensor& tensor) { return tensor.backend().log1p(tensor); } +Tensor log1p(Tensor const& tensor) { return tensor.backend().log1p(tensor); } -Tensor sin(const Tensor& tensor) { return tensor.backend().sin(tensor); } +Tensor sin(Tensor const& tensor) { return tensor.backend().sin(tensor); } -Tensor cos(const Tensor& tensor) { return tensor.backend().cos(tensor); } +Tensor cos(Tensor const& tensor) { return tensor.backend().cos(tensor); } -Tensor sqrt(const Tensor& tensor) { return tensor.backend().sqrt(tensor); } +Tensor sqrt(Tensor const& tensor) { return tensor.backend().sqrt(tensor); } -Tensor tanh(const Tensor& tensor) { return tensor.backend().tanh(tensor); } +Tensor tanh(Tensor const& tensor) { return tensor.backend().tanh(tensor); } -Tensor floor(const Tensor& tensor) { return tensor.backend().floor(tensor); } +Tensor floor(Tensor const& tensor) { return tensor.backend().floor(tensor); } -Tensor ceil(const Tensor& tensor) { return tensor.backend().ceil(tensor); } +Tensor ceil(Tensor const& tensor) { return tensor.backend().ceil(tensor); } -Tensor rint(const Tensor& tensor) { return tensor.backend().rint(tensor); } +Tensor rint(Tensor const& tensor) { return tensor.backend().rint(tensor); } -Tensor absolute(const Tensor& tensor) { return tensor.backend().absolute(tensor); } +Tensor absolute(Tensor const& tensor) { return tensor.backend().absolute(tensor); } -Tensor sigmoid(const Tensor& tensor) { return tensor.backend().sigmoid(tensor); } +Tensor sigmoid(Tensor const& tensor) { return tensor.backend().sigmoid(tensor); } -Tensor erf(const Tensor& tensor) { return tensor.backend().erf(tensor); } +Tensor erf(Tensor const& tensor) { return tensor.backend().erf(tensor); } -Tensor flip(const Tensor& tensor, const unsigned dim) { return tensor.backend().flip(tensor, dim); } +Tensor flip(Tensor const& tensor, unsigned const dim) { return tensor.backend().flip(tensor, dim); } -Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high) { +Tensor clip(Tensor const& tensor, Tensor const& low, Tensor const& high) { FL_TENSOR_BACKENDS_MATCH_CHECK(tensor, low, high); return tensor.backend().clip(tensor, low, high); } -Tensor clip(const Tensor& tensor, const Tensor& low, const double& high) { +Tensor clip(Tensor const& tensor, Tensor const& low, double const& high) { FL_TENSOR_BACKENDS_MATCH_CHECK(tensor, low); return tensor.backend().clip(tensor, low, high); } -Tensor clip(const Tensor& tensor, const double& low, const Tensor& high) { +Tensor clip(Tensor const& tensor, double const& low, Tensor const& high) { FL_TENSOR_BACKENDS_MATCH_CHECK(tensor, high); return tensor.backend().clip(tensor, low, high); } -Tensor clip(const Tensor& tensor, const double& low, const double& high) { +Tensor clip(Tensor const& tensor, double const& low, double const& high) { return tensor.backend().clip(tensor, low, high); } -Tensor roll(const Tensor& tensor, const int shift, const unsigned axis) { +Tensor roll(Tensor const& tensor, int const shift, unsigned const axis) { return tensor.backend().roll(tensor, shift, axis); } -Tensor isnan(const Tensor& tensor) { return tensor.backend().isnan(tensor); } +Tensor isnan(Tensor const& tensor) { return tensor.backend().isnan(tensor); } -Tensor isinf(const Tensor& tensor) { return tensor.backend().isinf(tensor); } +Tensor isinf(Tensor const& tensor) { return tensor.backend().isinf(tensor); } -Tensor sign(const Tensor& tensor) { return tensor.backend().sign(tensor); } +Tensor sign(Tensor const& tensor) { return tensor.backend().sign(tensor); } -Tensor tril(const Tensor& tensor) { return tensor.backend().tril(tensor); } +Tensor tril(Tensor const& tensor) { return tensor.backend().tril(tensor); } -Tensor triu(const Tensor& tensor) { return tensor.backend().triu(tensor); } +Tensor triu(Tensor const& tensor) { return tensor.backend().triu(tensor); } -Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) { +Tensor where(Tensor const& condition, Tensor const& x, Tensor const& y) { FL_TENSOR_BACKENDS_MATCH_CHECK(condition, x, y); return condition.backend().where(condition, x, y); } -Tensor where(const Tensor& condition, const Tensor& x, const double& y) { +Tensor where(Tensor const& condition, Tensor const& x, double const& y) { FL_TENSOR_BACKENDS_MATCH_CHECK(condition, x); return condition.backend().where(condition, x, y); } -Tensor where(const Tensor& condition, const double& x, const Tensor& y) { +Tensor where(Tensor const& condition, double const& x, Tensor const& y) { FL_TENSOR_BACKENDS_MATCH_CHECK(condition, y); return condition.backend().where(condition, x, y); } @@ -510,28 +414,28 @@ Tensor where(const Tensor& condition, const double& x, const Tensor& y) { void topk( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned k, - const Dim axis, - const SortMode sortMode /* = SortMode::Descending */ + Tensor const& input, + unsigned const k, + Dim const axis, + SortMode const sortMode /* = SortMode::Descending */ ) { FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input); input.backend().topk(values, indices, input, k, axis, sortMode); } -Tensor sort(const Tensor& input, const Dim axis, const SortMode sortMode) { +Tensor sort(Tensor const& input, Dim const axis, SortMode const sortMode) { return input.backend().sort(input, axis, sortMode); } void sort( Tensor& values, Tensor& indices, - const Tensor& input, - const Dim axis, - const SortMode sortMode /* = SortMode::Descending */ + Tensor const& input, + Dim const axis, + SortMode const sortMode /* = SortMode::Descending */ ) { return values.backend().sort(values, indices, input, axis, sortMode); } -Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) { +Tensor argsort(Tensor const& input, Dim const axis, SortMode const sortMode) { return input.backend().argsort(input, axis, sortMode); } @@ -598,37 +502,37 @@ FL_BINARY_OP_DEF(>>, rShift); #undef FL_BINARY_OP_LITERALS_DEF #undef FL_BINARY_OP_LITERAL_TYPE_DEF -Tensor minimum(const Tensor& lhs, const Tensor& rhs) { +Tensor minimum(Tensor const& lhs, Tensor const& rhs) { FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); return lhs.backend().minimum(lhs, rhs); } -Tensor maximum(const Tensor& lhs, const Tensor& rhs) { +Tensor maximum(Tensor const& lhs, Tensor const& rhs) { FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); return lhs.backend().maximum(lhs, rhs); } -Tensor minimum(const Tensor& lhs, const double& rhs) { return lhs.backend().minimum(lhs, rhs); } +Tensor minimum(Tensor const& lhs, double const& rhs) { return lhs.backend().minimum(lhs, rhs); } -Tensor minimum(const double& lhs, const Tensor& rhs) { return rhs.backend().minimum(lhs, rhs); } +Tensor minimum(double const& lhs, Tensor const& rhs) { return rhs.backend().minimum(lhs, rhs); } -Tensor maximum(const Tensor& lhs, const double& rhs) { return lhs.backend().maximum(lhs, rhs); } +Tensor maximum(Tensor const& lhs, double const& rhs) { return lhs.backend().maximum(lhs, rhs); } -Tensor maximum(const double& lhs, const Tensor& rhs) { return rhs.backend().maximum(lhs, rhs); } +Tensor maximum(double const& lhs, Tensor const& rhs) { return rhs.backend().maximum(lhs, rhs); } -Tensor power(const Tensor& lhs, const Tensor& rhs) { +Tensor power(Tensor const& lhs, Tensor const& rhs) { FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); return lhs.backend().power(lhs, rhs); } -Tensor power(const Tensor& lhs, const double& rhs) { return lhs.backend().power(lhs, rhs); } +Tensor power(Tensor const& lhs, double const& rhs) { return lhs.backend().power(lhs, rhs); } -Tensor power(const double& lhs, const Tensor& rhs) { return rhs.backend().power(lhs, rhs); } +Tensor power(double const& lhs, Tensor const& rhs) { return rhs.backend().power(lhs, rhs); } /******************************* BLAS ********************************/ Tensor matmul( - const Tensor& lhs, - const Tensor& rhs, + Tensor const& lhs, + Tensor const& rhs, MatrixProperty lhsProp, MatrixProperty rhsProp ) { @@ -639,23 +543,23 @@ Tensor matmul( /************************** Reductions ***************************/ Tensor amin( - const Tensor& input, - const std::vector& axes /* = {} */, - const bool keepDims /* = false */ + Tensor const& input, + std::vector const& axes /* = {} */, + bool const keepDims /* = false */ ) { return input.backend().amin(input, axes, keepDims); } Tensor amax( - const Tensor& input, - const std::vector& axes /* = {} */, - const bool keepDims /* = false */ + Tensor const& input, + std::vector const& axes /* = {} */, + bool const keepDims /* = false */ ) { return input.backend().amax(input, axes, keepDims); } void min( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims + Tensor const& input, + unsigned const axis, + bool const keepDims ) { FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input); return input.backend().min(values, indices, input, axis, keepDims); @@ -664,97 +568,97 @@ void min( void max( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims /* = false */ + Tensor const& input, + unsigned const axis, + bool const keepDims /* = false */ ) { FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input); return input.backend().max(values, indices, input, axis, keepDims); } Tensor sum( - const Tensor& input, - const std::vector& axes /* = {} */, - const bool keepDims /* = false */ + Tensor const& input, + std::vector const& axes /* = {} */, + bool const keepDims /* = false */ ) { return input.backend().sum(input, axes, keepDims); } -Tensor cumsum(const Tensor& input, const unsigned axis) { return input.backend().cumsum(input, axis); } +Tensor cumsum(Tensor const& input, unsigned const axis) { return input.backend().cumsum(input, axis); } Tensor argmax( - const Tensor& input, - const unsigned axis, - const bool keepDims /* = false */ + Tensor const& input, + unsigned const axis, + bool const keepDims /* = false */ ) { return input.backend().argmax(input, axis, keepDims); } Tensor argmin( - const Tensor& input, - const unsigned axis, - const bool keepDims /* = false */ + Tensor const& input, + unsigned const axis, + bool const keepDims /* = false */ ) { return input.backend().argmin(input, axis, keepDims); } Tensor mean( - const Tensor& input, - const std::vector& axes /* = {} */, - const bool keepDims /* = false */ + Tensor const& input, + std::vector const& axes /* = {} */, + bool const keepDims /* = false */ ) { return input.backend().mean(input, axes, keepDims); } Tensor median( - const Tensor& input, - const std::vector& axes /* = {} */, - const bool keepDims /* = false */ + Tensor const& input, + std::vector const& axes /* = {} */, + bool const keepDims /* = false */ ) { return input.backend().median(input, axes, keepDims); } Tensor var( - const Tensor& input, - const std::vector& axes /* = {} */, - const bool bias, - const bool keepDims /* = false */ + Tensor const& input, + std::vector const& axes /* = {} */, + bool const bias, + bool const keepDims /* = false */ ) { return input.backend().var(input, axes, bias, keepDims); } Tensor std( - const Tensor& input, - const std::vector& axes /* = {} */, - const bool keepDims /* = false */ + Tensor const& input, + std::vector const& axes /* = {} */, + bool const keepDims /* = false */ ) { return input.backend().std(input, axes, keepDims); } Tensor norm( - const Tensor& input, - const std::vector& axes /* = {} */, + Tensor const& input, + std::vector const& axes /* = {} */, double p /* = 2 */, - const bool keepDims /* = false */ + bool const keepDims /* = false */ ) { return input.backend().norm(input, axes, p, keepDims); } Tensor countNonzero( - const Tensor& input, - const std::vector& axes /* = {} */, - const bool keepDims /* = false */ + Tensor const& input, + std::vector const& axes /* = {} */, + bool const keepDims /* = false */ ) { return input.backend().countNonzero(input, axes, keepDims); } -Tensor any( - const Tensor& input, - const std::vector& axes /* = {} */, - const bool keepDims /* = false */ +Tensor any_of( + Tensor const& input, + std::vector const& axes /* = {} */, + bool const keepDims /* = false */ ) { return input.backend().any(input, axes, keepDims); } -Tensor all( - const Tensor& input, - const std::vector& axes /* = {} */, - const bool keepDims /* = false */ +Tensor all_of( + Tensor const& input, + std::vector const& axes /* = {} */, + bool const keepDims /* = false */ ) { return input.backend().all(input, axes, keepDims); } /************************** Utilities ***************************/ -std::ostream& operator<<(std::ostream& ostr, const Tensor& t) { +std::ostream& operator<<(std::ostream& ostr, Tensor const& t) { t.operator<<(ostr); return ostr; } -void print(const Tensor& tensor) { tensor.backend().print(tensor); } +void print(Tensor const& tensor) { tensor.backend().print(tensor); } bool allClose( - const fl::Tensor& a, - const fl::Tensor& b, - const double absTolerance + fl::Tensor const& a, + fl::Tensor const& b, + double const absTolerance ) { if(a.type() != b.type()) return false; @@ -762,28 +666,27 @@ bool allClose( return false; if(a.elements() == 0 && b.elements() == 0) return true; - return fl::amax(fl::abs(a - b)).astype(dtype::f64).scalar() - < absTolerance; + + auto const diff = fl::amax(fl::abs(a - b)).asType(dtype::f64).scalar(); + + return diff < absTolerance; } -bool isInvalidArray(const Tensor& tensor) { - return fl::any(fl::isnan(tensor)).asScalar() - || fl::any(fl::isinf(tensor)).asScalar(); +bool isInvalidArray(Tensor const& tensor) { + return fl::any_of(fl::isnan(tensor)).asScalar() + || fl::any_of(fl::isinf(tensor)).asScalar(); } -std::string tensorBackendTypeToString(const TensorBackendType type) { +std::string tensorBackendTypeToString(TensorBackendType const type) { switch(type) { - case TensorBackendType::Stub: - return "Stub"; - case TensorBackendType::Tracer: - return "Tracer"; - case TensorBackendType::ArrayFire: - return "ArrayFire"; + case TensorBackendType::Stub: return "Stub"; + case TensorBackendType::Tracer: return "Tracer"; + case TensorBackendType::ArrayFire: return "ArrayFire"; } throw std::runtime_error("Unreachable -- unrecognized tensor backend type"); } -std::ostream& operator<<(std::ostream& os, const TensorBackendType type) { +std::ostream& operator<<(std::ostream& os, TensorBackendType const type) { os << tensorBackendTypeToString(type); return os; } @@ -794,7 +697,7 @@ namespace detail { std::unique_ptr releaseAdapterUnsafe(Tensor& t) { return t.releaseAdapter(); } - bool areTensorTypesEqual(const Tensor& a, const Tensor& b) { return a.type() == b.type(); } + bool areTensorTypesEqual(Tensor const& a, Tensor const& b) { return a.type() == b.type(); } } // namespace detail diff --git a/flashlight/fl/tensor/TensorBase.h b/flashlight/fl/tensor/TensorBase.h index 82ae03f..a24e4a7 100644 --- a/flashlight/fl/tensor/TensorBase.h +++ b/flashlight/fl/tensor/TensorBase.h @@ -1,8 +1,8 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ #pragma once @@ -20,6 +20,9 @@ #include "flashlight/fl/tensor/Shape.h" #include "flashlight/fl/tensor/Types.h" +#include +#include + namespace fl { class Tensor; @@ -30,7 +33,7 @@ class Tensor; */ /// Enum for various tensor backends. -enum class TensorBackendType {Stub, Tracer, ArrayFire}; +enum class TensorBackendType { Stub, Tracer, ArrayFire }; // See TensorAdapter.h class TensorAdapterBase; @@ -45,12 +48,12 @@ struct Index; class Stream; /// Location of memory or tensors. -enum class Location {Host, Device}; +enum class Location { Host, Device }; /// Alias to make it semantically clearer when referring to buffer location using MemoryLocation = Location; /// Tensor storage types. -enum class StorageType {Dense = 0, CSR = 1, CSC = 2, COO = 3}; +enum class StorageType { Dense = 0, CSR = 1, CSC = 2, COO = 3 }; /* @} */ @@ -86,9 +89,9 @@ class FL_API Tensor { * compliance with TensorAdapter and is intentionally private. */ Tensor( - const Shape& shape, + Shape const& shape, fl::dtype type, - const void* ptr, + void const* ptr, MemoryLocation memoryLocation ); @@ -132,13 +135,13 @@ class FL_API Tensor { * Copy constructor - calls the implementation-defined copy constructor for * the TensorAdapter. */ - Tensor(const Tensor& tensor); + Tensor(Tensor const& tensor); /** * Move constructor - moves the pointer to the TensorAdapter - performs no * other operations. */ - Tensor(Tensor&& tensor) noexcept; + Tensor(Tensor&& other) noexcept; /** * Construct an empty tensor with the default tensor backend's tensor adapter. @@ -152,7 +155,7 @@ class FL_API Tensor { * @param[in] shape the shape of the tensor * @param[in] type (optional) the type of the tensor */ - explicit Tensor(const Shape& shape, fl::dtype type = fl::dtype::f32); + explicit Tensor(Shape const& shape, fl::dtype type = fl::dtype::f32); /** * Construct an empty tensor of a given type. @@ -174,11 +177,11 @@ class FL_API Tensor { * \todo Expand this API with getters as needed. */ Tensor( - const Dim nRows, - const Dim nCols, - const Tensor& values, - const Tensor& rowIdx, - const Tensor& colIdx, + Dim nRows, + Dim nCols, + Tensor const& values, + Tensor const& rowIdx, + Tensor const& colIdx, StorageType storageType ); @@ -212,7 +215,7 @@ class FL_API Tensor { template static Tensor fromVector(std::vector v) { return Tensor( - {static_cast(v.size())}, + {static_cast(v.size())}, fl::dtype_traits::fl_type, v.data(), Location::Host @@ -222,7 +225,7 @@ class FL_API Tensor { template static Tensor fromArray(std::array a) { return Tensor( - {static_cast(a.size())}, + {static_cast(a.size())}, fl::dtype_traits::fl_type, a.data(), Location::Host @@ -239,7 +242,7 @@ class FL_API Tensor { * @return a tensor with values and shape as given. */ template - static Tensor fromBuffer(Shape s, const T* ptr, Location memoryLocation) { + static Tensor fromBuffer(Shape s, T const* ptr, Location memoryLocation) { return Tensor(s, fl::dtype_traits::fl_type, ptr, memoryLocation); } @@ -256,7 +259,7 @@ class FL_API Tensor { static Tensor fromBuffer( Shape s, fl::dtype t, - const uint8_t* ptr, + uint8_t const* ptr, Location memoryLocation ) { return Tensor(s, t, ptr, memoryLocation); } @@ -270,7 +273,7 @@ class FL_API Tensor { * * @return the shape of the tensor */ - const Shape& shape() const; + Shape const& shape() const; /** * Get a tensor's location, host or some device. @@ -293,14 +296,14 @@ class FL_API Tensor { * * @return the number of elements at the given dimension */ - Dim dim(const size_t dim) const; + Dim dim(size_t dim) const; /** * Get the number of directions of the tensor. * * @return the number of dimensions */ - int ndim() const; + size_t ndim() const; /** * Returns true if the tensor has zero elements, else false. @@ -354,7 +357,7 @@ class FL_API Tensor { * @return an immutable reference to the stream that contains(ed) the * computations which create this tensor. */ - virtual const Stream& stream() const; + virtual Stream const& stream() const; /** * Returns a tensor with elements cast as a particular type @@ -362,7 +365,15 @@ class FL_API Tensor { * @param[in] type the type to which to cast the tensor * @return a tensor with element-wise cast to the new type */ - Tensor astype(const dtype type) const; + Tensor asType(dtype type) const; + + + /** + * @deprecated use @ref Tensor::asType(dtype) const instead + */ + Tensor astype(dtype type) const { return asType(type); } + + /** * Index into a tensor using a vector of fl::Index references. @@ -370,7 +381,7 @@ class FL_API Tensor { * @param[in] indices a vector of fl::Index references with which to index. * @return an indexed tensor */ - Tensor operator()(const std::vector& indices) const; + Tensor operator()(std::vector const& indices) const; /** * Index into a tensor using a variable number of fl::Index. @@ -379,12 +390,12 @@ class FL_API Tensor { * @return an indexed tensor */ template - Tensor operator()(const Ts&... args) const { - // TODO: add this back if acceptable with C++ 17 ABIs and a nvcc - // static_assert( - // std::conjunction...>::value, - // "Tensor index operator can only take Index-compatible types - " - // "fl::range, fl::Tensor, fl::span, and integer types."); + Tensor operator()(Ts const&... args) const { + static_assert( + (std::constructible_from && ...), + "Tensor index operator can only take Index-compatible types - " + "fl::range, fl::Tensor, fl::span, and integer types." + ); std::vector indices{{args...}}; return this->operator()(indices); } @@ -402,7 +413,7 @@ class FL_API Tensor { * * @return an indexed, 1D version of this tensor. */ - Tensor flat(const Index& idx) const; + Tensor flat(Index const& idx) const; /** * Return a copy (depending on copy-on-write behavior of the underlying @@ -425,9 +436,7 @@ class FL_API Tensor { * @return the tensor adapter. */ template - T& getAdapter() const { - return *static_cast(impl_.get()); - } + T& getAdapter() const { return *static_cast(impl_.get()); } /** * Return the TensorBackend associated with this tensor. @@ -446,8 +455,8 @@ class FL_API Tensor { * * @return a scalar of the first element in the tensor. */ - template - T scalar() const; + template + [[nodiscard]] T scalar() const; /** * Return a scalar of the specified type of the tensor. If the specified type @@ -457,100 +466,62 @@ class FL_API Tensor { * @return a scalar of the first element in the tensor cast to the specified * type. */ - template - T asScalar() const { - // Implicitly cast to the requested return type - switch(type()) { - case dtype::f16: - return astype(dtype::f32).scalar(); - case dtype::f32: - return scalar(); - case dtype::f64: - return scalar(); - case dtype::s32: - return scalar(); - case dtype::u32: - return scalar(); - case dtype::b8: - return scalar(); - case dtype::u8: - return scalar(); - case dtype::s64: - return scalar(); - case dtype::u64: - return scalar(); - case dtype::s16: - return scalar(); - case dtype::u16: - return scalar(); - default: - throw std::invalid_argument( - "Tensor::asScaler - no castable type exists." - ); - } - } + template + [[nodiscard]] T asScalar() const; /** * Return a pointer to the tensor's underlying data per a certain type. This * pointer exists on the computation device. * - * \note The memory allocated here will not be freed until Tensor:unlock() is + * @return the requested pointer on the device or `nullptr` if the tensor is empty + * + * @attention The memory allocated here will not be freed until Tensor:unlock() is * called. - * - * @return the requested pointer on the device. */ - template - T* device() const; + template + [[nodiscard]] T* device() const; /** - * Populate a pointer value with the address of a Tensor's underlying buffer - * on the computation device. + * Returns a pointer to the tensor's underlying data on the host. Copies the data + * to the host if it's located on the device. * - * \note The memory allocated here will not be freed until Tensor:unlock() is - * called. - * - * @param[in] ptr the pointer to populate with the Tensor's buffer location on - * device. + * @return host data pointer or `nullptr` if tensor is empty + * @attention memory ownership is transferred to the caller */ - template - void device(T** ptr) const; + [[nodiscard]] void* raw_host() const; /** - * Returns a pointer to the tensor's underlying data, but on the host. If the - * tensor is located on a device, makes a copy of device memory and returns a - * buffer on the host containing the relevant memory. + * Returns the tensor's underlying data on the host, creates a copy. + * Users who want to avoid copies may use @ref device() + * + * @throws std::logic_error if `Tensor::bytes() % sizeof(T) != 0` * - * @return the requested pointer on the host. + * @return vector of data */ - template - T* host() const; + template + [[nodiscard]] std::vector host() const; + /** - * Populates an existing buffer with the tensor's underlying data, but on the - * host. If the tensor is located on a device, makes a copy of device memory - * and returns a buffer on the host containing the relevant memory. - * - * @param[in] ptr a pointer to the region of memory to populate with tensor - * values + * Populates an existing host buffer with the tensor data. + * + * @param[in] dst span to write to + * @pre + * - span size >= @ref bytes() + * - @ref bytes() mod `sizeof(T)` == 0 */ - template - void host(T* ptr) const; + template + void host(std::span dst) const; /** - * Returns a vector on the host contaning a flat representation of the tensor. - * The resulting vector is a copy of the underlying tensor memory, even if on - * the host. + * Populates the existing host buffer with the tensor data. + * @param[in] dst to write to * - * @return a vector in host memory containing + * @pre buffer size must be >= @ref bytes() + * @post bytes 0 - @ref bytes written */ - template - std::vector toHostVector() const { - if(isEmpty()) - return std::vector(); - std::vector vec(this->elements()); - host(vec.data()); - return vec; - } + void raw_host(void* dst) const; + /** * Unlocks any device memory associated with the tensor that was acquired with @@ -578,9 +549,9 @@ class FL_API Tensor { * Stores arbitrary data on a tensor. For internal use/benchmarking only. This * may be a no-op for some backends. * - * @param[in] data a pointer to arbitrary data to pass to a tensor impl. + * @param[in] context a pointer to arbitrary data to pass to a tensor impl. */ - void setContext(void* data); + void setContext(void* context) const; /** * Gets arbitrary data stored on a tensor. For internal use/benchmarking only. @@ -643,8 +614,11 @@ class FL_API Tensor { */ Tensor& operator=(Tensor&& other) &; Tensor& operator=(Tensor&& other) &&; - Tensor& operator=(const Tensor& other) &; - Tensor& operator=(const Tensor& other) &&; + Tensor& operator=(Tensor const& other) &; + Tensor& operator=(Tensor const& other) &&; + +private: + void scalar_impl(void* out) const; }; /** @@ -664,7 +638,7 @@ class FL_API Tensor { * @return a tensor of the specified shape filled with the specified value */ template -FL_API Tensor fromScalar(const T& val, const dtype type = dtype_traits::ctype); +FL_API Tensor fromScalar(T const& val, dtype type = dtype_traits::fl_type); /** * Creates a new Tensor with a given Shape and filled with a particular value. @@ -677,9 +651,9 @@ FL_API Tensor fromScalar(const T& val, const dtype type = dtype_traits::ctype */ template FL_API Tensor full( - const Shape& dims, - const T& val, - const dtype type = dtype_traits::ctype + Shape const& dims, + T const& val, + dtype type = dtype_traits::fl_type ); /** @@ -688,14 +662,14 @@ FL_API Tensor full( * @param[in] dim the size of the dimension of the matrix (dim x dim) * @param[in] type the type of the resulting matrix */ -FL_API Tensor identity(const Dim dim, const dtype type = dtype::f32); +FL_API Tensor identity(Dim dim, dtype type = dtype::f32); /** * Return evenly-spaced values in a given interval. Generate values in the - * interval `[start, stop)` steppping each element by the passed step. + * interval `[begin, end)` stepping each element by the passed step. * - * @param[in] start the start of the range - * @param[in] end the end of the range, inclusive + * @param[in] start the start of the range (inclusive) + * @param[in] end the end of the range (exclusive) * @param[in] step the increment for each consecutive value in the range * @param[in] type the dtype of the resulting tensor * @@ -703,10 +677,10 @@ FL_API Tensor identity(const Dim dim, const dtype type = dtype::f32); */ template FL_API Tensor arange( - const T& start, - const T& end, - const T& step = 1, - const dtype type = dtype_traits::ctype + T const& start, + T const& end, + T const& step = 1, + dtype type = dtype_traits::fl_type ); /** @@ -721,7 +695,7 @@ FL_API Tensor arange( * @return a tensor with the given shape with the sequence along the given * dimension, tiled along other dimensions. */ -FL_API Tensor arange(const Shape& shape, const Dim seqDim = 0, const dtype type = dtype::f32); +FL_API Tensor arange(Shape const& shape, Dim seqDim = 0, dtype type = dtype::f32); /** * Creates a sequence with the range `[0, dims.elements())` sequentially in the @@ -738,9 +712,9 @@ FL_API Tensor arange(const Shape& shape, const Dim seqDim = 0, const dtype type * @return */ FL_API Tensor iota( - const Shape& dims, - const Shape& tileDims = { 1 }, - const dtype type = dtype::f32 + Shape const& dims, + Shape const& tileDims = {1}, + dtype type = dtype::f32 ); /************************ Shaping and Indexing *************************/ @@ -752,7 +726,7 @@ FL_API Tensor iota( * @param[in] shape the new shape for the tensor * @return the reshaped tensor */ -FL_API Tensor reshape(const Tensor& tensor, const Shape& shape); +FL_API Tensor reshape(Tensor const& tensor, Shape const& shape); /** * Permute the axes of a tensor. If no arguments are given, reverses the axes of @@ -764,7 +738,7 @@ FL_API Tensor reshape(const Tensor& tensor, const Shape& shape); * argument is not passed, the axes of the input tensor will be reversed. * @return the permuted tensor */ -FL_API Tensor transpose(const Tensor& tensor, const Shape& axes = {}); +FL_API Tensor transpose(Tensor const& tensor, Shape const& axes = {}); /** * Repeat the contents of a tensor a given number of times along specified @@ -775,7 +749,7 @@ FL_API Tensor transpose(const Tensor& tensor, const Shape& axes = {}); * tensor * @return the tiled tensor */ -FL_API Tensor tile(const Tensor& tensor, const Shape& shape); +FL_API Tensor tile(Tensor const& tensor, Shape const& shape); /** * Join or concatenate tensors together along a particular axis. @@ -784,7 +758,7 @@ FL_API Tensor tile(const Tensor& tensor, const Shape& shape); * @param[in] axis the axis along which to concatenate tensors * @return a concatenated tensor */ -FL_API Tensor concatenate(const std::vector& tensors, const unsigned axis = 0); +FL_API Tensor concatenate(std::vector const& tensors, unsigned axis = 0); /** * Join or concatenate tensors together along a particular axis. @@ -794,7 +768,7 @@ FL_API Tensor concatenate(const std::vector& tensors, const unsigned axi * @return a concatenated tensor */ template -Tensor concatenate(unsigned axis, const Ts&... args) { +Tensor concatenate(unsigned axis, Ts const&... args) { std::vector tensors{{args...}}; return concatenate(tensors, axis); } @@ -806,7 +780,7 @@ Tensor concatenate(unsigned axis, const Ts&... args) { * @param[in] tensor input tensor * @return a tensor containing the indices of the nonzero elements */ -FL_API Tensor nonzero(const Tensor& tensor); +FL_API Tensor nonzero(Tensor const& tensor); /// Padding types for the pad operator. enum class PadType { @@ -822,16 +796,15 @@ enum class PadType { * Pad a tensor with zeros. * * @param[in] input the input tensor to pad - * @param[in] padWidths a vector of tuples representing padding (before, after) - * tuples for each axis - * @param[in] type the padding mode with which to pad the tensor - see `PadType` + * @param[in] padWidths padding sizes for each axis with (prepended, appended) respectively + * @param[in] type the padding mode with which to pad the tensor - @ref PadType * * @return the padded tensor */ FL_API Tensor pad( - const Tensor& input, - const std::vector>& padWidths, - const PadType type = PadType::Constant + Tensor const& input, + std::vector> const& padWidths, + PadType type = PadType::Constant ); /************************** Unary Operators ***************************/ @@ -841,8 +814,8 @@ FL_API Tensor pad( * @param[in] tensor the input tensor to negate. * @return a tensor with elements negated. */ -FL_API Tensor negative(const Tensor& tensor); -inline Tensor operator-(const Tensor& tensor) { return negative(tensor); } +FL_API Tensor negative(Tensor const& tensor); +inline Tensor operator-(Tensor const& tensor) { return negative(tensor); } /** * Performs element-wise logical-not on the elements of a tensor @@ -850,8 +823,8 @@ inline Tensor operator-(const Tensor& tensor) { return negative(tensor); } * @param[in] tensor the tensor on which to perform logical not * @return a tensor with element-wise logical not of the input */ -FL_API Tensor logicalNot(const Tensor& tensor); -inline Tensor operator!(const Tensor& tensor) { return logicalNot(tensor); } +FL_API Tensor logicalNot(Tensor const& tensor); +inline Tensor operator!(Tensor const& tensor) { return logicalNot(tensor); } /** * Compute the element-wise exponential of a tensor @@ -859,7 +832,7 @@ inline Tensor operator!(const Tensor& tensor) { return logicalNot(tensor); } * @param[in] tensor the tensor to exponentiate * @return the exponentiated tensor */ -FL_API Tensor exp(const Tensor& tensor); +FL_API Tensor exp(Tensor const& tensor); /** * Compute the element-wise natural logarithm of a tensor @@ -867,7 +840,7 @@ FL_API Tensor exp(const Tensor& tensor); * @param[in] tensor the tensor on which to compute * @return the resulting tensor */ -FL_API Tensor log(const Tensor& tensor); +FL_API Tensor log(Tensor const& tensor); /** * Returns the natural logarithm of one plus the input, element-wise. @@ -875,7 +848,7 @@ FL_API Tensor log(const Tensor& tensor); * @param[in] tensor the tensor on which to compute * @return the resulting tensor */ -FL_API Tensor log1p(const Tensor& tensor); +FL_API Tensor log1p(Tensor const& tensor); /** * Returns the element-wise sine of the input. @@ -883,7 +856,7 @@ FL_API Tensor log1p(const Tensor& tensor); * @param[in] tensor the tensor on which to compute * @return the resulting tensor */ -FL_API Tensor sin(const Tensor& tensor); +FL_API Tensor sin(Tensor const& tensor); /** * Returns the element-wise cosine of the input. @@ -891,7 +864,7 @@ FL_API Tensor sin(const Tensor& tensor); * @param[in] tensor the tensor on which to compute * @return the resulting tensor */ -FL_API Tensor cos(const Tensor& tensor); +FL_API Tensor cos(Tensor const& tensor); /** * Returns the element-wise non-negative square root of the input. @@ -899,7 +872,7 @@ FL_API Tensor cos(const Tensor& tensor); * @param[in] tensor the tensor on which to compute * @return the resulting tensor */ -FL_API Tensor sqrt(const Tensor& tensor); +FL_API Tensor sqrt(Tensor const& tensor); /** * Returns the element-wise hyperbolic tangent of the input. @@ -907,7 +880,7 @@ FL_API Tensor sqrt(const Tensor& tensor); * @param[in] tensor the tensor on which to compute * @return the resulting tensor */ -FL_API Tensor tanh(const Tensor& tensor); +FL_API Tensor tanh(Tensor const& tensor); /** * Returns the element-wise floor of the input. @@ -915,7 +888,7 @@ FL_API Tensor tanh(const Tensor& tensor); * @param[in] tensor the tensor on which to compute the floor * @return the resulting tensor */ -FL_API Tensor floor(const Tensor& tensor); +FL_API Tensor floor(Tensor const& tensor); /** * Returns the element-wise ceiling of the input. @@ -923,7 +896,7 @@ FL_API Tensor floor(const Tensor& tensor); * @param[in] tensor the tensor on which to compute the ceiling * @return the resulting tensor */ -FL_API Tensor ceil(const Tensor& tensor); +FL_API Tensor ceil(Tensor const& tensor); /** * Returns the tensor with element-wise rounding to the nearest integer. @@ -931,7 +904,7 @@ FL_API Tensor ceil(const Tensor& tensor); * @param[in] tensor the input tensor * @return the resulting tensor */ -FL_API Tensor rint(const Tensor& tensor); +FL_API Tensor rint(Tensor const& tensor); /** * Returns the element-wise absolute value of the input. @@ -939,10 +912,10 @@ FL_API Tensor rint(const Tensor& tensor); * @param[in] tensor the tensor on which to compute * @return the resulting tensor */ -FL_API Tensor absolute(const Tensor& tensor); +FL_API Tensor absolute(Tensor const& tensor); // \copydoc absolute -inline Tensor abs(const Tensor& tensor) { return absolute(tensor); } +inline Tensor abs(Tensor const& tensor) { return absolute(tensor); } /** * Returns the element-wise sigmoid the input: @@ -951,7 +924,7 @@ inline Tensor abs(const Tensor& tensor) { return absolute(tensor); } * @param[in] tensor the tensor on which to compute * @return the resulting tensor */ -FL_API Tensor sigmoid(const Tensor& tensor); +FL_API Tensor sigmoid(Tensor const& tensor); /** * Computes the element-wise error function the input: see @@ -960,7 +933,7 @@ FL_API Tensor sigmoid(const Tensor& tensor); * @param[in] tensor the tensor on which to compute * @return ther resulting tensor */ -FL_API Tensor erf(const Tensor& tensor); +FL_API Tensor erf(Tensor const& tensor); /** * Flip a Tensor along a specified dimension. @@ -970,7 +943,7 @@ FL_API Tensor erf(const Tensor& tensor); * * @return the resulting flipped tensor */ -FL_API Tensor flip(const Tensor& tensor, const unsigned dim); +FL_API Tensor flip(Tensor const& tensor, unsigned dim); /** * Clip (limit) the values of a tensor. Given some interval of values, set @@ -988,7 +961,7 @@ FL_API Tensor flip(const Tensor& tensor, const unsigned dim); * clipping * @return a tensor with all values clipped between high and low */ -FL_API Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high); +FL_API Tensor clip(Tensor const& tensor, Tensor const& low, Tensor const& high); /** * Clip (limit) the values of a tensor. Given some interval of values, set @@ -1003,7 +976,7 @@ FL_API Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high); * @param[in] high a scalar to use as the maximum value in clipping * @return a tensor with all values clipped between high and low */ -FL_API Tensor clip(const Tensor& tensor, const Tensor& low, const double& high); +FL_API Tensor clip(Tensor const& tensor, Tensor const& low, double const& high); /** * Clip (limit) the values of a tensor. Given some interval of values, set @@ -1018,7 +991,7 @@ FL_API Tensor clip(const Tensor& tensor, const Tensor& low, const double& high); * clipping * @return a tensor with all values clipped between high and low */ -FL_API Tensor clip(const Tensor& tensor, const double& low, const Tensor& high); +FL_API Tensor clip(Tensor const& tensor, double const& low, Tensor const& high); /** * Clip (limit) the values of a tensor. Given some interval of values, set @@ -1031,7 +1004,7 @@ FL_API Tensor clip(const Tensor& tensor, const double& low, const Tensor& high); * @param[in] high a scalar to use as the maximum value in clipping * @return a tensor with all values clipped between high and low */ -FL_API Tensor clip(const Tensor& tensor, const double& low, const double& high); +FL_API Tensor clip(Tensor const& tensor, double const& low, double const& high); /** * Rolls (or shifts) a tensor by a certain amount along a given axis, moving @@ -1044,7 +1017,7 @@ FL_API Tensor clip(const Tensor& tensor, const double& low, const double& high); * @return a tensor with values shifted by the given amount in a circular * fashion */ -FL_API Tensor roll(const Tensor& tensor, const int shift, const unsigned axis); +FL_API Tensor roll(Tensor const& tensor, int shift, unsigned axis); /** * Returns a boolean tensor which is true where the input tensor was NaN, and @@ -1054,7 +1027,7 @@ FL_API Tensor roll(const Tensor& tensor, const int shift, const unsigned axis); * @return a boolean tensor with true in positions that contained NaN in the * input tensor */ -FL_API Tensor isnan(const Tensor& tensor); +FL_API Tensor isnan(Tensor const& tensor); /** * Returns a boolean tensor which is true where the input tensor was infinity, @@ -1064,7 +1037,7 @@ FL_API Tensor isnan(const Tensor& tensor); * @return a boolean tensor with true in positions that contained Inf in the * input tensor */ -FL_API Tensor isinf(const Tensor& tensor); +FL_API Tensor isinf(Tensor const& tensor); /** * Returns a tensor that contains -1 if an element is less than 0, 0 if an @@ -1074,7 +1047,7 @@ FL_API Tensor isinf(const Tensor& tensor); * @param[in] tensor the input tensor * @return a tensor containing element-wise sign values. */ -FL_API Tensor sign(const Tensor& tensor); +FL_API Tensor sign(Tensor const& tensor); /** * Returns an upper triangular version of the tensor. @@ -1087,7 +1060,7 @@ FL_API Tensor sign(const Tensor& tensor); * @return a copy of the input tensor with elements above the diagonal zeroed * out */ -FL_API Tensor tril(const Tensor& tensor); +FL_API Tensor tril(Tensor const& tensor); /** * Returns an upper triangular version of the tensor. @@ -1100,7 +1073,7 @@ FL_API Tensor tril(const Tensor& tensor); * @return a copy of the input tensor with elements below the diagonal zeroed * out */ -FL_API Tensor triu(const Tensor& tensor); +FL_API Tensor triu(Tensor const& tensor); /** * Conditionally return elements from one of two tensors based on a condition. @@ -1116,7 +1089,7 @@ FL_API Tensor triu(const Tensor& tensor); * @return the resulting tensor that contains elements of x where condition is * true and elements of y where condition is false. */ -FL_API Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y); +FL_API Tensor where(Tensor const& condition, Tensor const& x, Tensor const& y); /** * Conditionally return elements from a tensor or passed scalar based on a @@ -1132,7 +1105,7 @@ FL_API Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y); * @return the resulting tensor that contains elements of x where condition is * true and the scalar value y where the condition is false. */ -FL_API Tensor where(const Tensor& condition, const Tensor& x, const double& y); +FL_API Tensor where(Tensor const& condition, Tensor const& x, double const& y); /** * Conditionally return elements from a scalar or passed tensor based on a @@ -1148,12 +1121,12 @@ FL_API Tensor where(const Tensor& condition, const Tensor& x, const double& y); * @return the resulting tensor that contains elements of x where condition is * true and the scalar value y where the condition is false. */ -FL_API Tensor where(const Tensor& condition, const double& x, const Tensor& y); +FL_API Tensor where(Tensor const& condition, double const& x, Tensor const& y); /*! * Sorting mode for sorting-related functions. */ -enum class SortMode {Descending = 0, Ascending = 1}; +enum class SortMode { Descending = 0, Ascending = 1 }; /** * Get the top-k values and indices from a Tensor. @@ -1170,10 +1143,10 @@ enum class SortMode {Descending = 0, Ascending = 1}; FL_API void topk( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned k, - const Dim axis, - const SortMode sortMode = SortMode::Descending + Tensor const& input, + unsigned k, + Dim axis, + SortMode sortMode = SortMode::Descending ); /** @@ -1184,9 +1157,9 @@ FL_API void topk( * @param[in] sortMode the ordering with which to sort. Defaults to ascending */ FL_API Tensor sort( - const Tensor& input, - const Dim axis, - const SortMode sortMode = SortMode::Ascending + Tensor const& input, + Dim axis, + SortMode sortMode = SortMode::Ascending ); /** @@ -1201,9 +1174,9 @@ FL_API Tensor sort( FL_API void sort( Tensor& values, Tensor& indices, - const Tensor& input, - const Dim axis, - const SortMode sortMode = SortMode::Ascending + Tensor const& input, + Dim axis, + SortMode sortMode = SortMode::Ascending ); /** @@ -1214,9 +1187,9 @@ FL_API void sort( * @param[in] sortMode the ordering with which to sort. Defaults to ascending */ FL_API Tensor argsort( - const Tensor& input, - const Dim axis, - const SortMode sortMode = SortMode::Ascending + Tensor const& input, + Dim axis, + SortMode sortMode = SortMode::Ascending ); /************************** Binary Operators ***************************/ @@ -1280,7 +1253,7 @@ FL_BINARY_OP_DECL(>>, rShift); * @param[in] rhs right hand side tensor for the minimum * @return a tensor containing the minimum values in each tensor */ -FL_API Tensor minimum(const Tensor& lhs, const Tensor& rhs); +FL_API Tensor minimum(Tensor const& lhs, Tensor const& rhs); /** * Returns the element-wise minimum of tensor elements with some scalar. @@ -1290,7 +1263,7 @@ FL_API Tensor minimum(const Tensor& lhs, const Tensor& rhs); * @return a tensor containing the minimum values element-wise with the tensor * and a scalar. */ -FL_API Tensor minimum(const Tensor& lhs, const double& rhs); +FL_API Tensor minimum(Tensor const& lhs, double const& rhs); /** * Returns the element-wise minimum of tensor elements with some scalar. @@ -1300,7 +1273,7 @@ FL_API Tensor minimum(const Tensor& lhs, const double& rhs); * @return a tensor containing the minimum values element-wise with the tensor * and a scalar. */ -FL_API Tensor minimum(const double& lhs, const Tensor& rhs); +FL_API Tensor minimum(double const& lhs, Tensor const& rhs); /** * Returns the element-wise maximum of tensor elements. @@ -1311,7 +1284,7 @@ FL_API Tensor minimum(const double& lhs, const Tensor& rhs); * @param[in] rhs right hand side tensor for the minimum * @return a tensor containing the maximum values in each tensor */ -FL_API Tensor maximum(const Tensor& lhs, const Tensor& rhs); +FL_API Tensor maximum(Tensor const& lhs, Tensor const& rhs); /** * Returns the element-wise maximum of tensor elements with some scalar. @@ -1321,7 +1294,7 @@ FL_API Tensor maximum(const Tensor& lhs, const Tensor& rhs); * @return a tensor containing the maximum values element-wise with the tensor * and a scalar. */ -FL_API Tensor maximum(const Tensor& lhs, const double& rhs); +FL_API Tensor maximum(Tensor const& lhs, double const& rhs); /** * Returns the element-wise maximum of tensor elements with some scalar. @@ -1331,7 +1304,7 @@ FL_API Tensor maximum(const Tensor& lhs, const double& rhs); * @return a tensor containing the maximum values element-wise with the tensor * and a scalar. */ -FL_API Tensor maximum(const double& lhs, const Tensor& rhs); +FL_API Tensor maximum(double const& lhs, Tensor const& rhs); /** * Returns the element-wise exponentiation of tensors; the left hand tensor is @@ -1341,7 +1314,7 @@ FL_API Tensor maximum(const double& lhs, const Tensor& rhs); * @param[in] rhs the exponent tensor * @return a tensor containing the exponentiated values */ -FL_API Tensor power(const Tensor& lhs, const Tensor& rhs); +FL_API Tensor power(Tensor const& lhs, Tensor const& rhs); /** * Returns the element-wise exponentiation of tensors raised to some scalar @@ -1351,7 +1324,7 @@ FL_API Tensor power(const Tensor& lhs, const Tensor& rhs); * @param[in] rhs a scalar exponent * @return a tensor containing the exponentiated values */ -FL_API Tensor power(const Tensor& lhs, const double& rhs); +FL_API Tensor power(Tensor const& lhs, double const& rhs); /** * Returns the element-wise exponentiation of a scalar raised element-wise to @@ -1361,7 +1334,7 @@ FL_API Tensor power(const Tensor& lhs, const double& rhs); * @param[in] rhs the tensor containing exponent values * @return a tensor containing the exponentiated values */ -FL_API Tensor power(const double& lhs, const Tensor& rhs); +FL_API Tensor power(double const& lhs, Tensor const& rhs); /******************************* BLAS ********************************/ @@ -1369,7 +1342,7 @@ FL_API Tensor power(const double& lhs, const Tensor& rhs); * Transformations to apply to Tensors (i.e. matrices) before applying certain * operations (i.e. matmul). */ -enum class MatrixProperty {None = 0, Transpose = 1}; +enum class MatrixProperty { None = 0, Transpose = 1 }; /** * Perform matrix multiplication between two tensors. @@ -1384,8 +1357,8 @@ enum class MatrixProperty {None = 0, Transpose = 1}; * @return an output tensor containing the matrix product. */ FL_API Tensor matmul( - const Tensor& lhs, - const Tensor& rhs, + Tensor const& lhs, + Tensor const& rhs, MatrixProperty lhsProp = MatrixProperty::None, MatrixProperty rhsProp = MatrixProperty::None ); @@ -1404,9 +1377,9 @@ FL_API Tensor matmul( * @return a tensor containing the max(es) */ FL_API Tensor amin( - const Tensor& input, - const std::vector& axes = {}, - const bool keepDims = false + Tensor const& input, + std::vector const& axes = {}, + bool keepDims = false ); /** @@ -1421,9 +1394,9 @@ FL_API Tensor amin( * @return a tensor containing the max(es) */ FL_API Tensor amax( - const Tensor& input, - const std::vector& axes = {}, - const bool keepDims = false + Tensor const& input, + std::vector const& axes = {}, + bool keepDims = false ); /** @@ -1442,9 +1415,9 @@ FL_API Tensor amax( FL_API void min( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims = false + Tensor const& input, + unsigned axis, + bool keepDims = false ); /** @@ -1463,9 +1436,9 @@ FL_API void min( FL_API void max( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims = false + Tensor const& input, + unsigned axis, + bool keepDims = false ); /** @@ -1477,7 +1450,7 @@ FL_API void max( * as singleton dimensions rather than collapsing them * @return a tensor containing the indices of the max values along each axis */ -FL_API Tensor argmax(const Tensor& input, const unsigned axis, const bool keepDims = false); +FL_API Tensor argmax(Tensor const& input, unsigned axis, bool keepDims = false); /** * Return the indices of the minimum values along an axis. @@ -1488,7 +1461,7 @@ FL_API Tensor argmax(const Tensor& input, const unsigned axis, const bool keepDi * as singleton dimensions rather than collapsing them * @return a tensor containing the indices of the max values along each axis */ -FL_API Tensor argmin(const Tensor& input, const unsigned axis, const bool keepDims = false); +FL_API Tensor argmin(Tensor const& input, unsigned axis, bool keepDims = false); /** * Sum of tensor over given axes. If axes is left empty, computes the sum along @@ -1502,9 +1475,9 @@ FL_API Tensor argmin(const Tensor& input, const unsigned axis, const bool keepDi * @return a tensor containing the sum(s) */ FL_API Tensor sum( - const Tensor& input, - const std::vector& axes = {}, - const bool keepDims = false + Tensor const& input, + std::vector const& axes = {}, + bool keepDims = false ); /** @@ -1515,7 +1488,7 @@ FL_API Tensor sum( * @param[in] axis the axis along which to accumulate * @return a tensor of the same shape containing the accumulated sum */ -FL_API Tensor cumsum(const Tensor& input, const unsigned axis); +FL_API Tensor cumsum(Tensor const& input, unsigned axis); /** * Mean of tensor over given axes. If axes is left empty, computes the mean @@ -1529,9 +1502,9 @@ FL_API Tensor cumsum(const Tensor& input, const unsigned axis); * @return a tensor containing the mean(s) */ FL_API Tensor mean( - const Tensor& input, - const std::vector& axes = {}, - const bool keepDims = false + Tensor const& input, + std::vector const& axes = {}, + bool keepDims = false ); /** @@ -1546,9 +1519,9 @@ FL_API Tensor mean( * @return a tensor containing the median(s) */ FL_API Tensor median( - const Tensor& input, - const std::vector& axes = {}, - const bool keepDims = false + Tensor const& input, + std::vector const& axes = {}, + bool keepDims = false ); /** @@ -1564,10 +1537,10 @@ FL_API Tensor median( * @return a tensor containing the variance(s) */ FL_API Tensor var( - const Tensor& input, - const std::vector& axes = {}, - const bool bias = false, - const bool keepDims = false + Tensor const& input, + std::vector const& axes = {}, + bool bias = false, + bool keepDims = false ); /** @@ -1582,9 +1555,9 @@ FL_API Tensor var( * @return a tensor containing the standard deviation(s) */ FL_API Tensor std( - const Tensor& input, - const std::vector& axes = {}, - const bool keepDims = false + Tensor const& input, + std::vector const& axes = {}, + bool keepDims = false ); /** @@ -1599,10 +1572,10 @@ FL_API Tensor std( * @return a tensor containing the norm(s) */ FL_API Tensor norm( - const Tensor& input, - const std::vector& axes = {}, + Tensor const& input, + std::vector const& axes = {}, double p = 2, - const bool keepDims = false + bool keepDims = false ); /** @@ -1619,9 +1592,9 @@ FL_API Tensor norm( * over the entire tensor. */ FL_API Tensor countNonzero( - const Tensor& input, - const std::vector& axes = {}, - const bool keepDims = false + Tensor const& input, + std::vector const& axes = {}, + bool keepDims = false ); /** @@ -1638,12 +1611,21 @@ FL_API Tensor countNonzero( * @return a bool tensor containing axis-wise values denoting truthy values * along that axis in the input tensor. */ -FL_API Tensor any( - const Tensor& input, - const std::vector& axes = {}, - const bool keepDims = false +FL_API Tensor any_of( + Tensor const& input, + std::vector const& axes = {}, + bool keepDims = false ); +/** + * @deprecated use @ref fl::any_of(Tensor const&, std::vector const&, bool) + */ +FL_API inline Tensor any( + Tensor const& input, + std::vector const& axes = {}, + bool keepDims = false +) { return any_of(input, axes, keepDims); } + /** * Checks if all values are true in a tensor along one or more axes; returns * true if all are true and false otherwise. If k axes are passed, returns a @@ -1658,25 +1640,34 @@ FL_API Tensor any( * @return a bool tensor containing axis-wise values with true along * axes that contain only true values. */ -FL_API Tensor all( - const Tensor& input, - const std::vector& axes = {}, - const bool keepDims = false +FL_API Tensor all_of( + Tensor const& input, + std::vector const& axes = {}, + bool keepDims = false ); +/** + * @deprecated use @ref fl::all_of(Tensor const&, std::vector const&, bool) + */ +FL_API inline Tensor all( + Tensor const& input, + std::vector const& axes = {}, + bool keepDims = false +) { return all_of(input, axes, keepDims); } + /************************** Utilities ***************************/ /** * Write a string representation of a tensor to an output stream. */ -FL_API std::ostream& operator<<(std::ostream& ostr, const Tensor& t); +FL_API std::ostream& operator<<(std::ostream& ostr, Tensor const& t); /** * Print a string representation of a tensor to standard out. * * @param[in] tensor the tensor to print */ -FL_API void print(const Tensor& tensor); +FL_API void print(Tensor const& tensor); /** * Returns of two tensors are close. Checks: @@ -1691,15 +1682,15 @@ FL_API void print(const Tensor& tensor); * tensors */ FL_API bool allClose( - const fl::Tensor& a, - const fl::Tensor& b, - const double absTolerance = 1e-5 + fl::Tensor const& a, + fl::Tensor const& b, + double absTolerance = 1e-5 ); /** * @return if a Tensor contains any NaN or Inf values. */ -FL_API bool isInvalidArray(const Tensor& tensor); +FL_API bool isInvalidArray(Tensor const& tensor); /** * Get a string representation of a tensor backend type. @@ -1707,7 +1698,7 @@ FL_API bool isInvalidArray(const Tensor& tensor); * @param[in] type the tensor backend type. * @return a string representing the given tensor backend type. */ -FL_API std::string tensorBackendTypeToString(const TensorBackendType type); +FL_API std::string tensorBackendTypeToString(TensorBackendType type); /** * Write a string representation of a tensor backend type to an output stream. @@ -1716,8 +1707,11 @@ FL_API std::string tensorBackendTypeToString(const TensorBackendType type); * @param[in] type the tensor backend type. * @return the output stream. */ -FL_API std::ostream& operator<<(std::ostream& os, const TensorBackendType type); +FL_API std::ostream& operator<<(std::ostream& os, TensorBackendType type); +} + +namespace fl { /** * Convert a tensor from one type to another. Requires moving the input Tensor * - destroys the resulting tensor and creates another tensor of the desired @@ -1754,13 +1748,13 @@ Tensor to(Tensor&& t) { namespace detail { - bool areTensorTypesEqual(const Tensor& a, const Tensor& b); + bool areTensorTypesEqual(Tensor const& a, Tensor const& b); template bool areTensorTypesEqual( - const Tensor& a, - const Tensor& b, - const Args&... args + Tensor const& a, + Tensor const& b, + Args const&... args ) { return areTensorTypesEqual(a, b) && areTensorTypesEqual(a, args...); } } // namespace detail @@ -1776,3 +1770,67 @@ namespace detail { } } // namespace fl + + +namespace fl { + +template +T Tensor::scalar() const { + if(isEmpty()) + throw std::logic_error{"Tensor::scalar called on empty tensor"}; + + if(type() != dtype_traits::fl_type) + throw std::logic_error{ + std::format( + "Tensor::scalar: requested type of {} does not match tensor type {}", + dtype_traits::name(), + to_string(type()) + ) + }; + T out; + scalar_impl(&out); + return out; +} + +template +[[nodiscard]] T Tensor::asScalar() const { + if(type() == dtype::f16) + return fl::dispatch_dtype(dtype::f32, [t = asType(dtype::f32)] { return t.scalar(); }); + + + return fl::dispatch_dtype(type(), [&t = *this] { return t.scalar(); }); +} + + +template +T* Tensor::device() const { return static_cast(device()); } + +template +std::vector Tensor::host() const { + if(bytes() % sizeof(T) != 0) + throw std::logic_error{"Tensor data can't be mapped to an array of T"}; + + std::vector data(bytes() / sizeof(T)); + this->raw_host(data.data()); + + return data; +} + +template +void Tensor::host(std::span dst) const { + auto const size = bytes(); + auto const dstSize = dst.size_bytes(); + + if(size % sizeof(T) != 0) + throw std::logic_error{ + std::format("Tensor data ({} bytes) can't be mapped to an array of T (size {})", sizeof(T)) + }; + + if(size > dstSize) + throw std::logic_error{ + std::format("Tensor data ({} bytes) doesn't fit span of T with {} bytes", size, dstSize) + }; + + raw_host(dst.data()); +} +} diff --git a/flashlight/fl/tensor/Traits.h b/flashlight/fl/tensor/Traits.h new file mode 100644 index 0000000..0ad1ffa --- /dev/null +++ b/flashlight/fl/tensor/Traits.h @@ -0,0 +1,154 @@ +/* + * SPDX-License-Identifier: MIT + * + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) + */ +#pragma once +#include "DTypes.h" +#include "Meta.h" + +namespace fl { + +namespace detail { + template + struct FL_API dtype_traits_base { + static constexpr dtype fl_type = [] { + auto group = dtype_group_from_type(); + + for(size_t i = begin_of(group); i < end_of(group); ++i) + if(auto type = static_cast(i); size_of(type) == sizeof(T)) + return type; + + throw std::logic_error{"unknown type size requested"}; + }(); + + using base_type = T; + }; +} + +template +struct dtype_traits; + + +#define FL_TYPE_TRAIT(T) \ + template<> \ + struct FL_API dtype_traits : detail::dtype_traits_base { \ + static constexpr std::string_view name() { \ + return #T; \ + } \ + /* deprecated, use @ref name() instead */ \ + static constexpr const char* getName() { \ + return #T; \ + } \ + }; + +// using fundamental types instead of fixed to avoid missing templates when multiple fundamentals are equal size + +FL_TYPE_TRAIT(float); +FL_TYPE_TRAIT(double); +FL_TYPE_TRAIT(int); +FL_TYPE_TRAIT(unsigned int); +FL_TYPE_TRAIT(char); +FL_TYPE_TRAIT(unsigned char); +FL_TYPE_TRAIT(long); +FL_TYPE_TRAIT(unsigned long); +FL_TYPE_TRAIT(long long); +FL_TYPE_TRAIT(unsigned long long); +FL_TYPE_TRAIT(bool); +FL_TYPE_TRAIT(short); +FL_TYPE_TRAIT(unsigned short); + +namespace detail { + //TODO add c++23 float16_t once version is bumped + using fundamental_types = std::tuple< + float, + double, + bool, + char, + unsigned char, + short, + unsigned short, + int, + unsigned int, + long, + unsigned long, + long long, + unsigned long long>; +} + + +} // namespace fl + +namespace fl { +/** + * @brief Checks if T is any of @ref fl::fundamental_types + * @tparam T type to check + */ +template +concept fundamental_type = std::apply( + [](Ts...) { return dev::is_any_of; }, + detail::fundamental_types{} +); + + +/** + * @brief Accepts if the type would resolve to any @ref fl::fundamental_types in overload resolution. + * + * Let `f(X x)` be instantiated functions for `X` in @ref fl::fundamental_types . + * The concept is satisfied if and only if `f(std::declval())` is well-formed and unambiguous. + * + * @tparam T type to check resolution for + */ +template +concept fundamental_type_compatible = fundamental_type || std::apply( + [](Ts...) { return resolves_to_any_of; }, + detail::fundamental_types{} +); + + + +} + + +namespace fl { +//TODO not really happy with this, return type deduction should be possible + +/** + * @brief Runtime matches dtype with a type from the list and calls the templated function + * @tparam R function return type + * @tparam TypeList type list to apply + * @tparam Func templated function type + * @param type dtype to runtime dispatch + * @param func templated function to call + * @return result of func() where T corresponds to the dtype passed + */ +template +R dispatch_dtype(fl::dtype type, Func&& func) { + std::conditional_t, int, R> result{}; + bool found = false; + + auto try_dispatch = [&found, &result, type, &func]() { + if(!found && fl::dtype_traits::fl_type == type) { + if constexpr(std::is_void_v) + func.template operator()(); + else + result = func.template operator()(); + found = true; + } + }; + + [&](std::tuple) { (try_dispatch.template operator()(), ...); }( + TypeList{} + ); + + if(!found) + throw std::invalid_argument("Unsupported dtype for dispatch"); + + // C++17 feature: only return if R isn't void + if constexpr(!std::is_void_v) + return result; + else + return; +} +} diff --git a/flashlight/fl/tensor/Types.cpp b/flashlight/fl/tensor/Types.cpp deleted file mode 100644 index 8625593..0000000 --- a/flashlight/fl/tensor/Types.cpp +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "flashlight/fl/tensor/Types.h" - -#include -#include - -namespace fl { - -const std::unordered_map kTypeToString = { - {dtype::f16, "f16"}, - {dtype::f32, "f32"}, - {dtype::f64, "f64"}, - {dtype::b8, "b8"}, - {dtype::s16, "s16"}, - {dtype::s32, "s32"}, - {dtype::s64, "s64"}, - {dtype::u8, "u8"}, - {dtype::u16, "u16"}, - {dtype::u32, "u32"}, - {dtype::u64, "u64"}, -}; - -const std::unordered_map kStringToType = { - {"f16", dtype::f16}, - {"f32", dtype::f32}, - {"f64", dtype::f64}, - {"b8", dtype::b8}, - {"s16", dtype::s16}, - {"s32", dtype::s32}, - {"s64", dtype::s64}, - {"u8", dtype::u8}, - {"u16", dtype::u16}, - {"u32", dtype::u32}, - {"u64", dtype::u64}, -}; - -size_t getTypeSize(dtype type) { - switch(type) { - case dtype::f16: - return sizeof(float) / 2; - case dtype::f32: - return sizeof(float); - case dtype::f64: - return sizeof(double); - case dtype::b8: - return sizeof(unsigned char); - case dtype::s16: - return sizeof(short); - case dtype::s64: - return sizeof(long long); - case dtype::s32: - return sizeof(int); - case dtype::u8: - return sizeof(unsigned char); - case dtype::u16: - return sizeof(unsigned short); - case dtype::u32: - return sizeof(unsigned); - case dtype::u64: - return sizeof(unsigned long long); - default: - throw std::invalid_argument("getTypeSize - invalid type queried."); - } -} - -const std::string& dtypeToString(dtype type) { return kTypeToString.at(type); } - -fl::dtype stringToDtype(const std::string& string) { - if(kStringToType.find(string) != kStringToType.end()) - return kStringToType.at(string); - throw std::invalid_argument("stringToDtype: Invalid input type: " + string); -} - -std::ostream& operator<<(std::ostream& ostr, const dtype& s) { - ostr << dtypeToString(s); - return ostr; -} - -} // namespace fl diff --git a/flashlight/fl/tensor/Types.h b/flashlight/fl/tensor/Types.h index 16f201a..41e1f31 100644 --- a/flashlight/fl/tensor/Types.h +++ b/flashlight/fl/tensor/Types.h @@ -1,86 +1,10 @@ /* - * Copyright (c) Facebook, Inc. 6and its affiliates. + * SPDX-License-Identifier: MIT * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE) + * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE) */ - #pragma once -#include -#include - -#include "flashlight/fl/common/Defines.h" - -namespace fl { - -enum class dtype { - f16 = 0, // 16-bit float - f32 = 1, // 32-bit float - f64 = 2, // 64-bit float - b8 = 3, // 8-bit boolean - s16 = 4, // 16-bit signed integer - s32 = 5, // 32-bit signed integer - s64 = 6, // 64-bit signed integer - u8 = 7, // 8-bit unsigned integer - u16 = 8, // 16-bit unsigned integer - u32 = 9, // 32-bit unsigned integer - u64 = 10 // 64-bit unsigned integer - // TODO: add support for complex-valued tensors? (AF) -}; - -/** - * Returns the size of the type in bytes. - * - * @param[in] type the input type to query. - */ -FL_API size_t getTypeSize(dtype type); - -/** - * Convert a dtype to its string representation. - */ -FL_API const std::string& dtypeToString(dtype type); - -/** - * Converts string to a Flashlight dtype - * - * @param[in] string type name as a string. - * - * @return returns the corresponding Flashlight dtype - */ -FL_API fl::dtype stringToDtype(const std::string& string); - -/** - * Write a type's string representation to an output stream. - */ -FL_API std::ostream& operator<<(std::ostream& ostr, const dtype& s); - -template -struct dtype_traits; - -#define FL_TYPE_TRAIT(BASE_TYPE, DTYPE, CONSTANT_TYPE, STRING_NAME) \ - template<> \ - struct FL_API dtype_traits { \ - static const dtype fl_type = DTYPE; /* corresponding dtype */ \ - static const dtype ctype = CONSTANT_TYPE; /* constant init type */ \ - typedef BASE_TYPE base_type; \ - static const char* getName() { \ - return STRING_NAME; \ - } \ - } - -FL_TYPE_TRAIT(float, dtype::f32, dtype::f32, "float"); -FL_TYPE_TRAIT(double, dtype::f64, dtype::f32, "double"); -FL_TYPE_TRAIT(int, dtype::s32, dtype::s32, "int"); -FL_TYPE_TRAIT(unsigned, dtype::u32, dtype::u32, "unsigned int"); -FL_TYPE_TRAIT(char, dtype::b8, dtype::s32, "char"); -FL_TYPE_TRAIT(unsigned char, dtype::u8, dtype::u32, "unsigned char"); -FL_TYPE_TRAIT(long, dtype::s64, dtype::s32, "long int"); -FL_TYPE_TRAIT(unsigned long, dtype::u64, dtype::u32, "unsigned long"); -FL_TYPE_TRAIT(long long, dtype::s64, dtype::s64, "long long"); -FL_TYPE_TRAIT(unsigned long long, dtype::u64, dtype::u64, "unsigned long long"); -FL_TYPE_TRAIT(bool, dtype::u8, dtype::u8, "bool"); -FL_TYPE_TRAIT(short, dtype::s16, dtype::s16, "short"); -FL_TYPE_TRAIT(unsigned short, dtype::u16, dtype::u16, "short"); - -} // namespace fl +#include "DTypes.h" +#include "Traits.h" diff --git a/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp b/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp index c3c5cea..76895e3 100644 --- a/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp +++ b/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp @@ -14,11 +14,11 @@ namespace fl { namespace detail { void advancedIndex( - const af::array& inp, - const af::dim4& idxStart, - const af::dim4& idxEnd, - const af::dim4& outDims, - const std::vector& idxArr, + af::array const& inp, + af::dim4 const& idxStart, + af::dim4 const& idxEnd, + af::dim4 const& outDims, + std::vector const& idxArr, af::array& out ) { throw std::runtime_error("gradAdvancedIndex not implemented for cpu"); } diff --git a/flashlight/fl/tensor/backend/af/AdvancedIndex.cu b/flashlight/fl/tensor/backend/af/AdvancedIndex.cu index 171b528..1c4ca71 100644 --- a/flashlight/fl/tensor/backend/af/AdvancedIndex.cu +++ b/flashlight/fl/tensor/backend/af/AdvancedIndex.cu @@ -19,7 +19,7 @@ #define GRID_SIZE 32 #define BLOCK_SIZE 256 -const std::unordered_set validIndexTypes { +std::unordered_set const validIndexTypes { af::dtype::s32, af::dtype::s64, af::dtype::u32, @@ -29,17 +29,17 @@ const std::unordered_set validIndexTypes { template __global__ void advancedIndexKernel( - const Float* inp, - const dim_t* idxStart, - const dim_t* idxEnd, - const dim_t* outDims, - const dim_t* idxArr, + Float const* inp, + ::dim_t const* idxStart, + ::dim_t const* idxEnd, + ::dim_t const* outDims, + ::dim_t const* idxArr, Float* out ) { // Compute striding information for // the input and output tensors - dim_t dims[4], strides[4]; - dim_t outStrides[4]; + ::dim_t dims[4], strides[4]; + ::dim_t outStrides[4]; for(int i = 0; i < 4; i++) dims[i] = idxEnd[i] - idxStart[i]; strides[0] = 1; @@ -53,20 +53,20 @@ __global__ void advancedIndexKernel( // Map CUDA thread to an element in the input array for( - dim_t tid = threadIdx.x + blockIdx.x * BLOCK_SIZE; + ::dim_t tid = threadIdx.x + blockIdx.x * BLOCK_SIZE; tid < (strides[3] * dims[3]); tid += (GRID_SIZE * BLOCK_SIZE) ) { // Compute input array index for CUDA thread - dim_t index[4]; - dim_t cursor = tid; + ::dim_t index[4]; + ::dim_t cursor = tid; for(int i = 3; i >= 0; i--) { index[i] = cursor / strides[i]; cursor = cursor % strides[i]; } - dim_t inpIdx = tid; - dim_t outIdx = 0; + ::dim_t inpIdx = tid; + ::dim_t outIdx = 0; for(int i = 0; i < 4; i++) { // If indexing array specified, use it if(idxArr[i]) { @@ -82,25 +82,27 @@ __global__ void advancedIndexKernel( } namespace fl { +using af_dim_t = ::dim_t; + namespace detail { void advancedIndex( - const af::array& inp, - const af::dim4& idxStart, - const af::dim4& idxEnd, - const af::dim4& outDims, - const std::vector& idxArr, + af::array const& inp, + af::dim4 const& idxStart, + af::dim4 const& idxEnd, + af::dim4 const& outDims, + std::vector const& idxArr, af::array& out ) { auto inpType = inp.type(); auto outType = out.type(); if((inpType != af::dtype::f32) && (inpType != af::dtype::f16)) - throw std::invalid_argument("Input type must be f16/f32"); + throw std::invalid_argument{"Input type must be f16/f32"}; if((outType != af::dtype::f32) && (outType != af::dtype::f16)) - throw std::invalid_argument("Output type must be f16/f32"); + throw std::invalid_argument{"Output type must be f16/f32"}; if(idxArr.size() != 4) - throw std::invalid_argument("Index array vector must be length 4"); + throw std::invalid_argument{"Index array vector must be length 4"}; af::dim4 idxPtr; // Extract raw device pointers for dimensions @@ -114,18 +116,18 @@ namespace detail { continue; } if(validIndexTypes.find(idxArr[i].type()) == validIndexTypes.end()) - throw std::invalid_argument( + throw std::invalid_argument{ "Index type must be one of s32/s64/u32/u64, observed type is " + std::to_string(idxArr[i].type()) - ); + }; idxTypes.push_back(idxArr[i].type()); - idxPtr[i] = (dim_t) (idxArr[i].device()); + idxPtr[i] = reinterpret_cast(idxArr[i].device()); } for(int i = 0; i + 1 < idxTypes.size(); i++) if(idxTypes[i] != idxTypes[i + 1]) - throw std::invalid_argument( + throw std::invalid_argument{ "Index type must be the same across all dimensions" - ); + }; af::array inpCast = inp; af::array outCast = out; @@ -136,10 +138,10 @@ namespace detail { void* inpRawPtr = inpCast.device(); void* outRawPtr = outCast.device(); - af::array arrIdxPtr(4, idxPtr.get()); - af::array arrIdxEnd(4, idxEnd.get()); - af::array arrIdxStart(4, idxStart.get()); - af::array arrOutDims(4, outDims.get()); + af::array arrIdxPtr{4, idxPtr.get()}; + af::array arrIdxEnd{4, idxEnd.get()}; + af::array arrIdxStart{4, idxStart.get()}; + af::array arrOutDims{4, outDims.get()}; void* arrIdxStartDev = arrIdxStart.device(); void* arrIdxEndDev = arrIdxEnd.device(); void* arrOutDimsDev = arrOutDims.device(); @@ -148,35 +150,35 @@ namespace detail { cudaStream_t stream = afcu::getStream(af::getDevice()); if(idxTypes.size() == 0 || idxTypes[0] == af::dtype::s32) advancedIndexKernel << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( - static_cast(inpRawPtr), - static_cast(arrIdxStartDev), - static_cast(arrIdxEndDev), - static_cast(arrOutDimsDev), - static_cast(arrIdxPtrDev), + static_cast(inpRawPtr), + static_cast(arrIdxStartDev), + static_cast(arrIdxEndDev), + static_cast(arrOutDimsDev), + static_cast(arrIdxPtrDev), static_cast(outRawPtr)); else if(idxTypes[0] == af::dtype::s64) advancedIndexKernel << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( - static_cast(inpRawPtr), - static_cast(arrIdxStartDev), - static_cast(arrIdxEndDev), - static_cast(arrOutDimsDev), - static_cast(arrIdxPtrDev), + static_cast(inpRawPtr), + static_cast(arrIdxStartDev), + static_cast(arrIdxEndDev), + static_cast(arrOutDimsDev), + static_cast(arrIdxPtrDev), static_cast(outRawPtr)); else if(idxTypes[0] == af::dtype::u32) advancedIndexKernel << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( - static_cast(inpRawPtr), - static_cast(arrIdxStartDev), - static_cast(arrIdxEndDev), - static_cast(arrOutDimsDev), - static_cast(arrIdxPtrDev), + static_cast(inpRawPtr), + static_cast(arrIdxStartDev), + static_cast(arrIdxEndDev), + static_cast(arrOutDimsDev), + static_cast(arrIdxPtrDev), static_cast(outRawPtr)); else if(idxTypes[0] == af::dtype::u64) advancedIndexKernel << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( - static_cast(inpRawPtr), - static_cast(arrIdxStartDev), - static_cast(arrIdxEndDev), - static_cast(arrOutDimsDev), - static_cast(arrIdxPtrDev), + static_cast(inpRawPtr), + static_cast(arrIdxStartDev), + static_cast(arrIdxEndDev), + static_cast(arrOutDimsDev), + static_cast(arrIdxPtrDev), static_cast(outRawPtr)); else throw std::invalid_argument("Index type must be one of s32/s64/u32/u64"); @@ -191,7 +193,7 @@ namespace detail { arrIdxEnd.unlock(); arrOutDims.unlock(); arrIdxPtr.unlock(); - for(const auto& arr : idxArr) + for(auto const& arr : idxArr) arr.unlock(); out = outCast; diff --git a/flashlight/fl/tensor/backend/af/AdvancedIndex.h b/flashlight/fl/tensor/backend/af/AdvancedIndex.h index 5155df1..0fb9138 100644 --- a/flashlight/fl/tensor/backend/af/AdvancedIndex.h +++ b/flashlight/fl/tensor/backend/af/AdvancedIndex.h @@ -31,11 +31,11 @@ namespace detail { * operator */ void advancedIndex( - const af::array& inp, - const af::dim4& idxStart, - const af::dim4& idxEnd, - const af::dim4& outDims, - const std::vector& idxArr, + af::array const& inp, + af::dim4 const& idxStart, + af::dim4 const& idxEnd, + af::dim4 const& outDims, + std::vector const& idxArr, af::array& out ); diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp b/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp index 610004b..4404850 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp @@ -15,12 +15,12 @@ namespace fl { Tensor ArrayFireBackend::matmul( - const Tensor& lhs, - const Tensor& rhs, + Tensor const& lhs, + Tensor const& rhs, MatrixProperty lhsProp, MatrixProperty rhsProp ) { - unsigned numDims = std::max(lhs.ndim(), rhs.ndim()); + auto numDims = std::max(lhs.ndim(), rhs.ndim()); if((lhs.ndim() == 1 || rhs.ndim() == 1) && numDims > 1) numDims -= 1; diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp b/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp index 968bf92..9f3bcc6 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp @@ -36,12 +36,12 @@ namespace fl { namespace { -// Get the stream associated with given device in the given map; if it's not in -// the map, initialize it (by wrapping or creating) and put it into the map. - const Stream& getOrWrapAfDeviceStream( - const int afId, - const int nativeId, - std::unordered_map>& afIdToStream + // Get the stream associated with given device in the given map; if it's not in + // the map, initialize it (by wrapping or creating) and put it into the map. + Stream const& getOrWrapAfDeviceStream( + int const afId, + int const nativeId, + std::unordered_map>& afIdToStream ) { auto iter = afIdToStream.find(afId); if(iter != afIdToStream.end()) @@ -51,7 +51,7 @@ namespace { auto resIter = afIdToStream.emplace(afId, ArrayFireCPUStream::create()); return *resIter.first->second; #elif FL_ARRAYFIRE_USE_CUDA - const cudaStream_t cudaNativeStream = afcu::getStream(afId); + cudaStream_t const cudaNativeStream = afcu::getStream(afId); auto resIter = afIdToStream.emplace( afId, CUDAStream::wrapUnmanaged(nativeId, cudaNativeStream) @@ -92,22 +92,22 @@ ArrayFireBackend::ArrayFireBackend() { idToNativeId_[id] = nativeId; } - const auto& manager = DeviceManager::getInstance(); + auto const& manager = DeviceManager::getInstance(); // This callback ensures consistency of AF internal state on active device. // Capturing by value to avoid destructor race hazard for static objects. - const auto setActiveCallback = [nativeIdToId = nativeIdToId_, + auto const setActiveCallback = [nativeIdToId = nativeIdToId_, afIdToStream = afIdToStream_](int nativeId) { - auto afId = nativeIdToId.at(nativeId); - af::setDevice(afId); - // this is the latest point we can lazily wrap the AF stream, which may get - // lazily intialized anytime in AF internally, e.g., via tensor computation. - getOrWrapAfDeviceStream(afId, nativeId, *afIdToStream); - }; + auto afId = nativeIdToId.at(nativeId); + af::setDevice(afId); + // this is the latest point we can lazily wrap the AF stream, which may get + // lazily intialized anytime in AF internally, e.g., via tensor computation. + getOrWrapAfDeviceStream(afId, nativeId, *afIdToStream); + }; #if FL_ARRAYFIRE_USE_CPU auto& device = manager.getActiveDevice(DeviceType::x64); device.addSetActiveCallback(setActiveCallback); #elif FL_ARRAYFIRE_USE_CUDA - const auto deviceCount = manager.getDeviceCount(DeviceType::CUDA); + auto const deviceCount = manager.getDeviceCount(DeviceType::CUDA); for(unsigned nativeId = 0; nativeId < deviceCount; nativeId++) { auto& device = manager.getDevice(DeviceType::CUDA, nativeId); device.addSetActiveCallback(setActiveCallback); @@ -127,16 +127,14 @@ ArrayFireBackend& ArrayFireBackend::getInstance() { return instance; } -TensorBackendType ArrayFireBackend::backendType() const { - return TensorBackendType::ArrayFire; -} +TensorBackendType ArrayFireBackend::backendType() const { return TensorBackendType::ArrayFire; } /* -------------------------- Compute Functions -------------------------- */ -void ArrayFireBackend::eval(const Tensor& tensor) { af::eval(toArray(tensor)); } +void ArrayFireBackend::eval(Tensor const& tensor) { af::eval(toArray(tensor)); } -const Stream& ArrayFireBackend::getStreamOfArray( - const af::array& arr +Stream const& ArrayFireBackend::getStreamOfArray( + af::array const& arr ) { // TODO once we enforce integrate Device::setDevice into fl::setDevice, each // array's stream should always be wrapped already (via setDevice callback). @@ -148,21 +146,19 @@ const Stream& ArrayFireBackend::getStreamOfArray( return getOrWrapAfDeviceStream(afId, nativeId, *afIdToStream_); } -bool ArrayFireBackend::supportsDataType(const fl::dtype& dtype) const { +bool ArrayFireBackend::supportsDataType(fl::dtype const& dtype) const { switch(dtype) { - case fl::dtype::f16: - return af::isHalfAvailable(af::getDevice()) - && // f16 isn't [yet] supported with the CPU backend per onednn - // limitations - !FL_BACKEND_CPU; - default: - return true; + case fl::dtype::f16: return af::isHalfAvailable(af::getDevice()) + && // f16 isn't [yet] supported with the CPU backend per onednn + // limitations + !FL_BACKEND_CPU; + default: return true; } } void ArrayFireBackend::getMemMgrInfo( - const char* msg, - const int nativeDeviceId, + char const* msg, + int const nativeDeviceId, std::ostream* ostream ) { int deviceId = nativeIdToId_.at(nativeDeviceId); @@ -187,14 +183,14 @@ void ArrayFireBackend::setMemMgrLogStream(std::ostream* stream) { curMemMgr->setLogStream(stream); } -void ArrayFireBackend::setMemMgrLoggingEnabled(const bool enabled) { +void ArrayFireBackend::setMemMgrLoggingEnabled(bool const enabled) { auto* curMemMgr = fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); if(curMemMgr) curMemMgr->setLoggingEnabled(enabled); } -void ArrayFireBackend::setMemMgrFlushInterval(const size_t interval) { +void ArrayFireBackend::setMemMgrFlushInterval(size_t const interval) { auto* curMemMgr = fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); if(curMemMgr) @@ -203,16 +199,16 @@ void ArrayFireBackend::setMemMgrFlushInterval(const size_t interval) { /* -------------------------- Rand Functions -------------------------- */ -void ArrayFireBackend::setSeed(const int seed) { af::setSeed(seed); } +void ArrayFireBackend::setSeed(int const seed) { af::setSeed(seed); } -Tensor ArrayFireBackend::randn(const Shape& shape, dtype type) { +Tensor ArrayFireBackend::randn(Shape const& shape, dtype type) { return toTensor( af::randn(detail::flToAfDims(shape), detail::flToAfType(type)), shape.ndim() ); } -Tensor ArrayFireBackend::rand(const Shape& shape, dtype type) { +Tensor ArrayFireBackend::rand(Shape const& shape, dtype type) { return toTensor( af::randu(detail::flToAfDims(shape), detail::flToAfType(type)), shape.ndim() @@ -257,28 +253,31 @@ AF_BACKEND_CREATE_FUN_LITERAL_DEF(const bool&); AF_BACKEND_CREATE_FUN_LITERAL_DEF(const short&); AF_BACKEND_CREATE_FUN_LITERAL_DEF(const unsigned short&); -Tensor ArrayFireBackend::identity(const Dim dim, const dtype type) { +Tensor ArrayFireBackend::identity(Dim const dim, dtype const type) { return toTensor( - af::identity({dim, dim}, detail::flToAfType(type)), /* numDims = */ + af::identity({dim, dim}, detail::flToAfType(type)), + /* numDims = */ 2 ); } Tensor ArrayFireBackend::arange( - const Shape& shape, - const Dim seqDim, - const dtype type + Shape const& shape, + Dim const seqDim, + dtype const type ) { + auto const afType = detail::flToAfType(type); + return toTensor( - af::range(detail::flToAfDims(shape), seqDim, detail::flToAfType(type)), + af::range(detail::flToAfDims(shape), seqDim, afType), shape.ndim() ); } Tensor ArrayFireBackend::iota( - const Shape& dims, - const Shape& tileDims, - const dtype type + Shape const& dims, + Shape const& tileDims, + dtype const type ) { return toTensor( af::iota( @@ -286,14 +285,15 @@ Tensor ArrayFireBackend::iota( detail::flToAfDims(tileDims), detail::flToAfType(type) ), - /* numDims = */ std::max(dims.ndim(), tileDims.ndim()) + /* numDims = */ + std::max(dims.ndim(), tileDims.ndim()) ); } Tensor ArrayFireBackend::where( - const Tensor& condition, - const Tensor& x, - const Tensor& y + Tensor const& condition, + Tensor const& x, + Tensor const& y ) { Tensor orig = x; af::replace(toArray(orig), toArray(condition), toArray(y)); @@ -303,10 +303,10 @@ Tensor ArrayFireBackend::where( void ArrayFireBackend::topk( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned k, - const Dim axis, - const SortMode sortMode + Tensor const& input, + unsigned const k, + Dim const axis, + SortMode const sortMode ) { if(axis != 0) throw std::invalid_argument( @@ -327,9 +327,9 @@ void ArrayFireBackend::topk( } Tensor ArrayFireBackend::sort( - const Tensor& input, - const Dim axis, - const SortMode sortMode + Tensor const& input, + Dim const axis, + SortMode const sortMode ) { if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) throw std::invalid_argument( @@ -351,9 +351,9 @@ Tensor ArrayFireBackend::sort( void ArrayFireBackend::sort( Tensor& values, Tensor& indices, - const Tensor& input, - const Dim axis, - const SortMode sortMode + Tensor const& input, + Dim const axis, + SortMode const sortMode ) { if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) throw std::invalid_argument( @@ -374,9 +374,9 @@ void ArrayFireBackend::sort( } Tensor ArrayFireBackend::argsort( - const Tensor& input, - const Dim axis, - const SortMode sortMode + Tensor const& input, + Dim const axis, + SortMode const sortMode ) { if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) throw std::invalid_argument( @@ -395,5 +395,5 @@ Tensor ArrayFireBackend::argsort( return toTensor(std::move(indices), input.ndim()); } -void ArrayFireBackend::print(const Tensor& tensor) { af::print("ArrayFireTensor", toArray(tensor)); } +void ArrayFireBackend::print(Tensor const& tensor) { af::print("ArrayFireTensor", toArray(tensor)); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBackend.h b/flashlight/fl/tensor/backend/af/ArrayFireBackend.h index 698fe33..034883a 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBackend.h +++ b/flashlight/fl/tensor/backend/af/ArrayFireBackend.h @@ -36,10 +36,10 @@ class ArrayFireBackend : public TensorBackend { // keep track of the individual active stream on each ArrayFire device // NOTE using a `shared_ptr` to allow its capture in setActive callback; // see constructor for details. - std::shared_ptr>> + std::shared_ptr>> afIdToStream_{ std::make_shared< - std::unordered_map>>() + std::unordered_map>>() }; // Intentionally private. Only one instance should exist/it should be accessed @@ -53,31 +53,31 @@ class ArrayFireBackend : public TensorBackend { // No copy or move construction or assignment ArrayFireBackend(ArrayFireBackend&&) = delete; - ArrayFireBackend(const ArrayFireBackend&) = delete; + ArrayFireBackend(ArrayFireBackend const&) = delete; ArrayFireBackend& operator=(ArrayFireBackend&&) = delete; - ArrayFireBackend& operator=(const ArrayFireBackend&) = delete; + ArrayFireBackend& operator=(ArrayFireBackend const&) = delete; /* -------------------------- Compute Functions -------------------------- */ - void eval(const Tensor& tensor) override; + void eval(Tensor const& tensor) override; /** * Return the stream from which the given array was created. * * @return an immutable reference to the stream from which `arr` was created. */ - const Stream& getStreamOfArray(const af::array& arr); - bool supportsDataType(const fl::dtype& dtype) const override; + Stream const& getStreamOfArray(af::array const& arr); + bool supportsDataType(fl::dtype const& dtype) const override; // Memory management - void getMemMgrInfo(const char* msg, const int nativeDeviceId, std::ostream* ostream) + void getMemMgrInfo(char const* msg, int nativeDeviceId, std::ostream* ostream) override; void setMemMgrLogStream(std::ostream* stream) override; - void setMemMgrLoggingEnabled(const bool enabled) override; - void setMemMgrFlushInterval(const size_t interval) override; + void setMemMgrLoggingEnabled(bool enabled) override; + void setMemMgrFlushInterval(size_t interval) override; /* -------------------------- Rand Functions -------------------------- */ - void setSeed(const int seed) override; - Tensor randn(const Shape& shape, dtype type) override; - Tensor rand(const Shape& shape, dtype type) override; + void setSeed(int seed) override; + Tensor randn(Shape const& shape, dtype type) override; + Tensor rand(Shape const& shape, dtype type) override; /* --------------------------- Tensor Operators --------------------------- */ /******************** Tensor Creation Functions ********************/ @@ -99,71 +99,71 @@ class ArrayFireBackend : public TensorBackend { AF_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned short&); #undef AF_BACKEND_CREATE_FUN_LITERAL_DECL - Tensor identity(const Dim dim, const dtype type) override; - Tensor arange(const Shape& shape, const Dim seqDim, const dtype type) + Tensor identity(Dim dim, dtype type) override; + Tensor arange(Shape const& shape, Dim seqDim, dtype type) override; - Tensor iota(const Shape& dims, const Shape& tileDims, const dtype type) + Tensor iota(Shape const& dims, Shape const& tileDims, dtype type) override; /************************ Shaping and Indexing *************************/ - Tensor reshape(const Tensor& tensor, const Shape& shape) override; - Tensor transpose(const Tensor& tensor, const Shape& axes /* = {} */) override; - Tensor tile(const Tensor& tensor, const Shape& shape) override; - Tensor concatenate(const std::vector& tensors, const unsigned axis) + Tensor reshape(Tensor const& tensor, Shape const& shape) override; + Tensor transpose(Tensor const& tensor, Shape const& axes /* = {} */) override; + Tensor tile(Tensor const& tensor, Shape const& shape) override; + Tensor concatenate(std::vector const& tensors, unsigned axis) override; - Tensor nonzero(const Tensor& tensor) override; + Tensor nonzero(Tensor const& tensor) override; Tensor pad( - const Tensor& input, - const std::vector>& padWidths, - const PadType type + Tensor const& input, + std::vector> const& padWidths, + PadType type ) override; /************************** Unary Operators ***************************/ - Tensor exp(const Tensor& tensor) override; - Tensor log(const Tensor& tensor) override; - Tensor negative(const Tensor& tensor) override; - Tensor logicalNot(const Tensor& tensor) override; - Tensor log1p(const Tensor& tensor) override; - Tensor sin(const Tensor& tensor) override; - Tensor cos(const Tensor& tensor) override; - Tensor sqrt(const Tensor& tensor) override; - Tensor tanh(const Tensor& tensor) override; - Tensor floor(const Tensor& tensor) override; - Tensor ceil(const Tensor& tensor) override; - Tensor rint(const Tensor& tensor) override; - Tensor absolute(const Tensor& tensor) override; - Tensor sigmoid(const Tensor& tensor) override; - Tensor erf(const Tensor& tensor) override; - Tensor flip(const Tensor& tensor, const unsigned dim) override; - Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high) + Tensor exp(Tensor const& tensor) override; + Tensor log(Tensor const& tensor) override; + Tensor negative(Tensor const& tensor) override; + Tensor logicalNot(Tensor const& tensor) override; + Tensor log1p(Tensor const& tensor) override; + Tensor sin(Tensor const& tensor) override; + Tensor cos(Tensor const& tensor) override; + Tensor sqrt(Tensor const& tensor) override; + Tensor tanh(Tensor const& tensor) override; + Tensor floor(Tensor const& tensor) override; + Tensor ceil(Tensor const& tensor) override; + Tensor rint(Tensor const& tensor) override; + Tensor absolute(Tensor const& tensor) override; + Tensor sigmoid(Tensor const& tensor) override; + Tensor erf(Tensor const& tensor) override; + Tensor flip(Tensor const& tensor, unsigned dim) override; + Tensor clip(Tensor const& tensor, Tensor const& low, Tensor const& high) override; - Tensor roll(const Tensor& tensor, const int shift, const unsigned axis) + Tensor roll(Tensor const& tensor, int shift, unsigned axis) override; - Tensor isnan(const Tensor& tensor) override; - Tensor isinf(const Tensor& tensor) override; - Tensor sign(const Tensor& tensor) override; - Tensor tril(const Tensor& tensor) override; - Tensor triu(const Tensor& tensor) override; - Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) + Tensor isnan(Tensor const& tensor) override; + Tensor isinf(Tensor const& tensor) override; + Tensor sign(Tensor const& tensor) override; + Tensor tril(Tensor const& tensor) override; + Tensor triu(Tensor const& tensor) override; + Tensor where(Tensor const& condition, Tensor const& x, Tensor const& y) override; void topk( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned k, - const Dim axis, - const SortMode sortMode + Tensor const& input, + unsigned k, + Dim axis, + SortMode sortMode ) override; - Tensor sort(const Tensor& input, const Dim axis, const SortMode sortMode) + Tensor sort(Tensor const& input, Dim axis, SortMode sortMode) override; void sort( Tensor& values, Tensor& indices, - const Tensor& input, - const Dim axis, - const SortMode sortMode + Tensor const& input, + Dim axis, + SortMode sortMode ) override; - Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) + Tensor argsort(Tensor const& input, Dim axis, SortMode sortMode) override; /************************** Binary Operators ***************************/ @@ -212,77 +212,77 @@ class ArrayFireBackend : public TensorBackend { #undef FL_AF_BINARY_OP_TYPE_DECL #undef FL_AF_BINARY_OP_LITERALS_DECL - Tensor minimum(const Tensor& lhs, const Tensor& rhs) override; - Tensor maximum(const Tensor& lhs, const Tensor& rhs) override; - Tensor power(const Tensor& lhs, const Tensor& rhs) override; + Tensor minimum(Tensor const& lhs, Tensor const& rhs) override; + Tensor maximum(Tensor const& lhs, Tensor const& rhs) override; + Tensor power(Tensor const& lhs, Tensor const& rhs) override; /******************************* BLAS ********************************/ Tensor matmul( - const Tensor& lhs, - const Tensor& rhs, + Tensor const& lhs, + Tensor const& rhs, MatrixProperty lhsProp, MatrixProperty rhsProp ) override; /************************** Reductions ***************************/ - Tensor amin(const Tensor& input, const std::vector& axes, const bool keepDims) + Tensor amin(Tensor const& input, std::vector const& axes, bool keepDims) override; - Tensor amax(const Tensor& input, const std::vector& axes, const bool keepDims) + Tensor amax(Tensor const& input, std::vector const& axes, bool keepDims) override; void min( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims + Tensor const& input, + unsigned axis, + bool keepDims ) override; void max( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims + Tensor const& input, + unsigned axis, + bool keepDims ) override; - Tensor sum(const Tensor& input, const std::vector& axes, const bool keepDims) + Tensor sum(Tensor const& input, std::vector const& axes, bool keepDims) override; - Tensor cumsum(const Tensor& input, const unsigned axis) override; - Tensor argmax(const Tensor& input, const unsigned axis, const bool keepDims) + Tensor cumsum(Tensor const& input, unsigned axis) override; + Tensor argmax(Tensor const& input, unsigned axis, bool keepDims) override; - Tensor argmin(const Tensor& input, const unsigned axis, const bool keepDims) + Tensor argmin(Tensor const& input, unsigned axis, bool keepDims) override; - Tensor mean(const Tensor& input, const std::vector& axes, const bool keepDims) + Tensor mean(Tensor const& input, std::vector const& axes, bool keepDims) override; Tensor median( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool keepDims ) override; Tensor var( - const Tensor& input, - const std::vector& axes, - const bool bias, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool bias, + bool keepDims ) override; - Tensor std(const Tensor& input, const std::vector& axes, const bool keepDims) + Tensor std(Tensor const& input, std::vector const& axes, bool keepDims) override; Tensor norm( - const Tensor& input, - const std::vector& axes, + Tensor const& input, + std::vector const& axes, double p, - const bool keepDims + bool keepDims ) override; Tensor countNonzero( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool keepDims ) override; - Tensor any(const Tensor& input, const std::vector& axes, const bool keepDims) + Tensor any(Tensor const& input, std::vector const& axes, bool keepDims) override; - Tensor all(const Tensor& input, const std::vector& axes, const bool keepDims) + Tensor all(Tensor const& input, std::vector const& axes, bool keepDims) override; /************************** Utils ***************************/ - void print(const Tensor& tensor) override; + void print(Tensor const& tensor) override; }; } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp b/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp index 5c3fb57..9408c5e 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp @@ -15,7 +15,7 @@ namespace fl { namespace { - bool canBroadcast(const Shape& lhs, const Shape& rhs) { + bool canBroadcast(Shape const& lhs, Shape const& rhs) { unsigned nDim = std::max(lhs.ndim(), rhs.ndim()); for(unsigned i = 0; i < nDim; ++i) { @@ -31,11 +31,11 @@ namespace { // A binary operation on two ArrayFire arrays using binaryOpFunc_t = - af::array (*)(const af::array& lhs, const af::array& rhs); + af::array (*)(af::array const& lhs, af::array const& rhs); Tensor doBinaryOpOrBroadcast( - const Tensor& lhs, - const Tensor& rhs, + Tensor const& lhs, + Tensor const& rhs, binaryOpFunc_t func ) { // Dims are the same or scalar <> 1-el tensor - no broadcasting @@ -122,15 +122,15 @@ FL_AF_BINARY_OP_DEF(>>, rShift); #undef FL_AF_BINARY_OP_TYPE_DEF #undef FL_AF_BINARY_OP_LITERALS_DEF -Tensor ArrayFireBackend::minimum(const Tensor& lhs, const Tensor& rhs) { +Tensor ArrayFireBackend::minimum(Tensor const& lhs, Tensor const& rhs) { return doBinaryOpOrBroadcast(lhs, rhs, af::min); } -Tensor ArrayFireBackend::maximum(const Tensor& lhs, const Tensor& rhs) { +Tensor ArrayFireBackend::maximum(Tensor const& lhs, Tensor const& rhs) { return doBinaryOpOrBroadcast(lhs, rhs, af::max); } -Tensor ArrayFireBackend::power(const Tensor& lhs, const Tensor& rhs) { +Tensor ArrayFireBackend::power(Tensor const& lhs, Tensor const& rhs) { return doBinaryOpOrBroadcast(lhs, rhs, af::pow); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.cpp b/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.cpp index c1098a1..39715ca 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.cpp @@ -15,8 +15,8 @@ std::shared_ptr ArrayFireCPUStream::create() { // TODO `std::make_shared` requires a public constructor, which could be // abused and lead to unregistered stream. However, it has one internal // allocation and is more cache-friendly than `std::shared_ptr`. - const auto rawStreamPtr = new ArrayFireCPUStream(); - const auto stream = std::shared_ptr(rawStreamPtr); + auto const rawStreamPtr = new ArrayFireCPUStream(); + auto const stream = std::shared_ptr(rawStreamPtr); rawStreamPtr->device_.addStream(stream); return stream; } diff --git a/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp b/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp index bd00f5f..d280bed 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp @@ -16,16 +16,18 @@ #include "flashlight/fl/tensor/backend/af/Utils.h" namespace fl { +using af_dim_t = ::dim_t; + namespace { - using reduceFunc_t = af::array (*)(const af::array&, const int); + using reduceFunc_t = af::array (*)(af::array const&, int const); template af::array afReduceAxes( - const af::array& input, - const std::vector& axes, + af::array const& input, + std::vector const& axes, T func, - const bool keepDims = false + bool const keepDims = false ) { auto arr = input; for(int dim : axes) @@ -33,7 +35,7 @@ namespace { return fl::detail::condenseIndices(arr, keepDims); } - unsigned getReducedNumDims(unsigned inSize, unsigned axisSize, const bool keepDims) { + unsigned getReducedNumDims(unsigned inSize, unsigned axisSize, bool const keepDims) { if(keepDims) return inSize; else { @@ -44,7 +46,7 @@ namespace { } } - bool isAllAxisReduction(const Tensor& input, const std::vector& axes) { + bool isAllAxisReduction(Tensor const& input, std::vector const& axes) { if(input.ndim() == 0 || axes.empty()) return true; if(input.ndim() != axes.size()) @@ -60,9 +62,9 @@ namespace { } // namespace Tensor ArrayFireBackend::amin( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) { if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor @@ -72,7 +74,8 @@ Tensor ArrayFireBackend::amin( detail::condenseIndices( af::min(af::min(af::min(af::min(toArray(input))))) ), - /* numDims = */ 0 + /* numDims = */ + 0 ); else return toTensor( @@ -82,9 +85,9 @@ Tensor ArrayFireBackend::amin( } Tensor ArrayFireBackend::amax( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) { if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor @@ -94,7 +97,8 @@ Tensor ArrayFireBackend::amax( detail::condenseIndices( af::max(af::max(af::max(af::max(toArray(input))))) ), - /* numDims = */ 0 + /* numDims = */ + 0 ); else return toTensor( @@ -106,9 +110,9 @@ Tensor ArrayFireBackend::amax( void ArrayFireBackend::min( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims + Tensor const& input, + unsigned const axis, + bool const keepDims ) { af::min(toArray(values), toArray(indices), toArray(input), axis); values = toTensor( @@ -124,9 +128,9 @@ void ArrayFireBackend::min( void ArrayFireBackend::max( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims + Tensor const& input, + unsigned const axis, + bool const keepDims ) { af::max(toArray(values), toArray(indices), toArray(input), axis); values = toTensor( @@ -140,9 +144,9 @@ void ArrayFireBackend::max( } Tensor ArrayFireBackend::sum( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) { if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor @@ -152,7 +156,8 @@ Tensor ArrayFireBackend::sum( detail::condenseIndices( af::sum(af::sum(af::sum(af::sum(toArray(input))))) ), - /* numDims = */ 0 + /* numDims = */ + 0 ); else return toTensor( @@ -161,17 +166,18 @@ Tensor ArrayFireBackend::sum( ); } -Tensor ArrayFireBackend::cumsum(const Tensor& input, const unsigned axis) { +Tensor ArrayFireBackend::cumsum(Tensor const& input, unsigned const axis) { return toTensor( - af::accum(toArray(input), axis), /* numDims = */ + af::accum(toArray(input), axis), + /* numDims = */ input.ndim() ); } Tensor ArrayFireBackend::argmax( - const Tensor& input, - const unsigned axis, - const bool keepDims + Tensor const& input, + unsigned const axis, + bool const keepDims ) { af::array tmpVal, indices; af::max(tmpVal, indices, toArray(input), axis); @@ -182,9 +188,9 @@ Tensor ArrayFireBackend::argmax( } Tensor ArrayFireBackend::argmin( - const Tensor& input, - const unsigned axis, - const bool keepDims + Tensor const& input, + unsigned const axis, + bool const keepDims ) { af::array tmpVal, indices; af::min(tmpVal, indices, toArray(input), axis); @@ -195,9 +201,9 @@ Tensor ArrayFireBackend::argmin( } Tensor ArrayFireBackend::mean( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) { if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor @@ -207,14 +213,15 @@ Tensor ArrayFireBackend::mean( detail::condenseIndices( af::mean(af::mean(af::mean(af::mean(toArray(input))))) ), - /* numDims = */ 0 + /* numDims = */ + 0 ); else return toTensor( - afReduceAxes( + afReduceAxes( toArray(input), axes, - af::mean, + [](af::array const& arr, int dim) { return af::mean(arr, dim); }, keepDims ), getReducedNumDims(input.ndim(), axes.size(), keepDims) @@ -222,9 +229,9 @@ Tensor ArrayFireBackend::mean( } Tensor ArrayFireBackend::median( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) { if(isAllAxisReduction(input, axes)) { // Reduce along all axes returning a singleton tensor @@ -233,14 +240,16 @@ Tensor ArrayFireBackend::median( double median = af::median(toArray(input)); return toTensor( af::constant(median, 1), - /* numDims = */ 0 + /* numDims = */ + 0 ); - } else + } + else return toTensor( - afReduceAxes( + afReduceAxes( toArray(input), axes, - af::median, + [](af::array const& arr, int dim) { return af::median(arr, static_cast(dim)); }, keepDims ), getReducedNumDims(input.ndim(), axes.size(), keepDims) @@ -248,10 +257,10 @@ Tensor ArrayFireBackend::median( } Tensor ArrayFireBackend::var( - const Tensor& input, - const std::vector& axes, - const bool bias, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const bias, + bool const keepDims ) { af_var_bias biasMode = bias ? AF_VARIANCE_SAMPLE : AF_VARIANCE_POPULATION; // Use ArrayFire default for one dimension which may be optimized @@ -262,7 +271,8 @@ Tensor ArrayFireBackend::var( if(isAllAxisReduction(input, axes)) { double out = af::var(toArray(input), biasMode); return toTensor(af::constant(out, 1), /* numDims = */ 0); - } else if(axes.size() == 1) + } + else if(axes.size() == 1) return toTensor( detail::condenseIndices(af::var(arr, biasMode, axes[0]), keepDims), getReducedNumDims(input.ndim(), axes.size(), keepDims) @@ -290,17 +300,18 @@ Tensor ArrayFireBackend::var( } Tensor ArrayFireBackend::std( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) { - const bool bias = false; // TODO: make this configurable + bool const bias = false; // TODO: make this configurable af_var_bias biasMode = bias ? AF_VARIANCE_SAMPLE : AF_VARIANCE_POPULATION; if(isAllAxisReduction(input, axes)) { // TODO: update to af::stdev once specialization is available double out = af::stdev(toArray(input), biasMode); return toTensor(af::constant(out, 1), /* numDims = */ 0); - } else if(axes.size() == 1) + } + else if(axes.size() == 1) // Use arrayfire default for one dimension which may be optimized // TODO: update this? stddev is deprecated. return toTensor( @@ -314,10 +325,10 @@ Tensor ArrayFireBackend::std( } Tensor ArrayFireBackend::norm( - const Tensor& input, - const std::vector& axes, + Tensor const& input, + std::vector const& axes, double p /* = 2 */, - const bool keepDims + bool const keepDims ) { if(isAllAxisReduction(input, axes)) { // TODO: update to af::norm if device-side specialization is @@ -328,10 +339,12 @@ Tensor ArrayFireBackend::norm( result = af::sum(af::sum(af::sum(result))); result = af::pow(result, 1 / p); return toTensor( - detail::condenseIndices(result), /* numDims = */ + detail::condenseIndices(result), + /* numDims = */ 0 ); - } else { + } + else { auto result = af::pow(af::abs(toArray(input)), p); result = afReduceAxes(result, axes, af::sum, keepDims); result = af::pow(result, 1 / p); @@ -343,9 +356,9 @@ Tensor ArrayFireBackend::norm( } Tensor ArrayFireBackend::countNonzero( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) { auto& arr = toArray(input); unsigned numDims; @@ -356,10 +369,12 @@ Tensor ArrayFireBackend::countNonzero( keepDims ); numDims = 0; - } else if(axes.size() == 1) { + } + else if(axes.size() == 1) { out = af::count(arr, axes.front()); numDims = getReducedNumDims(input.ndim(), axes.size(), keepDims); - } else { + } + else { out = afReduceAxes( af::count(arr, axes.front()), std::vector(axes.begin() + 1, axes.end()), @@ -375,9 +390,9 @@ Tensor ArrayFireBackend::countNonzero( } Tensor ArrayFireBackend::any( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) { if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor @@ -387,7 +402,8 @@ Tensor ArrayFireBackend::any( detail::condenseIndices( af::anyTrue(af::anyTrue(af::anyTrue(af::anyTrue(toArray(input))))) ), - /* numDims = */ 0 + /* numDims = */ + 0 ); else return toTensor( @@ -397,9 +413,9 @@ Tensor ArrayFireBackend::any( } Tensor ArrayFireBackend::all( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) { if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor @@ -409,7 +425,8 @@ Tensor ArrayFireBackend::all( detail::condenseIndices( af::allTrue(af::allTrue(af::allTrue(af::allTrue(toArray(input))))) ), - /* numDims = */ 0 + /* numDims = */ + 0 ); else return toTensor( diff --git a/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp b/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp index 397cd75..98bba4b 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp @@ -13,10 +13,11 @@ #include "flashlight/fl/tensor/backend/af/ArrayFireTensor.h" #include "flashlight/fl/tensor/backend/af/Utils.h" +#include namespace fl { -Tensor ArrayFireBackend::reshape(const Tensor& tensor, const Shape& shape) { +Tensor ArrayFireBackend::reshape(Tensor const& tensor, Shape const& shape) { return toTensor( af::moddims(toArray(tensor), detail::flToAfDims(shape)), shape.ndim() @@ -24,8 +25,8 @@ Tensor ArrayFireBackend::reshape(const Tensor& tensor, const Shape& shape) { } Tensor ArrayFireBackend::transpose( - const Tensor& tensor, - const Shape& axes /* = {} */ + Tensor const& tensor, + Shape const& axes /* = {} */ ) { if(tensor.ndim() == 1) return tensor; @@ -48,7 +49,8 @@ Tensor ArrayFireBackend::transpose( af::reorder(toArray(tensor), dims[0], dims[1], dims[2], dims[3]), tensor.ndim() ); - } else { + } + else { if(axes.ndim() > AF_MAX_DIMS) throw std::invalid_argument( "ArrayFire tensor transpose was given " @@ -79,7 +81,7 @@ Tensor ArrayFireBackend::transpose( } } -Tensor ArrayFireBackend::tile(const Tensor& tensor, const Shape& shape) { +Tensor ArrayFireBackend::tile(Tensor const& tensor, Shape const& shape) { return toTensor( af::tile(toArray(tensor), detail::flToAfDims(shape)), // TODO: check @@ -87,62 +89,81 @@ Tensor ArrayFireBackend::tile(const Tensor& tensor, const Shape& shape) { ); } + +namespace { + + af::array join_chunk(std::span chunk, unsigned const axis) { + switch(chunk.size()) { + case 0: throw std::invalid_argument{"Cannot concatenate empty chunk"}; + case 1: return chunk[0]; + case 2: return af::join(axis, chunk[0], chunk[1]); + case 3: return af::join(axis, chunk[0], chunk[1], chunk[2]); + case 4: return af::join(axis, chunk[0], chunk[1], chunk[2], chunk[3]); + default: { + std::vector handles{}; + handles.reserve(chunk.size()); + for(auto const& arr : chunk) + handles.push_back(arr.get()); + + af_array outHandle = nullptr; + AF_CHECK(af_join_many(&outHandle, axis, chunk.size(), handles.data())); + return af::array{outHandle}; + } + } + } + +} // namespace + Tensor ArrayFireBackend::concatenate( - const std::vector& tensors, - const unsigned axis + std::vector const& tensors, + unsigned axis ) { - af::array out; - switch(tensors.size()) { - case 0: - return toTensor(ArrayFireTensor()); // empty tensor - case 1: - return tensors.front(); - case 2: - out = af::join(axis, toArray(tensors[0]), toArray(tensors[1])); - break; - case 3: - out = af::join( - axis, - toArray(tensors[0]), - toArray(tensors[1]), - toArray(tensors[2]) - ); - break; - case 4: - out = af::join( - axis, - toArray(tensors[0]), - toArray(tensors[1]), - toArray(tensors[2]), - toArray(tensors[3]) - ); - break; - default: - // TODO: iteratively concat to remove this limitation - throw std::invalid_argument( - "ArrayFire concatenate doesn't support > 4 tensors" - ); + if(tensors.empty()) + return toTensor(); // empty tensor + + //TODO use std::from_range and views::transform once c++23 + std::vector arrays{}; + arrays.reserve(tensors.size()); + + + for(auto const& t : tensors){ + arrays.push_back(toArray(t)); + } + constexpr size_t maxChunkSize = 10; //https://arrayfire.org/docs/group__manip__func__join.htm + + //greedy chunk and join + while(arrays.size() > 1) { + size_t const chunks = (arrays.size() + maxChunkSize - 1) / maxChunkSize; + + for(size_t i = 0; i < chunks; i++) { + auto const begin = i * maxChunkSize; + auto const size = std::min(maxChunkSize, arrays.size() - begin); + + arrays[i] = join_chunk({&arrays[begin], size}, axis); + } + arrays.resize(chunks); } unsigned numDims = tensors[0].ndim(); - if(axis > std::max(numDims - 1, 0u)) + if(axis >= numDims) numDims = axis + 1; // All tensors have the same numdims else AF would throw - return toTensor(std::move(out), numDims); + return toTensor(std::move(arrays[0]), numDims); } -Tensor ArrayFireBackend::nonzero(const Tensor& tensor) { +Tensor ArrayFireBackend::nonzero(Tensor const& tensor) { return toTensor( - af::where(toArray(tensor)), /* numDims = */ + af::where(toArray(tensor)), + /* numDims = */ 1 ); } Tensor ArrayFireBackend::pad( - const Tensor& input, - const std::vector>& padWidths, - const PadType type + Tensor const& input, + std::vector> const& padWidths, + PadType const type ) { if(padWidths.size() > AF_MAX_DIMS) throw std::invalid_argument( @@ -165,8 +186,9 @@ Tensor ArrayFireBackend::pad( endPadding, detail::flToAfPadType(type) ), - /* numDims = */ // TODO: check - std::max(input.ndim(), static_cast(padWidths.size())) + /* numDims = */ + // TODO: check + std::max(input.ndim(), padWidths.size()) ); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp b/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp index ed581ae..311121d 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp @@ -25,8 +25,10 @@ #include namespace fl { +using af_dim_t = ::dim_t; -const af::array& toArray(const Tensor& tensor) { + +af::array const& toArray(Tensor const& tensor) { if(tensor.backendType() != TensorBackendType::ArrayFire) throw std::invalid_argument("toArray: tensor is not ArrayFire-backed"); return tensor.getAdapter().getHandle(); @@ -40,7 +42,7 @@ af::array& toArray(Tensor& tensor) { ArrayFireTensor::ArrayFireTensor( af::array&& array, - const unsigned numDims + size_t const numDims ) : arrayHandle_(std::make_shared(std::move(array))), numDims_(numDims) {} @@ -48,8 +50,8 @@ ArrayFireTensor::ArrayFireTensor( std::shared_ptr arr, std::vector&& afIndices, std::vector&& indexTypes, - const unsigned numDims, - const bool isFlat + size_t const numDims, + bool const isFlat ) : arrayHandle_(arr), indices_(std::move(afIndices)), indexTypes_(std::move(indexTypes)), @@ -58,79 +60,75 @@ ArrayFireTensor::ArrayFireTensor( ArrayFireTensor::ArrayFireTensor( std::shared_ptr arr, - unsigned numDims + size_t numDims ) : arrayHandle_(arr), numDims_(numDims) {} ArrayFireTensor::ArrayFireTensor() : arrayHandle_(std::make_shared()), - handle_(ArrayComponent()) {} + handle_(ArrayComponent()) {} ArrayFireTensor::ArrayFireTensor( - const Shape& shape, + Shape const& shape, fl::dtype type, - const void* ptr, + void const* ptr, Location memoryLocation -) : arrayHandle_(std::make_shared( - detail::fromFlData(shape, ptr, type, memoryLocation) - )), +) : arrayHandle_( + std::make_shared( + detail::fromFlData(shape, ptr, type, memoryLocation) + ) + ), handle_(ArrayComponent()), numDims_(shape.ndim()) {} ArrayFireTensor::ArrayFireTensor( - const Dim nRows, - const Dim nCols, - const Tensor& values, - const Tensor& rowIdx, - const Tensor& colIdx, + Dim const nRows, + Dim const nCols, + Tensor const& values, + Tensor const& rowIdx, + Tensor const& colIdx, StorageType storageType -) : arrayHandle_(std::make_shared( - af::sparse( - nRows, - nCols, - toArray(values), - toArray(rowIdx), - toArray(colIdx), - detail::flToAfStorageType(storageType)) - )), +) : arrayHandle_( + std::make_shared( + af::sparse( + nRows, + nCols, + toArray(values), + toArray(rowIdx), + toArray(colIdx), + detail::flToAfStorageType(storageType) + ) + ) + ), handle_(ArrayComponent()), // ArrayFire only supports 2D sparsity numDims_(2) {} -unsigned ArrayFireTensor::numDims() const { - return numDims_; -} +size_t ArrayFireTensor::numDims() const { return numDims_; } ArrayFireTensor::IndexedArrayComponent::IndexedArrayComponent( - const bool _isFlat /* = false */ + bool const _isFlat /* = false */ ) : isFlat(_isFlat) {} af::array::array_proxy ArrayFireTensor::IndexedArrayComponent::get( - const ArrayFireTensor& inst + ArrayFireTensor const& inst ) { auto& i = inst.indices_.value(); auto& a = *(inst.arrayHandle_); switch(i.size()) { - case 1: - return a(i[0]); - case 2: - return a(i[0], i[1]); - case 3: - return a(i[0], i[1], i[2]); - case 4: - return a(i[0], i[1], i[2], i[3]); - default: - throw std::invalid_argument( + case 1: return a(i[0]); + case 2: return a(i[0], i[1]); + case 3: return a(i[0], i[1], i[2]); + case 4: return a(i[0], i[1], i[2], i[3]); + default: throw std::invalid_argument( "ArrayFireTensor::IndexedArrayComponent::get - " "given invalid number of index components." ); } } -af::array& ArrayFireTensor::ArrayComponent::get(const ArrayFireTensor& inst) { return *(inst.arrayHandle_); } +af::array& ArrayFireTensor::ArrayComponent::get(ArrayFireTensor const& inst) { return *(inst.arrayHandle_); } -const af::array& ArrayFireTensor::getHandle() const { - return const_cast(this)->getHandle(); -} +af::array const& ArrayFireTensor::getHandle() const { return const_cast(this)->getHandle(); } af::array& ArrayFireTensor::getHandle() { // If the handle currently requires indexing, perform the indexing, change the @@ -144,13 +142,15 @@ af::array& ArrayFireTensor::getHandle() { arrayHandle_ = std::make_shared( detail::condenseIndices( idxComp.get(*this), - /* keepDims = */ false, + /* keepDims = */ + false, indexTypes_, - /* isFlat = */ idxComp.isFlat + /* isFlat = */ + idxComp.isFlat ) ); // Clear state - handle_ = ArrayComponent(); // set to passthrough + handle_ = ArrayComponent{}; // set to passthrough indices_ = {}; // remove indices indexTypes_ = {}; // remove IndexTypes } @@ -174,21 +174,19 @@ Tensor ArrayFireTensor::shallowCopy() { getHandle(); // if this tensor was a view, run indexing and promote return Tensor( std::unique_ptr( - new ArrayFireTensor(arrayHandle_, numDims()) + new ArrayFireTensor{arrayHandle_, numDims()} ) ); } -TensorBackendType ArrayFireTensor::backendType() const { - return TensorBackendType::ArrayFire; -} +TensorBackendType ArrayFireTensor::backendType() const { return TensorBackendType::ArrayFire; } TensorBackend& ArrayFireTensor::backend() const { // The ArrayFire backend has a single ArrayFireBackend instance per process. return ::fl::ArrayFireBackend::getInstance(); } -const Shape& ArrayFireTensor::shape() { +Shape const& ArrayFireTensor::shape() { // Update the Shape in-place. Doesn't change any underlying data; only the // mirrored Shape metadata. detail::afToFlDims(getHandle().dims(), numDims(), shape_); @@ -204,12 +202,9 @@ af::dtype ArrayFireTensor::afHandleType() { return arrayHandle_->type(); } Location ArrayFireTensor::location() { switch(af::getBackendId(getHandle())) { case AF_BACKEND_CUDA: - case AF_BACKEND_OPENCL: - return Location::Device; - case AF_BACKEND_CPU: - return Location::Host; - default: - throw std::logic_error( + case AF_BACKEND_OPENCL: return Location::Device; + case AF_BACKEND_CPU: return Location::Host; + default: throw std::logic_error( "ArrayFireTensor::location got an unmatched location" ); } @@ -219,7 +214,11 @@ void ArrayFireTensor::scalar(void* out) { AF_CHECK(af_get_scalar(out, getHandle( void ArrayFireTensor::device(void** out) { AF_CHECK(af_get_device_ptr(out, getHandle().get())); } -void ArrayFireTensor::host(void* out) { AF_CHECK(af_get_data_ptr(out, getHandle().get())); } +void ArrayFireTensor::host(void* out) { + + + AF_CHECK(af_get_data_ptr(out, getHandle().get())); +} void ArrayFireTensor::unlock() { AF_CHECK(af_unlock_array(getHandle().get())); } @@ -238,18 +237,18 @@ bool ArrayFireTensor::isContiguous() { return af::isLinear(getHandle()); } Shape ArrayFireTensor::strides() { return detail::afToFlDims(af::getStrides(getHandle()), numDims()); } -const Stream& ArrayFireTensor::stream() const { +Stream const& ArrayFireTensor::stream() const { // TODO indexing is unlikely to change the stream associated with a tensor. // But if it can, we need to call `getHandle()` here. return ArrayFireBackend::getInstance().getStreamOfArray(*arrayHandle_); } -Tensor ArrayFireTensor::astype(const dtype type) { +Tensor ArrayFireTensor::asType(dtype const type) { auto a = getHandle().as(detail::flToAfType(type)); return toTensor(std::move(a), numDims()); } -Tensor ArrayFireTensor::index(const std::vector& indices) { +Tensor ArrayFireTensor::index(std::vector const& indices) { if(indices.size() > AF_MAX_DIMS) throw std::invalid_argument( "ArrayFire-backed tensor was indexed with > 4 elements:" @@ -299,7 +298,7 @@ Tensor ArrayFireTensor::index(const std::vector& indices) { // tensor(s) newNumDims = 1; else - for(const auto& type : indexTypes) + for(auto const& type : indexTypes) if(type == detail::IndexType::Literal) newNumDims--; newNumDims = std::max(newNumDims, 1u); // can never index to a 0 dim tensor @@ -311,17 +310,16 @@ Tensor ArrayFireTensor::index(const std::vector& indices) { std::move(afIndices), std::move(indexTypes), newNumDims, - /* isFlat = */ false + /* isFlat = */ + false ) ) ); } -Tensor ArrayFireTensor::flatten() const { - return toTensor(af::flat(getHandle()), /* numDims = */ 1); -} +Tensor ArrayFireTensor::flatten() const { return toTensor(af::flat(getHandle()), /* numDims = */ 1); } -Tensor ArrayFireTensor::flat(const Index& idx) const { +Tensor ArrayFireTensor::flat(Index const& idx) const { getHandle(); // if this tensor was a view, run indexing and promote // Return a lazy indexing operation. Indexing with a single index on an // ArrayFire tensor (with a type that is not an af::array) ends up doing @@ -332,8 +330,10 @@ Tensor ArrayFireTensor::flat(const Index& idx) const { arrayHandle_, {detail::flToAfIndex(idx)}, {idx.type()}, - /* numDims = */ 1, - /* isFlat = */ true + /* numDims = */ + 1, + /* isFlat = */ + true ) ) ); @@ -345,7 +345,7 @@ Tensor ArrayFireTensor::asContiguousTensor() { return toTensor(std::move(other), numDims()); } - const af::array& array = getHandle(); + af::array const& array = getHandle(); auto linearArray = af::array(array.dims(), array.type()); af::copy(linearArray, array, af::span); return toTensor(std::move(linearArray), numDims()); @@ -358,9 +358,9 @@ void* ArrayFireTensor::getContext() { } std::string ArrayFireTensor::toString() { - const char* afStr = af::toString("ArrayFireTensor", getHandle()); + char const* afStr = af::toString("ArrayFireTensor", getHandle()); // std::string copies `afStr` content into its own buffer - const std::string str(afStr); + std::string const str(afStr); af::freeHost(afStr); return str; } @@ -394,14 +394,14 @@ std::ostream& ArrayFireTensor::operator<<(std::ostream& ostr) { ASSIGN_OP_TYPE(FUN, AF_OP, long long); \ ASSIGN_OP_TYPE(FUN, AF_OP, unsigned long long); -af::array ArrayFireTensor::adjustInPlaceOperandDims(const Tensor& operand) { +af::array ArrayFireTensor::adjustInPlaceOperandDims(Tensor const& operand) { // optimstically try to moddims the operand's singleton dims - const af::dim4& preIdxDims = arrayHandle_->dims(); - const af::array& operandArr = toArray(operand); + af::dim4 const& preIdxDims = arrayHandle_->dims(); + af::array const& operandArr = toArray(operand); // dims to which to try to modify the input if doing indexing af::dim4 newDims; - const af::dim4 operandDims = operandArr.dims(); + af::dim4 const operandDims = operandArr.dims(); using detail::IndexType; if(indices_ && indices_.value().size() == 1) { @@ -412,10 +412,11 @@ af::array ArrayFireTensor::adjustInPlaceOperandDims(const Tensor& operand) { "ArrayFireTensor::adjustInPlaceOperandDims " "index size was 1 but tensor has greater than 1 dimension." ); - } else if(indices_ && !indices_.value().empty()) { + } + else if(indices_ && !indices_.value().empty()) { // All other indexing operations - const auto& indices = indices_.value(); - const auto& indexTypes = indexTypes_.value(); + auto const& indices = indices_.value(); + auto const& indexTypes = indexTypes_.value(); if(indices.size() != indexTypes.size()) throw std::invalid_argument( "ArrayFireTensor adjustInPlaceOperandDims - passed indices" @@ -446,16 +447,18 @@ af::array ArrayFireTensor::adjustInPlaceOperandDims(const Tensor& operand) { ) { compressIdx++; postIdxDims[i] = 1; - } else { + } + else { // Use the size of the dim post-indexing. Span uses the preIdx dim // and literals are pushed to 1. if(i < indexTypes.size()) { if(indexTypes[i] == IndexType::Tensor) { - dim_t size; + af_dim_t size; AF_CHECK(af_get_elements(&size, indices[i].get().idx.arr)); postIdxDims[i] = size; - } else if(indexTypes[i] == IndexType::Range) - postIdxDims[i] = af::seq(indices[i].get().idx.seq).size; + } + else if(indexTypes[i] == IndexType::Range) + postIdxDims[i] = static_cast(af::seq(indices[i].get().idx.seq).size); else if(indexTypes[i] == IndexType::Literal) postIdxDims[i] = 1; } @@ -473,8 +476,9 @@ af::array ArrayFireTensor::adjustInPlaceOperandDims(const Tensor& operand) { "ArrayFireTensor adjustInPlaceOperandDims: can't apply operation " "in-place to indexed ArrayFireTensor - dimensions don't match." ); - } else - // No indexing so no change in dimensions required + } + else + // No indexing so no change in dimensions required newDims = operandDims; // af::moddims involves an eval. This will be fixed in AF 3.8.1/3.8.2 @@ -497,18 +501,18 @@ af::array ArrayFireTensor::adjustInPlaceOperandDims(const Tensor& operand) { ASSIGN_OP_LITERALS(FUN, AF_OP) // (function name, AF op). Use build-in AF operators. -ASSIGN_OP(inPlaceSubtract, -= ); -ASSIGN_OP(inPlaceMultiply, *= ); -ASSIGN_OP(inPlaceDivide, /= ); +ASSIGN_OP(inPlaceSubtract, -=); +ASSIGN_OP(inPlaceMultiply, *=); +ASSIGN_OP(inPlaceDivide, /=); // Instantiate definitions for type literals - those remain unchanged: -ASSIGN_OP_LITERALS(assign, = ); -void ArrayFireTensor::assign(const Tensor& tensor) { +ASSIGN_OP_LITERALS(assign, =); +void ArrayFireTensor::assign(Tensor const& tensor) { std::visit( [&tensor, this](auto&& arr) { if(indices_) - // If this is an indexing op, do as other in-place ops with lvalue - // temporaries as a result of indexing do + // If this is an indexing op, do as other in-place ops with lvalue + // temporaries as a result of indexing do arr.get(*this) = this->adjustInPlaceOperandDims(tensor); else { // Not an indexing op - just assign the tensor, but make sure to @@ -528,9 +532,9 @@ void ArrayFireTensor::assign(const Tensor& tensor) { * it properly-handles the case of repeated indices. */ // Instantiate definitions for type literals - those remain unchanged: -ASSIGN_OP_LITERALS(inPlaceAdd, += ); +ASSIGN_OP_LITERALS(inPlaceAdd, +=); // Special tensor op: -void ArrayFireTensor::inPlaceAdd(const Tensor& tensor) { +void ArrayFireTensor::inPlaceAdd(Tensor const& tensor) { // First, check if this a tensor that's going to be lazily indexed. Don't // implicitly cast to an array, else that will trigger indexing. // Carefully get the handle types without calling type(), which will lazily @@ -550,37 +554,40 @@ void ArrayFireTensor::inPlaceAdd(const Tensor& tensor) { ) { // Call the regular af::array::operator+= std::visit( - [&tensor, this](auto&& arr) { - arr.get(*this) += this->adjustInPlaceOperandDims(tensor); - }, + [&tensor, this](auto&& arr) { arr.get(*this) += this->adjustInPlaceOperandDims(tensor); }, handle_ ); return; - } else { + } + else { af::dim4 inDims = arrayHandle_->dims(); af::dim4 idxStart; af::dim4 idxEnd; std::vector idxArr(4); auto idxFunc = [&idxStart, &idxEnd, &idxArr, &inDims]( - const af::index& index, int pos) { - if(index.isspan()) { + af::index const& index, + int pos + ) { + if(index.isspan()) { + idxStart[pos] = 0; + idxEnd[pos] = inDims[pos]; + } + else { + auto const& idxSeq = index.get(); + if(idxSeq.isSeq) { + // arrayfire uses inclusive last dimension, we use exclusive + idxStart[pos] = idxSeq.idx.seq.begin; + idxEnd[pos] = idxSeq.idx.seq.end + 1; + } + else { + af_array arr; + af_retain_array(&arr, idxSeq.idx.arr); + idxArr[pos] = af::array(arr); idxStart[pos] = 0; - idxEnd[pos] = inDims[pos]; - } else { - const auto& idxSeq = index.get(); - if(idxSeq.isSeq) { - // arrayfire uses inclusive last dimension, we use exclusive - idxStart[pos] = idxSeq.idx.seq.begin; - idxEnd[pos] = idxSeq.idx.seq.end + 1; - } else { - af_array arr; - af_retain_array(&arr, idxSeq.idx.arr); - idxArr[pos] = af::array(arr); - idxStart[pos] = 0; - idxEnd[pos] = idxArr[pos].dims(0); - } + idxEnd[pos] = idxArr[pos].dims(0); } - }; + } + }; unsigned i = 0; for(; i < indices_.value().size(); ++i) diff --git a/flashlight/fl/tensor/backend/af/ArrayFireTensor.h b/flashlight/fl/tensor/backend/af/ArrayFireTensor.h index faeda13..93c6f56 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireTensor.h +++ b/flashlight/fl/tensor/backend/af/ArrayFireTensor.h @@ -41,18 +41,18 @@ class ArrayFireTensor : public TensorAdapterBase { // To be visited when this tensor is to be indexed. Indexes the underlying // af::array, and returns the proxy to be used as a temporary lvalue. struct IndexedArrayComponent { - explicit IndexedArrayComponent(const bool _isFlat = false); - af::array::array_proxy get(const ArrayFireTensor& inst); + explicit IndexedArrayComponent(bool _isFlat = false); + af::array::array_proxy get(ArrayFireTensor const& inst); bool isFlat; }; // To be visited when this tensor is holding an array without needing // indexing. Passthrough - returns the array directly. struct ArrayComponent { - af::array& get(const ArrayFireTensor& inst); + af::array& get(ArrayFireTensor const& inst); }; // An interface to visit when getting an array handle. Indexes lazily // because we can't store an af::array::proxy as an lvalue. See getHandle(). - std::variant handle_{ArrayComponent()}; + std::variant handle_{ArrayComponent{}}; /** * Constructs an ArrayFireTensor that will be lazily indexed. @@ -68,7 +68,7 @@ class ArrayFireTensor : public TensorAdapterBase { * a full af::array on which operations can be performed. * * @param[in] handle a pointer to the ArrayFire array - * @param[in] indices a vector of ArrayFire indices to lazily index. + * @param[in] afIndices a vector of ArrayFire indices to lazily index. * @param[in] indexTypes a vector of index types to lazily index. Needed to * determine singleton dimension condensation * @param[in] isFlat if the indexing op is flat (condense all dims) @@ -77,15 +77,15 @@ class ArrayFireTensor : public TensorAdapterBase { std::shared_ptr handle, std::vector&& afIndices, std::vector&& indexTypes, - const unsigned numDims, - const bool isFlat + size_t numDims, + bool isFlat ); /** * Construct an ArrayFireTensor from an ArrayFire array handle without copying * the handle. Used for creating guaranteed-shallow copies. */ - explicit ArrayFireTensor(std::shared_ptr arr, unsigned numDims); + explicit ArrayFireTensor(std::shared_ptr arr, size_t numDims); /* * A Flashlight Shape that mirrors ArrayFire dims. @@ -108,7 +108,7 @@ class ArrayFireTensor : public TensorAdapterBase { * The fl::Tensor default Tensor shape is {0} - the default number of numDims * is thus 1. Scalars have numDims == 0; */ - unsigned numDims_{1}; + size_t numDims_{1}; public: constexpr static TensorBackendType tensorBackendType = TensorBackendType::ArrayFire; @@ -126,7 +126,7 @@ class ArrayFireTensor : public TensorAdapterBase { * @param[in] array construct a tensor from an ArrayFire array rvalue * reference. */ - explicit ArrayFireTensor(af::array&& array, const unsigned numDims); + explicit ArrayFireTensor(af::array&& array, size_t numDims); /** * Default initialization - empty ArrayFire array and empty shape. @@ -142,18 +142,18 @@ class ArrayFireTensor : public TensorAdapterBase { * @param[in] memoryLocation the location of the buffer */ ArrayFireTensor( - const Shape& shape, + Shape const& shape, fl::dtype type, - const void* ptr, + void const* ptr, Location memoryLocation ); ArrayFireTensor( - const Dim nRows, - const Dim nCols, - const Tensor& values, - const Tensor& rowIdx, - const Tensor& colIdx, + Dim nRows, + Dim nCols, + Tensor const& values, + Tensor const& rowIdx, + Tensor const& colIdx, StorageType storageType ); @@ -163,7 +163,7 @@ class ArrayFireTensor : public TensorAdapterBase { * Throws if this tensor represents an array_proxy, since it precludes * promotion to an array. */ - const af::array& getHandle() const; + af::array const& getHandle() const; /** * Gets an ArrayFire Array from this impl. If the underlying handle is an @@ -173,14 +173,14 @@ class ArrayFireTensor : public TensorAdapterBase { af::array& getHandle(); ~ArrayFireTensor() override = default; - unsigned numDims() const; + size_t numDims() const; // Used with the fl::Tensor copy constructor std::unique_ptr clone() const override; TensorBackendType backendType() const override; TensorBackend& backend() const override; Tensor copy() override; Tensor shallowCopy() override; - const Shape& shape() override; + Shape const& shape() override; dtype type() override; bool isSparse() override; af::dtype afHandleType(); // for internal use only @@ -192,11 +192,11 @@ class ArrayFireTensor : public TensorAdapterBase { bool isLocked() override; bool isContiguous() override; Shape strides() override; - const Stream& stream() const override; - Tensor astype(const dtype type) override; - Tensor index(const std::vector& indices) override; + Stream const& stream() const override; + Tensor asType(dtype type) override; + Tensor index(std::vector const& indices) override; Tensor flatten() const override; - Tensor flat(const Index& idx) const override; + Tensor flat(Index const& idx) const override; Tensor asContiguousTensor() override; void setContext(void* context) override; // noop void* getContext() override; // noop @@ -222,7 +222,7 @@ class ArrayFireTensor : public TensorAdapterBase { * @param[in] operand the tensor operand * @param[in] newNumDims the number of dims of the resulting tensor */ - af::array adjustInPlaceOperandDims(const Tensor& operand); + af::array adjustInPlaceOperandDims(Tensor const& operand); #define ASSIGN_OP(OP) \ ASSIGN_OP_TYPE(OP, Tensor); \ @@ -256,7 +256,7 @@ class ArrayFireTensor : public TensorAdapterBase { * @param[in] tensor the input tensor * @return the array underying the Tensor */ -const af::array& toArray(const Tensor& tensor); +af::array const& toArray(Tensor const& tensor); af::array& toArray(Tensor& tensor); } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp b/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp index c12e1a3..6ff7c8c 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp @@ -14,67 +14,67 @@ namespace fl { -Tensor ArrayFireBackend::exp(const Tensor& tensor) { +Tensor ArrayFireBackend::exp(Tensor const& tensor) { return toTensor(af::exp(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::log(const Tensor& tensor) { +Tensor ArrayFireBackend::log(Tensor const& tensor) { return toTensor(af::log(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::negative(const Tensor& tensor) { +Tensor ArrayFireBackend::negative(Tensor const& tensor) { return toTensor(-toArray(tensor), tensor.ndim()); } -Tensor ArrayFireBackend::logicalNot(const Tensor& tensor) { +Tensor ArrayFireBackend::logicalNot(Tensor const& tensor) { return toTensor(!toArray(tensor), tensor.ndim()); } -Tensor ArrayFireBackend::log1p(const Tensor& tensor) { +Tensor ArrayFireBackend::log1p(Tensor const& tensor) { return toTensor(af::log1p(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::sin(const Tensor& tensor) { +Tensor ArrayFireBackend::sin(Tensor const& tensor) { return toTensor(af::sin(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::cos(const Tensor& tensor) { +Tensor ArrayFireBackend::cos(Tensor const& tensor) { return toTensor(af::cos(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::sqrt(const Tensor& tensor) { +Tensor ArrayFireBackend::sqrt(Tensor const& tensor) { return toTensor(af::sqrt(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::tanh(const Tensor& tensor) { +Tensor ArrayFireBackend::tanh(Tensor const& tensor) { return toTensor(af::tanh(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::floor(const Tensor& tensor) { +Tensor ArrayFireBackend::floor(Tensor const& tensor) { return toTensor(af::floor(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::ceil(const Tensor& tensor) { +Tensor ArrayFireBackend::ceil(Tensor const& tensor) { return toTensor(af::ceil(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::rint(const Tensor& tensor) { +Tensor ArrayFireBackend::rint(Tensor const& tensor) { return toTensor(af::round(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::absolute(const Tensor& tensor) { +Tensor ArrayFireBackend::absolute(Tensor const& tensor) { return toTensor(af::abs(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::sigmoid(const Tensor& tensor) { +Tensor ArrayFireBackend::sigmoid(Tensor const& tensor) { return toTensor(af::sigmoid(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::erf(const Tensor& tensor) { +Tensor ArrayFireBackend::erf(Tensor const& tensor) { return toTensor(af::erf(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::flip(const Tensor& tensor, const unsigned dim) { +Tensor ArrayFireBackend::flip(Tensor const& tensor, unsigned const dim) { return toTensor( af::flip(toArray(tensor), dim), tensor.ndim() @@ -82,9 +82,9 @@ Tensor ArrayFireBackend::flip(const Tensor& tensor, const unsigned dim) { } Tensor ArrayFireBackend::clip( - const Tensor& tensor, - const Tensor& low, - const Tensor& high + Tensor const& tensor, + Tensor const& low, + Tensor const& high ) { return toTensor( af::clamp(toArray(tensor), toArray(low), toArray(high)), @@ -93,9 +93,9 @@ Tensor ArrayFireBackend::clip( } Tensor ArrayFireBackend::roll( - const Tensor& tensor, - const int shift, - const unsigned axis + Tensor const& tensor, + int const shift, + unsigned const axis ) { if(axis > AF_MAX_DIMS) throw std::invalid_argument( @@ -109,28 +109,28 @@ Tensor ArrayFireBackend::roll( ); } -Tensor ArrayFireBackend::isnan(const Tensor& tensor) { +Tensor ArrayFireBackend::isnan(Tensor const& tensor) { return toTensor(af::isNaN(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::isinf(const Tensor& tensor) { +Tensor ArrayFireBackend::isinf(Tensor const& tensor) { return toTensor(af::isInf(toArray(tensor)), tensor.ndim()); } -Tensor ArrayFireBackend::sign(const Tensor& tensor) { +Tensor ArrayFireBackend::sign(Tensor const& tensor) { auto wSigned = 1 - 2 * af::sign(toArray(tensor)); wSigned(toArray(tensor) == 0) = 0; return toTensor(std::move(wSigned), tensor.ndim()); } -Tensor ArrayFireBackend::tril(const Tensor& tensor) { +Tensor ArrayFireBackend::tril(Tensor const& tensor) { return toTensor( af::lower(toArray(tensor), /* is_unit_diag = */ false), tensor.ndim() ); } -Tensor ArrayFireBackend::triu(const Tensor& tensor) { +Tensor ArrayFireBackend::triu(Tensor const& tensor) { return toTensor( af::upper(toArray(tensor), /* is_unit_diag = */ false), tensor.ndim() diff --git a/flashlight/fl/tensor/backend/af/CMakeLists.txt b/flashlight/fl/tensor/backend/af/CMakeLists.txt index da4fbdb..9d0c826 100644 --- a/flashlight/fl/tensor/backend/af/CMakeLists.txt +++ b/flashlight/fl/tensor/backend/af/CMakeLists.txt @@ -1,11 +1,9 @@ cmake_minimum_required(VERSION 3.16) # ----------------------------- ArrayFire ----------------------------- -find_package(ArrayFire 3.7.3 REQUIRED) -if (ArrayFire_FOUND AND ArrayFire_VERSION VERSION_LESS 3.7.1) - message(FATAL_ERROR "ArrayFire versions < 3.8.1 are no longer supported " - "with flashlight. To build flashlight with a version of ArrayFire " - "< 3.7.1, use commit <= 5518d91b7f4fd5b400cbc802cfbecc0df57836bd.") +find_package(ArrayFire 3.10.0 REQUIRED) +if (ArrayFire_FOUND AND ArrayFire_VERSION VERSION_LESS 3.10.0) + message(FATAL_ERROR "ArrayFire < 3.10 is not supported") endif() if (ArrayFire_FOUND) diff --git a/flashlight/fl/tensor/backend/af/Utils.cpp b/flashlight/fl/tensor/backend/af/Utils.cpp index 8b11f58..4353a66 100644 --- a/flashlight/fl/tensor/backend/af/Utils.cpp +++ b/flashlight/fl/tensor/backend/af/Utils.cpp @@ -16,53 +16,52 @@ namespace fl::detail { af::dtype flToAfType(fl::dtype type) { - static const std::unordered_map - kFlashlightTypeToArrayFire = { - {fl::dtype::f16, af::dtype::f16}, - {fl::dtype::f32, af::dtype::f32}, - {fl::dtype::f64, af::dtype::f64}, - {fl::dtype::b8, af::dtype::b8}, - {fl::dtype::s16, af::dtype::s16}, - {fl::dtype::s32, af::dtype::s32}, - {fl::dtype::s64, af::dtype::s64}, - {fl::dtype::u8, af::dtype::u8}, - {fl::dtype::u16, af::dtype::u16}, - {fl::dtype::u32, af::dtype::u32}, - { - fl::dtype::u64, af::dtype::u64 - } - }; + static std::unordered_map const + kFlashlightTypeToArrayFire = { + {fl::dtype::f16, af::dtype::f16}, + {fl::dtype::f32, af::dtype::f32}, + {fl::dtype::f64, af::dtype::f64}, + {fl::dtype::b8, af::dtype::b8}, + {fl::dtype::s16, af::dtype::s16}, + {fl::dtype::s32, af::dtype::s32}, + {fl::dtype::s64, af::dtype::s64}, + {fl::dtype::u8, af::dtype::u8}, + {fl::dtype::u16, af::dtype::u16}, + {fl::dtype::u32, af::dtype::u32}, + { + fl::dtype::u64, + af::dtype::u64 + } + }; return kFlashlightTypeToArrayFire.at(type); } fl::dtype afToFlType(af::dtype type) { - static const std::unordered_map - kArrayFireTypeToFlashlight = { - {af::dtype::f16, fl::dtype::f16}, - {af::dtype::f32, fl::dtype::f32}, - {af::dtype::f64, fl::dtype::f64}, - {af::dtype::b8, fl::dtype::b8}, - {af::dtype::s16, fl::dtype::s16}, - {af::dtype::s32, fl::dtype::s32}, - {af::dtype::s64, fl::dtype::s64}, - {af::dtype::u8, fl::dtype::u8}, - {af::dtype::u16, fl::dtype::u16}, - {af::dtype::u32, fl::dtype::u32}, - { - af::dtype::u64, fl::dtype::u64 - } - }; + static std::unordered_map const + kArrayFireTypeToFlashlight = { + {af::dtype::f16, fl::dtype::f16}, + {af::dtype::f32, fl::dtype::f32}, + {af::dtype::f64, fl::dtype::f64}, + {af::dtype::b8, fl::dtype::b8}, + {af::dtype::s16, fl::dtype::s16}, + {af::dtype::s32, fl::dtype::s32}, + {af::dtype::s64, fl::dtype::s64}, + {af::dtype::u8, fl::dtype::u8}, + {af::dtype::u16, fl::dtype::u16}, + {af::dtype::u32, fl::dtype::u32}, + { + af::dtype::u64, + fl::dtype::u64 + } + }; return kArrayFireTypeToFlashlight.at(type); } af_mat_prop flToAfMatrixProperty(MatrixProperty property) { switch(property) { - case MatrixProperty::None: - return AF_MAT_NONE; - case MatrixProperty::Transpose: - return AF_MAT_TRANS; - default: - throw std::invalid_argument( + case MatrixProperty::None: return AF_MAT_NONE; + case MatrixProperty::Transpose: return AF_MAT_TRANS; + default: throw std::invalid_argument( "flToAfMatrixProperty: invalid property specified" ); } @@ -70,16 +69,11 @@ af_mat_prop flToAfMatrixProperty(MatrixProperty property) { af_storage flToAfStorageType(StorageType storageType) { switch(storageType) { - case StorageType::Dense: - return AF_STORAGE_DENSE; - case StorageType::CSR: - return AF_STORAGE_CSR; - case StorageType::CSC: - return AF_STORAGE_CSC; - case StorageType::COO: - return AF_STORAGE_COO; - default: - throw std::invalid_argument( + case StorageType::Dense: return AF_STORAGE_DENSE; + case StorageType::CSR: return AF_STORAGE_CSR; + case StorageType::CSC: return AF_STORAGE_CSC; + case StorageType::COO: return AF_STORAGE_COO; + default: throw std::invalid_argument( "flToAfStorageType: Flashlight storage type " "doesn't have an ArrayFire analog" ); @@ -88,18 +82,15 @@ af_storage flToAfStorageType(StorageType storageType) { af_topk_function flToAfTopKSortMode(SortMode sortMode) { switch(sortMode) { - case SortMode::Descending: - return AF_TOPK_MAX; - case SortMode::Ascending: - return AF_TOPK_MIN; - default: - throw std::invalid_argument( + case SortMode::Descending: return AF_TOPK_MAX; + case SortMode::Ascending: return AF_TOPK_MIN; + default: throw std::invalid_argument( "flToAfTopKSortMode: sort mode with no ArrayFire analog specified" ); } } -af::dim4 flToAfDims(const Shape& shape) { +af::dim4 flToAfDims(Shape const& shape) { if(shape.ndim() > 4) throw std::invalid_argument( "flToAfDims: ArrayFire shapes can't be more than 4 dimensions" @@ -111,7 +102,7 @@ af::dim4 flToAfDims(const Shape& shape) { return out; } -void afToFlDims(const af::dim4& d, const unsigned numDims, Shape& s) { +void afToFlDims(af::dim4 const& d, unsigned const numDims, Shape& s) { if(numDims > AF_MAX_DIMS) throw std::invalid_argument("afToFlDims - numDims > AF_MAX_DIMS"); @@ -136,16 +127,16 @@ void afToFlDims(const af::dim4& d, const unsigned numDims, Shape& s) { s[i] = d[i]; } -Shape afToFlDims(const af::dim4& d, const unsigned numDims) { +Shape afToFlDims(af::dim4 const& d, unsigned const numDims) { Shape s; afToFlDims(d, numDims, s); return s; } -af::seq flRangeToAfSeq(const fl::range& range) { - const int start = range.start(); - const auto& optEnd = range.end(); - const int end = optEnd.has_value() ? optEnd.value() - 1 : af::end; +af::seq flRangeToAfSeq(fl::range const& range) { + int const start = range.start(); + auto const& optEnd = range.end(); + int const end = optEnd.has_value() ? optEnd.value() - 1 : af::end; // There could be have other empty sequence representations, e.g., (0, -1) // for axis with 1 element. In those cases, AF will throw internally -- // we can't throw here because these cases axis-size dependent. @@ -156,24 +147,19 @@ af::seq flRangeToAfSeq(const fl::range& range) { return af::seq(start, end, range.stride()); } -af::index flToAfIndex(const fl::Index& idx) { +af::index flToAfIndex(fl::Index const& idx) { switch(idx.type()) { - case IndexType::Tensor: - return af::index(toArray(idx.get())); - case IndexType::Span: - return af::index(af::span); - case IndexType::Range: - return af::index(flRangeToAfSeq(idx.get())); - case IndexType::Literal: - return af::index(idx.get()); - default: - throw std::invalid_argument( + case IndexType::Tensor: return af::index(toArray(idx.get())); + case IndexType::Span: return af::index(af::span); + case IndexType::Range: return af::index(flRangeToAfSeq(idx.get())); + case IndexType::Literal: return af::index(idx.get()); + default: throw std::invalid_argument( "flToAfIndex: fl::Index has unknown or invalid type." ); } } -af::dim4 condenseDims(const af::dim4& dims) { +af::dim4 condenseDims(af::dim4 const& dims) { if(dims.elements() == 0) return af::dim4(0); @@ -190,10 +176,10 @@ af::dim4 condenseDims(const af::dim4& dims) { } af::array condenseIndices( - const af::array& arr, - const bool keepDims /* = false */, - const std::optional>& indexTypes /* = {} */, - const bool isFlat /* = false */ + af::array const& arr, + bool const keepDims /* = false */, + std::optional> const& indexTypes /* = {} */, + bool const isFlat /* = false */ ) { // Fast path - return the Array as is if keepDims - don't consolidate if(keepDims) @@ -202,8 +188,8 @@ af::array condenseIndices( if(arr.elements() == 0) return arr; - const af::dim4& dims = arr.dims(); - af::dim4 newDims(1, 1, 1, 1); + af::dim4 const& dims = arr.dims(); + af::dim4 newDims{1, 1, 1, 1}; unsigned newDimIdx = 0; for(unsigned i = 0; i < AF_MAX_DIMS; ++i) { // If we're doing an index op (indexTypes is non-empty), then only collapse @@ -215,7 +201,8 @@ af::array condenseIndices( ) { newDims[newDimIdx] = 1; newDimIdx++; - } else if(dims[i] != 1) { + } + else if(dims[i] != 1) { // found a non-1 dim size - populate newDims. newDims[newDimIdx] = dims[i]; newDimIdx++; @@ -225,81 +212,61 @@ af::array condenseIndices( // Only change dims if condensing is possible if(newDims != arr.dims()) return af::moddims(arr, newDims); - else - return arr; + + + return arr; } af_source flToAfLocation(Location location) { switch(location) { - case Location::Host: - return afHost; - case Location::Device: - return afDevice; - default: - throw std::invalid_argument( + case Location::Host: return afHost; + case Location::Device: return afDevice; + default: throw std::invalid_argument{ "flToAfLocation: no valid ArrayFire location exists " " for given Flashlight location." - ); + }; } } af::array fromFlData( - const Shape& shape, - const void* ptr, + Shape const& shape, + void const* ptr, fl::dtype type, fl::Location memoryLocation ) { - af::dim4 dims = detail::flToAfDims(shape); - af::dtype afType = detail::flToAfType(type); + auto dims = detail::flToAfDims(shape); + auto const afType = detail::flToAfType(type); af_source loc = detail::flToAfLocation(memoryLocation); // No or null buffer if(!ptr) - return af::array(dims, afType); - - using af::dtype; - switch(afType) { - case f32: - return af::array(dims, reinterpret_cast(ptr), loc); - case f64: - return af::array(dims, reinterpret_cast(ptr), loc); - case s32: - return af::array(dims, reinterpret_cast(ptr), loc); - case u32: - return af::array(dims, reinterpret_cast(ptr), loc); - case s64: - return af::array(dims, reinterpret_cast(ptr), loc); - case u64: - return af::array( - dims, - reinterpret_cast(ptr), - loc - ); - case s16: - return af::array(dims, reinterpret_cast(ptr), loc); - case u16: - return af::array(dims, reinterpret_cast(ptr), loc); - case b8: - return af::array(dims, reinterpret_cast(ptr), loc); - case u8: - return af::array(dims, reinterpret_cast(ptr), loc); - default: - throw std::invalid_argument( - "fromFlData: can't construct ArrayFire array from given type." - ); - } + return af::array{dims, afType}; + + + using af_fundamental_types = std::tuple< + float, + double, + char, + signed short, + int, + long long, + unsigned char, + unsigned short, + unsigned int, + unsigned long long>; + + return dispatch_dtype( + type, + [&]() { return af::array{dims, static_cast(ptr), loc}; } + ); } af_border_type flToAfPadType(PadType type) { switch(type) { - case PadType::Constant: - return AF_PAD_ZERO; // constant padding --> zero padding in AF - case PadType::Edge: - return AF_PAD_CLAMP_TO_EDGE; - case PadType::Symmetric: - return AF_PAD_SYM; - default: - throw std::invalid_argument( + case PadType::Constant: return AF_PAD_ZERO; // constant padding --> zero padding in AF + case PadType::Edge: return AF_PAD_CLAMP_TO_EDGE; + case PadType::Symmetric: return AF_PAD_SYM; + default: throw std::invalid_argument( "flToAfPadType: Flashlight padding " "type not supported by ArrayFire" ); diff --git a/flashlight/fl/tensor/backend/af/Utils.h b/flashlight/fl/tensor/backend/af/Utils.h index 0b299ba..3c073a0 100644 --- a/flashlight/fl/tensor/backend/af/Utils.h +++ b/flashlight/fl/tensor/backend/af/Utils.h @@ -65,34 +65,34 @@ namespace detail { /** * Convert an fl::Shape into an ArrayFire af::dim4 */ - af::dim4 flToAfDims(const Shape& shape); + af::dim4 flToAfDims(Shape const& shape); /** * Convert an ArrayFire af::dim4 into an fl::Shape */ - Shape afToFlDims(const af::dim4& d, const unsigned numDims); + Shape afToFlDims(af::dim4 const& d, unsigned const numDims); /** * Convert an ArrayFire af::dim4 into an fl::Shape, in-place */ - void afToFlDims(const af::dim4& d, const unsigned numDims, Shape& s); + void afToFlDims(af::dim4 const& d, unsigned const numDims, Shape& s); /** * Convert an fl::range into an af::seq. */ - af::seq flRangeToAfSeq(const fl::range& range); + af::seq flRangeToAfSeq(fl::range const& range); /** * Convert an fl::Index into an af::index. */ - af::index flToAfIndex(const fl::Index& idx); + af::index flToAfIndex(fl::Index const& idx); - std::vector flToAfIndices(const std::vector& flIndices); + std::vector flToAfIndices(std::vector const& flIndices); /** * Strip leading 1 indices from an ArrayFire dim4. */ - af::dim4 condenseDims(const af::dim4& dims); + af::dim4 condenseDims(af::dim4 const& dims); /** * Modify the dimensions (in place via af::moddims) or an Array to have no 1 @@ -104,10 +104,10 @@ namespace detail { * If keepDims is true, this is a noop, and the array is returned as is. */ af::array condenseIndices( - const af::array& arr, - const bool keepDims = false, - const std::optional>& indexTypes = {}, - const bool isFlat = false + af::array const& arr, + bool const keepDims = false, + std::optional> const& indexTypes = {}, + bool const isFlat = false ); /** @@ -119,8 +119,8 @@ namespace detail { * Construct an ArrayFire array from a buffer and Flashlight details. */ af::array fromFlData( - const Shape& shape, - const void* ptr, + Shape const& shape, + void const* ptr, fl::dtype type, fl::Location memoryLocation ); diff --git a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp index 2014b77..18b7528 100644 --- a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp +++ b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp @@ -33,9 +33,9 @@ namespace { 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer constexpr size_t kRoundLarge = 2097152; // round up large allocs to 2 MiB -// Environment variables names, specifying number of mega bytes as floats. - constexpr const char* kMemRecyclingSize = "FL_MEM_RECYCLING_SIZE_MB"; - constexpr const char* kMemSplitSize = "FL_MEM_SPLIT_SIZE_MB"; + // Environment variables names, specifying number of mega bytes as floats. + constexpr auto kMemRecyclingSize = "FL_MEM_RECYCLING_SIZE_MB"; + constexpr auto kMemSplitSize = "FL_MEM_SPLIT_SIZE_MB"; constexpr double kMB = static_cast(1UL << 20); size_t roundSize(size_t size) { @@ -55,16 +55,16 @@ namespace { } static bool BlockComparator( - const CachingMemoryManager::Block* a, - const CachingMemoryManager::Block* b + CachingMemoryManager::Block const* a, + CachingMemoryManager::Block const* b ) { if(a->size_ != b->size_) return a->size_ < b->size_; - return (uintptr_t) a->ptr_ < (uintptr_t) b->ptr_; + return reinterpret_cast(a->ptr_) < reinterpret_cast(b->ptr_); } std::string formatMemory(size_t bytes) { - const std::vector units = {"B", "KiB", "MiB", "GiB", "TiB"}; + std::vector const units = {"B", "KiB", "MiB", "GiB", "TiB"}; size_t unitId = bytes == 0 ? 0 : std::floor(std::log(bytes) / std::log(1024.0)); unitId = std::min(unitId, units.size() - 1); @@ -73,20 +73,21 @@ namespace { return bytesStr + " " + units[unitId]; } -/** - * Returns number of bytes as represented by the named environment variable. The - * variable is interperested as a float string specifying value in MBs. Returns - * defaultVal on failure to read the variable or parse its value. - */ - size_t getEnvAsBytesFromFloatMb(const char* name, size_t defaultVal) { - const char* env = std::getenv(name); + /** + * Returns number of bytes as represented by the named environment variable. The + * variable is interperested as a float string specifying value in MBs. Returns + * defaultVal on failure to read the variable or parse its value. + */ + size_t getEnvAsBytesFromFloatMb(char const* name, size_t defaultVal) { + char const* env = std::getenv(name); if(env) { try { - const double mb = std::stod(env); + double const mb = std::stod(env); return std::round(mb * kMB); - } catch(std::exception& ex) { + } + catch(std::exception& ex) { std::cerr << "getEnvAsBytesFromFloatMb: Invalid environment " - << "variable value: name=" << name << " value=" << env; + << "variable value: name=" << name << " value=" << env; throw ex; } } @@ -96,8 +97,8 @@ namespace { } // namespace CachingMemoryManager::DeviceMemoryInfo::DeviceMemoryInfo(int id) : deviceId_(id), - largeBlocks_(BlockComparator), - smallBlocks_(BlockComparator) {} + largeBlocks_(BlockComparator), + smallBlocks_(BlockComparator) {} CachingMemoryManager::CachingMemoryManager( int numDevices, @@ -139,19 +140,19 @@ void CachingMemoryManager::removeMemoryManagement(int device) { void* CachingMemoryManager::alloc( bool userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elementSize + unsigned const ndims, + ::dim_t* dims, + unsigned const elementSize ) { auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); + std::scoped_lock lock(memoryInfo.mutexAll_); size_t size = elementSize; for(unsigned i = 0; i < ndims; ++i) size *= dims[i]; if(size == 0) return nullptr; size = roundSize(size); - const bool isSmallAlloc = (size <= kSmallSize); + bool const isSmallAlloc = (size <= kSmallSize); CachingMemoryManager::Block searchKey(size); CachingMemoryManager::BlockSet& pool = isSmallAlloc ? memoryInfo.smallBlocks_ : memoryInfo.largeBlocks_; @@ -167,7 +168,8 @@ void* CachingMemoryManager::alloc( block = *it; pool.erase(it); memoryInfo.stats_.cachedBytes_ -= block->size_; - } else { + } + else { void* ptr = nullptr; size_t allocSize = getAllocationSize(size); mallocWithRetry(allocSize, &ptr); // could throw @@ -184,7 +186,7 @@ void* CachingMemoryManager::alloc( if( (diff >= (isSmallAlloc ? kMinBlockSize : kSmallSize)) && (block->size_ < splitSizeLimit_) // possibly dont split large buffers to - // minimize risk of fragmentation + // minimize risk of fragmentation ) { remaining = block; block = new Block(size, block->ptr_); @@ -210,7 +212,7 @@ size_t CachingMemoryManager::allocated(void* ptr) { if(!ptr) return 0; auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); + std::scoped_lock _(memoryInfo.mutexAll_); auto it = memoryInfo.allocatedBlocks_.find(ptr); if(it == memoryInfo.allocatedBlocks_.end()) return 0; @@ -221,7 +223,7 @@ void CachingMemoryManager::unlock(void* ptr, bool userUnlock) { if(!ptr) return; auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); + std::scoped_lock lock(memoryInfo.mutexAll_); auto it = memoryInfo.allocatedBlocks_.find(ptr); if(it == memoryInfo.allocatedBlocks_.end()) { // Probably came from user, just free it @@ -247,9 +249,9 @@ void CachingMemoryManager::freeBlock(CachingMemoryManager::Block* block) { if(block->inUse()) throw std::runtime_error("trying to free a block which is in use"); auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); + std::scoped_lock lock(memoryInfo.mutexAll_); - const bool isSmallAlloc = (block->size_ <= kSmallSize); + bool const isSmallAlloc = (block->size_ <= kSmallSize); CachingMemoryManager::BlockSet& pool = isSmallAlloc ? memoryInfo.smallBlocks_ : memoryInfo.largeBlocks_; tryMergeBlocks(block, block->prev_, pool); @@ -272,7 +274,8 @@ void CachingMemoryManager::tryMergeBlocks( dst->prev_ = src->prev_; if(dst->prev_) dst->prev_->next_ = dst; - } else { + } + else { dst->next_ = src->next_; if(dst->next_) dst->next_->prev_ = dst; @@ -290,24 +293,26 @@ void CachingMemoryManager::mallocWithRetry(size_t size, void** ptr) { try { ++memInfo.stats_.totalNativeMallocs_; *ptr = this->deviceInterface->nativeAlloc(size); - } catch(std::exception&) { + } + catch(std::exception&) { try { signalMemoryCleanup(); ++memInfo.stats_.totalNativeMallocs_; *ptr = this->deviceInterface->nativeAlloc(size); - } catch(std::exception& ex) { + } + catch(std::exception& ex) { // note: af exception inherits from std exception std::cerr << "Failed to allocate memory of size " << formatMemory(size) - << " (Device: " << memInfo.deviceId_ << ", Capacity: " - << formatMemory( - this->deviceInterface->getMaxMemorySize( - memInfo.deviceId_ + << " (Device: " << memInfo.deviceId_ << ", Capacity: " + << formatMemory( + this->deviceInterface->getMaxMemorySize( + memInfo.deviceId_ + ) ) - ) - << ", Allocated: " - << formatMemory(memInfo.stats_.allocatedBytes_) - << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) - << ") with error '" << ex.what() << "'" << std::endl; + << ", Allocated: " + << formatMemory(memInfo.stats_.allocatedBytes_) + << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) + << ") with error '" << ex.what() << "'" << std::endl; // note: converting here an af exception to std exception prevents to // catch the af error code at the user level. Rethrowing. throw; @@ -333,7 +338,8 @@ void CachingMemoryManager::freeBlocks( ++it; blocks.erase(cur); delete block; - } else + } + else ++it; } } @@ -341,7 +347,7 @@ void CachingMemoryManager::freeBlocks( void CachingMemoryManager::signalMemoryCleanup() { // Free all non-split cached blocks on device auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); + std::scoped_lock lock(memoryInfo.mutexAll_); freeBlocks( memoryInfo.largeBlocks_, @@ -365,32 +371,32 @@ bool CachingMemoryManager::jitTreeExceedsMemoryPressure(size_t /* unused */) { } void CachingMemoryManager::printInfo( - const char* msg, - const int /* unused */, + char const* msg, + int const /* unused */, std::ostream* _ostream ) { std::ostream& ostream = *_ostream; auto& memInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memInfo.mutexAll_); + std::scoped_lock lock(memInfo.mutexAll_); ostream << msg << "\nType: CachingMemoryManager" << std::endl - << "\nDevice: " << memInfo.deviceId_ << ", Capacity: " - << formatMemory( - this->deviceInterface->getMaxMemorySize(memInfo.deviceId_) + << "\nDevice: " << memInfo.deviceId_ << ", Capacity: " + << formatMemory( + this->deviceInterface->getMaxMemorySize(memInfo.deviceId_) ) - << ", Allocated: " << formatMemory(memInfo.stats_.allocatedBytes_) - << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) - << std::endl - << "\nTotal native calls: " << memInfo.stats_.totalNativeMallocs_ - << "(mallocs), " << memInfo.stats_.totalNativeFrees_ << "(frees)" - << std::endl; + << ", Allocated: " << formatMemory(memInfo.stats_.allocatedBytes_) + << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) + << std::endl + << "\nTotal native calls: " << memInfo.stats_.totalNativeMallocs_ + << "(mallocs), " << memInfo.stats_.totalNativeFrees_ << "(frees)" + << std::endl; } -void CachingMemoryManager::userLock(const void* ptr) { +void CachingMemoryManager::userLock(void const* ptr) { if(!ptr) return; auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); + std::scoped_lock lock(memoryInfo.mutexAll_); auto it = memoryInfo.allocatedBlocks_.find(const_cast(ptr)); if(it == memoryInfo.allocatedBlocks_.end()) { @@ -399,17 +405,18 @@ void CachingMemoryManager::userLock(const void* ptr) { block->managerLock_ = false; block->userLock_ = true; memoryInfo.allocatedBlocks_[block->ptr_] = block; - } else + } + else it->second->userLock_ = true; } -void CachingMemoryManager::userUnlock(const void* ptr) { this->unlock(const_cast(ptr), true); } +void CachingMemoryManager::userUnlock(void const* ptr) { this->unlock(const_cast(ptr), true); } -bool CachingMemoryManager::isUserLocked(const void* ptr) { +bool CachingMemoryManager::isUserLocked(void const* ptr) { if(!ptr) return false; auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); + std::scoped_lock lock(memoryInfo.mutexAll_); auto it = memoryInfo.allocatedBlocks_.find(const_cast(ptr)); if(it == memoryInfo.allocatedBlocks_.end()) return false; diff --git a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.h b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.h index 57ea899..5a34ec7 100644 --- a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.h +++ b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.h @@ -39,20 +39,20 @@ class CachingMemoryManager : public MemoryManagerAdapter { void shutdown() override; void* alloc( bool userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elSize + unsigned ndims, + ::dim_t* dims, + unsigned elSize ) override; size_t allocated(void* ptr) override; void unlock(void* ptr, bool userLock) override; void printInfo( - const char* msg, - const int device, + char const* msg, + int device, std::ostream* ostream = & std::cout ) override; - void userLock(const void* ptr) override; - void userUnlock(const void* ptr) override; - bool isUserLocked(const void* ptr) override; + void userLock(void const* ptr) override; + void userUnlock(void const* ptr) override; + bool isUserLocked(void const* ptr) override; void signalMemoryCleanup() override; float getMemoryPressure() override; bool jitTreeExceedsMemoryPressure(size_t bytes) override; @@ -89,7 +89,7 @@ class CachingMemoryManager : public MemoryManagerAdapter { next_(nullptr) {} }; - typedef bool (*Comparison)(const Block*, const Block*); + typedef bool (*Comparison)(Block const*, Block const*); typedef std::set BlockSet; // A structure to store allocation stats per device. @@ -131,9 +131,9 @@ class CachingMemoryManager : public MemoryManagerAdapter { protected: std::unordered_map> deviceMemInfos_; - CachingMemoryManager(const CachingMemoryManager& other) = delete; + CachingMemoryManager(CachingMemoryManager const& other) = delete; CachingMemoryManager(CachingMemoryManager&& other) = delete; - CachingMemoryManager& operator=(const CachingMemoryManager& other) = delete; + CachingMemoryManager& operator=(CachingMemoryManager const& other) = delete; CachingMemoryManager& operator=(CachingMemoryManager&& other) = delete; // Returns the memory info of the caching allocator for the given device. diff --git a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp index 3a7497c..9bc1fe5 100644 --- a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp +++ b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp @@ -41,7 +41,7 @@ void DefaultMemoryManager::cleanDeviceMemoryManager(int device) { size_t bytesFreed = 0; MemoryInfo& current = memory[device]; { - std::lock_guard lock(this->memoryMutex); + std::scoped_lock lock(this->memoryMutex); // Return if all buffers are locked if(current.totalBuffers == current.lockBuffers) return; @@ -65,7 +65,7 @@ void DefaultMemoryManager::cleanDeviceMemoryManager(int device) { std::stringstream ss; ss << "GC: Clearing " << freePtrs.size() << " buffers |" - << std::to_string(bytesFreed) << " bytes"; + << std::to_string(bytesFreed) << " bytes"; this->log(ss.str()); // Free memory outside of the lock @@ -85,13 +85,13 @@ DefaultMemoryManager::DefaultMemoryManager( memory(numDevices) { // Check for environment variables // Debug mode - if(const char* c = std::getenv("AF_MEM_DEBUG")) + if(char const* c = std::getenv("AF_MEM_DEBUG")) this->debugMode = (std::string(c) != "0"); if(this->debugMode) memStepSize = 1; // Max Buffer count - if(const char* c = std::getenv("AF_MAX_BUFFERS")) + if(char const* c = std::getenv("AF_MAX_BUFFERS")) this->maxBuffers = std::max(1, std::stoi(std::string(c))); } @@ -127,16 +127,16 @@ void DefaultMemoryManager::setMaxMemorySize() { // memsize returned 0, then use 1GB size_t memsize = this->deviceInterface->getMaxMemorySize(n); memory[n].maxBytes = memsize == 0 - ? ONE_GB - : std::max(memsize * 0.75, static_cast(memsize - ONE_GB)); + ? ONE_GB + : std::max(memsize * 0.75, static_cast(memsize - ONE_GB)); } } void* DefaultMemoryManager::alloc( bool userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elementSize + unsigned const ndims, + ::dim_t* dims, + unsigned const elementSize ) { size_t bytes = elementSize; for(unsigned i = 0; i < ndims; ++i) @@ -160,7 +160,7 @@ void* DefaultMemoryManager::alloc( ) this->signalMemoryCleanup(); - std::lock_guard lock(this->memoryMutex); + std::scoped_lock lock(this->memoryMutex); free_iter iter = current.freeMap.find(allocBytes); if(iter != current.freeMap.end() && !iter->second.empty()) { @@ -176,9 +176,8 @@ void* DefaultMemoryManager::alloc( // Only comes here if buffer size not found or in debug mode if(ptr == nullptr) { // Perform garbage collection if memory can not be allocated - try { - ptr = this->deviceInterface->nativeAlloc(allocBytes); - } catch(std::exception&) { + try { ptr = this->deviceInterface->nativeAlloc(allocBytes); } + catch(std::exception&) { // FIXME: assume that the exception is due to out of memory, and don't // continue propagating it // If out of memory, run garbage collect and try again @@ -188,7 +187,7 @@ void* DefaultMemoryManager::alloc( this->signalMemoryCleanup(); ptr = this->deviceInterface->nativeAlloc(allocBytes); } - std::lock_guard lock(this->memoryMutex); + std::scoped_lock lock(this->memoryMutex); // Increment these two only when it succeeds to come here. current.totalBytes += allocBytes; current.totalBuffers += 1; @@ -204,7 +203,7 @@ size_t DefaultMemoryManager::allocated(void* ptr) { if(!ptr) return 0; MemoryInfo& current = this->getCurrentMemoryInfo(); - locked_iter iter = current.lockedMap.find((void*) ptr); + auto iter = current.lockedMap.find((void*) ptr); if(iter == current.lockedMap.end()) return 0; return (iter->second).bytes; @@ -217,12 +216,14 @@ void DefaultMemoryManager::unlock(void* ptr, bool userUnlock) { // Frees the pointer outside the lock. uptr_t freedPtr( - nullptr, [this](void* p) { this->deviceInterface->nativeFree(p); }); + nullptr, + [this](void* p) { this->deviceInterface->nativeFree(p); } + ); { - std::lock_guard lock(this->memoryMutex); + std::scoped_lock lock(this->memoryMutex); MemoryInfo& current = this->getCurrentMemoryInfo(); - locked_iter iter = current.lockedMap.find((void*) ptr); + auto iter = current.lockedMap.find((void*) ptr); // Pointer not found in locked map if(iter == current.lockedMap.end()) { @@ -251,7 +252,8 @@ void DefaultMemoryManager::unlock(void* ptr, bool userUnlock) { current.totalBuffers--; current.totalBytes -= iter->second.bytes; } - } else + } + else current.freeMap.at(bytes).emplace_back(ptr); current.lockedMap.erase(iter); } @@ -262,8 +264,8 @@ void DefaultMemoryManager::signalMemoryCleanup() { } float DefaultMemoryManager::getMemoryPressure() { - std::lock_guard lock(this->memoryMutex); - MemoryInfo& current = this->getCurrentMemoryInfo(); + std::scoped_lock lock(this->memoryMutex); + MemoryInfo const& current = this->getCurrentMemoryInfo(); if( current.lockBytes > current.maxBytes || current.lockBuffers > maxBuffers @@ -274,34 +276,34 @@ float DefaultMemoryManager::getMemoryPressure() { } bool DefaultMemoryManager::jitTreeExceedsMemoryPressure(size_t bytes) { - std::lock_guard lock(this->memoryMutex); + std::scoped_lock lock(this->memoryMutex); MemoryInfo& current = this->getCurrentMemoryInfo(); return 2 * bytes > current.lockBytes; } void DefaultMemoryManager::printInfo( - const char* msg, - const int /* device */, + char const* msg, + int const /* device */, std::ostream* _ostream ) { std::ostream& ostream = *_ostream; - const MemoryInfo& current = this->getCurrentMemoryInfo(); + MemoryInfo const& current = this->getCurrentMemoryInfo(); ostream << msg << std::endl - << "---------------------------------------------------------\n" - << "| POINTER | SIZE | AF LOCK | USER LOCK |\n" - << "---------------------------------------------------------\n"; + << "---------------------------------------------------------\n" + << "| POINTER | SIZE | AF LOCK | USER LOCK |\n" + << "---------------------------------------------------------\n"; - std::lock_guard lock(this->memoryMutex); + std::scoped_lock lock(this->memoryMutex); for(auto& kv : current.lockedMap) { - const char* statusMngr = "Yes"; - const char* statusUser = "Unknown"; + char const* statusMngr = "Yes"; + char const* statusUser = "Unknown"; if(kv.second.userLock) statusUser = "Yes"; else statusUser = " No"; - const char* unit = "KB"; + char const* unit = "KB"; double size = static_cast(kv.second.bytes) / 1024; if(size >= 1024) { size = size / 1024; @@ -309,14 +311,14 @@ void DefaultMemoryManager::printInfo( } ostream << "| " << kv.first << " | " << size << " " << unit << " | " - << statusMngr << " | " << statusUser << " |\n"; + << statusMngr << " | " << statusUser << " |\n"; } for(auto& kv : current.freeMap) { - const char* statusMngr = "No"; - const char* statusUser = "No"; + char const* statusMngr = "No"; + char const* statusUser = "No"; - const char* unit = "KB"; + char const* unit = "KB"; double size = static_cast(kv.first) / 1024; if(size >= 1024) { size = size / 1024; @@ -325,16 +327,16 @@ void DefaultMemoryManager::printInfo( for(auto& ptr : kv.second) ostream << "| " << ptr << " | " << size << " " << unit << " | " - << statusMngr << " | " << statusUser << " |\n"; + << statusMngr << " | " << statusUser << " |\n"; } ostream << "---------------------------------------------------------\n"; } -void DefaultMemoryManager::userLock(const void* ptr) { +void DefaultMemoryManager::userLock(void const* ptr) { MemoryInfo& current = this->getCurrentMemoryInfo(); - std::lock_guard lock(this->memoryMutex); + std::scoped_lock lock(this->memoryMutex); locked_iter iter = current.lockedMap.find(const_cast(ptr)); if(iter != current.lockedMap.end()) @@ -346,11 +348,11 @@ void DefaultMemoryManager::userLock(const void* ptr) { } } -void DefaultMemoryManager::userUnlock(const void* ptr) { this->unlock(const_cast(ptr), true); } +void DefaultMemoryManager::userUnlock(void const* ptr) { this->unlock(const_cast(ptr), true); } -bool DefaultMemoryManager::isUserLocked(const void* ptr) { +bool DefaultMemoryManager::isUserLocked(void const* ptr) { MemoryInfo& current = this->getCurrentMemoryInfo(); - std::lock_guard lock(this->memoryMutex); + std::scoped_lock lock(this->memoryMutex); locked_iter iter = current.lockedMap.find(const_cast(ptr)); if(iter != current.lockedMap.end()) return iter->second.userLock; @@ -359,26 +361,26 @@ bool DefaultMemoryManager::isUserLocked(const void* ptr) { } size_t DefaultMemoryManager::getMemStepSize() { - std::lock_guard lock(this->memoryMutex); + std::scoped_lock lock(this->memoryMutex); return this->memStepSize; } -void DefaultMemoryManager::setMemStepSize(size_t new_step_size) { - std::lock_guard lock(this->memoryMutex); - this->memStepSize = new_step_size; +void DefaultMemoryManager::setMemStepSize(size_t newStepSize) { + std::scoped_lock lock(this->memoryMutex); + this->memStepSize = newStepSize; } size_t DefaultMemoryManager::getMaxBytes() { - std::lock_guard lock(this->memoryMutex); + std::scoped_lock lock(this->memoryMutex); return this->getCurrentMemoryInfo().maxBytes; } unsigned DefaultMemoryManager::getMaxBuffers() { return this->maxBuffers; } bool DefaultMemoryManager::checkMemoryLimit() { - const MemoryInfo& current = this->getCurrentMemoryInfo(); + MemoryInfo const& current = this->getCurrentMemoryInfo(); return current.lockBytes >= current.maxBytes - || current.totalBuffers >= this->maxBuffers; + || current.totalBuffers >= this->maxBuffers; } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.h b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.h index 7fd2134..6ef3bc4 100644 --- a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.h +++ b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.h @@ -88,20 +88,20 @@ class DefaultMemoryManager : public MemoryManagerAdapter { void shutdown() override; void* alloc( bool userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elSize + unsigned ndims, + ::dim_t* dims, + unsigned elSize ) override; size_t allocated(void* ptr) override; void unlock(void* ptr, bool userLock) override; void printInfo( - const char* msg, - const int device, + char const* msg, + int device, std::ostream* ostream = & std::cout ) override; - void userLock(const void* ptr) override; - void userUnlock(const void* ptr) override; - bool isUserLocked(const void* ptr) override; + void userLock(void const* ptr) override; + void userUnlock(void const* ptr) override; + bool isUserLocked(void const* ptr) override; void signalMemoryCleanup() override; float getMemoryPressure() override; bool jitTreeExceedsMemoryPressure(size_t bytes) override; @@ -110,16 +110,16 @@ class DefaultMemoryManager : public MemoryManagerAdapter { // Implementation-specific functions void setMaxMemorySize(); size_t getMemStepSize() override; - void setMemStepSize(size_t size) override; + void setMemStepSize(size_t newStepSize) override; size_t getMaxBytes(); unsigned getMaxBuffers(); bool checkMemoryLimit(); protected: - DefaultMemoryManager(const DefaultMemoryManager& other) = delete; - DefaultMemoryManager(const DefaultMemoryManager&& other) = delete; - DefaultMemoryManager& operator=(const DefaultMemoryManager& other) = delete; - DefaultMemoryManager& operator=(const DefaultMemoryManager&& other) = delete; + DefaultMemoryManager(DefaultMemoryManager const& other) = delete; + DefaultMemoryManager(DefaultMemoryManager const&& other) = delete; + DefaultMemoryManager& operator=(DefaultMemoryManager const& other) = delete; + DefaultMemoryManager& operator=(DefaultMemoryManager const&& other) = delete; std::mutex memoryMutex; // backend-specific diff --git a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h index aa31e09..2e3ce5c 100644 --- a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h +++ b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h @@ -23,7 +23,7 @@ namespace fl { namespace { - const size_t kDefaultLogFlushInterval = 50; + size_t const kDefaultLogFlushInterval = 50; } // namespace @@ -74,21 +74,21 @@ class MemoryManagerAdapter { virtual void shutdown() = 0; virtual void* alloc( bool userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elSize + unsigned ndims, + ::dim_t* dims, + unsigned elSize ) = 0; virtual size_t allocated(void* ptr) = 0; virtual void unlock(void* ptr, bool userLock) = 0; virtual void signalMemoryCleanup() = 0; virtual void printInfo( - const char* msg, - const int device, - std::ostream* ostream = & std::cout + char const* msg, + int device, + std::ostream* ostream = &std::cout ) = 0; - virtual void userLock(const void* ptr) = 0; - virtual void userUnlock(const void* ptr) = 0; - virtual bool isUserLocked(const void* ptr) = 0; + virtual void userLock(void const* ptr) = 0; + virtual void userUnlock(void const* ptr) = 0; + virtual bool isUserLocked(void const* ptr) = 0; virtual float getMemoryPressure() = 0; virtual bool jitTreeExceedsMemoryPressure(size_t bytes) = 0; virtual void addMemoryManagement(int device) = 0; @@ -148,7 +148,7 @@ class MemoryManagerAdapter { af_memory_manager getHandle() const; // Native and device memory management functions - const std::shared_ptr deviceInterface; + std::shared_ptr const deviceInterface; protected: // AF memory manager entity containing relevant function pointers diff --git a/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp b/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp index b34aa55..0279566 100644 --- a/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp +++ b/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp @@ -18,6 +18,8 @@ namespace fl { +using af_dim_t = ::dim_t; + // Statics from MemoryManagerInstaller std::shared_ptr MemoryManagerInstaller::currentlyInstalledMemoryManager_; @@ -27,7 +29,7 @@ MemoryManagerAdapter* MemoryManagerInstaller::getImpl( ) { void* ptr; AF_CHECK(af_memory_manager_get_payload(manager, &ptr)); - return (MemoryManagerAdapter*) ptr; + return static_cast(ptr); } MemoryManagerInstaller::MemoryManagerInstaller( @@ -48,70 +50,75 @@ MemoryManagerInstaller::MemoryManagerInstaller( // Set appropriate function pointers for each class method auto initializeFn = [](af_memory_manager manager) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("initialize"); - m->initialize(); - return AF_SUCCESS; - }; + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("initialize"); + m->initialize(); + return AF_SUCCESS; + }; AF_CHECK(af_memory_manager_set_initialize_fn(itf, initializeFn)); auto shutdownFn = [](af_memory_manager manager) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("shutdown"); - m->shutdown(); - return AF_SUCCESS; - }; + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("shutdown"); + m->shutdown(); + return AF_SUCCESS; + }; AF_CHECK(af_memory_manager_set_shutdown_fn(itf, shutdownFn)); // ArrayFire expects the memory managers alloc fn to return an af_err, not to // throw, if a problem with allocation occurred - auto allocFn = [](af_memory_manager manager, + auto allocFn = []( + af_memory_manager manager, void** ptr, - /* bool */ int userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elSize) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - try { - *ptr = m->alloc(userLock, ndims, dims, elSize); - } catch(af::exception& ex) { - m->log( - "allocFn: alloc failed with af exception " - + std::to_string(ex.err()) - ); - return ex.err(); // AF_ERR_NO_MEM, ... - } catch(...) { - m->log("allocFn: alloc failed with unspecified exception"); - return af_err(AF_ERR_UNKNOWN); - } - // Log + /* bool */ + int userLock, + unsigned const ndims, + af_dim_t* dims, + unsigned const elSize + ) { + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + try { *ptr = m->alloc(userLock, ndims, dims, elSize); } + catch(af::exception& ex) { m->log( - "alloc", - /* size */ dims[0], // HACK: dims[0] until af::memAlloc is size-aware - userLock, - (std::uintptr_t) *ptr + "allocFn: alloc failed with af exception " + + std::to_string(ex.err()) ); - return AF_SUCCESS; - }; + return ex.err(); // AF_ERR_NO_MEM, ... + } + catch(...) { + m->log("allocFn: alloc failed with unspecified exception"); + return static_cast(AF_ERR_UNKNOWN); + } + // Log + m->log( + "alloc", + /* size */ + dims[0], + // HACK: dims[0] until af::memAlloc is size-aware + userLock, + reinterpret_cast(*ptr) + ); + return AF_SUCCESS; + }; AF_CHECK(af_memory_manager_set_alloc_fn(itf, allocFn)); auto allocatedFn = [](af_memory_manager manager, size_t* size, void* ptr) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("allocated", (std::uintptr_t) ptr); - *size = m->allocated(ptr); - return AF_SUCCESS; - }; + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("allocated", reinterpret_cast(ptr)); + *size = m->allocated(ptr); + return AF_SUCCESS; + }; AF_CHECK(af_memory_manager_set_allocated_fn(itf, allocatedFn)); auto unlockFn = [](af_memory_manager manager, void* ptr, int userLock) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("unlock", (std::uintptr_t) ptr, userLock); - m->unlock(ptr, (bool) userLock); - return AF_SUCCESS; - }; + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("unlock", reinterpret_cast(ptr), userLock); + m->unlock(ptr, static_cast(userLock)); + return AF_SUCCESS; + }; AF_CHECK(af_memory_manager_set_unlock_fn(itf, unlockFn)); auto signalMemoryCleanupFn = [](af_memory_manager manager) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("signalMemoryCleanup"); - m->signalMemoryCleanup(); - return AF_SUCCESS; - }; + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("signalMemoryCleanup"); + m->signalMemoryCleanup(); + return AF_SUCCESS; + }; AF_CHECK( af_memory_manager_set_signal_memory_cleanup_fn( itf, @@ -119,46 +126,46 @@ MemoryManagerInstaller::MemoryManagerInstaller( ) ); auto printInfoFn = [](af_memory_manager manager, char* msg, int device) { - // no log - auto* adapter = MemoryManagerInstaller::getImpl(manager); - adapter->printInfo(msg, device, adapter->getLogStream()); - return AF_SUCCESS; - }; + // no log + auto* adapter = MemoryManagerInstaller::getImpl(manager); + adapter->printInfo(msg, device, adapter->getLogStream()); + return AF_SUCCESS; + }; AF_CHECK(af_memory_manager_set_print_info_fn(itf, printInfoFn)); auto userLockFn = [](af_memory_manager manager, void* ptr) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("userLock", (std::uintptr_t) ptr); - m->userLock(ptr); - return AF_SUCCESS; - }; + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("userLock", reinterpret_cast(ptr)); + m->userLock(ptr); + return AF_SUCCESS; + }; AF_CHECK(af_memory_manager_set_user_lock_fn(itf, userLockFn)); auto userUnlockFn = [](af_memory_manager manager, void* ptr) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("userUnlock", (std::uintptr_t) ptr); - MemoryManagerInstaller::getImpl(manager)->userUnlock(ptr); - return AF_SUCCESS; - }; + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("userUnlock", reinterpret_cast(ptr)); + MemoryManagerInstaller::getImpl(manager)->userUnlock(ptr); + return AF_SUCCESS; + }; AF_CHECK(af_memory_manager_set_user_unlock_fn(itf, userUnlockFn)); auto isUserLockedFn = [](af_memory_manager manager, int* out, void* ptr) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("isUserLocked", (std::uintptr_t) ptr); - *out = static_cast(m->isUserLocked(ptr)); - return AF_SUCCESS; - }; + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("isUserLocked", reinterpret_cast(ptr)); + *out = static_cast(m->isUserLocked(ptr)); + return AF_SUCCESS; + }; AF_CHECK(af_memory_manager_set_is_user_locked_fn(itf, isUserLockedFn)); auto getMemoryPressureFn = [](af_memory_manager manager, float* pressure) { - *pressure = MemoryManagerInstaller::getImpl(manager)->getMemoryPressure(); - return AF_SUCCESS; - }; + *pressure = MemoryManagerInstaller::getImpl(manager)->getMemoryPressure(); + return AF_SUCCESS; + }; AF_CHECK( af_memory_manager_set_get_memory_pressure_fn(itf, getMemoryPressureFn) ); auto jitTreeExceedsMemoryPressureFn = [](af_memory_manager manager, int* out, size_t bytes) { - *out = static_cast(MemoryManagerInstaller::getImpl(manager) - ->jitTreeExceedsMemoryPressure(bytes)); - return AF_SUCCESS; - }; + *out = static_cast(MemoryManagerInstaller::getImpl(manager) + ->jitTreeExceedsMemoryPressure(bytes)); + return AF_SUCCESS; + }; AF_CHECK( af_memory_manager_set_jit_tree_exceeds_memory_pressure_fn( itf, @@ -166,8 +173,8 @@ MemoryManagerInstaller::MemoryManagerInstaller( ) ); auto addMemoryManagementFn = [](af_memory_manager manager, int device) { - MemoryManagerInstaller::getImpl(manager)->addMemoryManagement(device); - }; + MemoryManagerInstaller::getImpl(manager)->addMemoryManagement(device); + }; AF_CHECK( af_memory_manager_set_add_memory_management_fn( itf, @@ -175,8 +182,8 @@ MemoryManagerInstaller::MemoryManagerInstaller( ) ); auto removeMemoryManagementFn = [](af_memory_manager manager, int device) { - MemoryManagerInstaller::getImpl(manager)->removeMemoryManagement(device); - }; + MemoryManagerInstaller::getImpl(manager)->removeMemoryManagement(device); + }; AF_CHECK( af_memory_manager_set_remove_memory_management_fn( itf, @@ -186,47 +193,47 @@ MemoryManagerInstaller::MemoryManagerInstaller( // Native and device memory manager functions auto getActiveDeviceIdFn = [itf]() { - int id; - AF_CHECK(af_memory_manager_get_active_device_id(itf, &id)); - return id; - }; + int id; + AF_CHECK(af_memory_manager_get_active_device_id(itf, &id)); + return id; + }; impl_->deviceInterface->getActiveDeviceId = std::move(getActiveDeviceIdFn); auto getMaxMemorySizeFn = [itf](int id) { - size_t out; - AF_CHECK(af_memory_manager_get_max_memory_size(itf, &out, id)); - return out; - }; + size_t out; + AF_CHECK(af_memory_manager_get_max_memory_size(itf, &out, id)); + return out; + }; impl_->deviceInterface->getMaxMemorySize = std::move(getMaxMemorySizeFn); // nativeAlloc could throw via AF_CHECK: - auto nativeAllocFn = [itf](const size_t bytes) { - void* ptr; - AF_CHECK(af_memory_manager_native_alloc(itf, &ptr, bytes)); - MemoryManagerInstaller::getImpl(itf)->log( - "nativeAlloc", - bytes, - (std::uintptr_t) ptr - ); - return ptr; - }; + auto nativeAllocFn = [itf](size_t const bytes) { + void* ptr; + AF_CHECK(af_memory_manager_native_alloc(itf, &ptr, bytes)); + MemoryManagerInstaller::getImpl(itf)->log( + "nativeAlloc", + bytes, + reinterpret_cast(ptr) + ); + return ptr; + }; impl_->deviceInterface->nativeAlloc = std::move(nativeAllocFn); auto nativeFreeFn = [itf](void* ptr) { - MemoryManagerInstaller::getImpl(itf)->log( - "nativeFree", - (std::uintptr_t) ptr - ); - AF_CHECK(af_memory_manager_native_free(itf, ptr)); - }; + MemoryManagerInstaller::getImpl(itf)->log( + "nativeFree", + reinterpret_cast(ptr) + ); + AF_CHECK(af_memory_manager_native_free(itf, ptr)); + }; impl_->deviceInterface->nativeFree = std::move(nativeFreeFn); auto getMemoryPressureThresholdFn = [itf]() { - float pressure; - AF_CHECK(af_memory_manager_get_memory_pressure_threshold(itf, &pressure)); - return pressure; - }; + float pressure; + AF_CHECK(af_memory_manager_get_memory_pressure_threshold(itf, &pressure)); + return pressure; + }; impl_->deviceInterface->getMemoryPressureThreshold = std::move(getMemoryPressureThresholdFn); auto setMemoryPressureThresholdFn = [itf](float pressure) { - AF_CHECK(af_memory_manager_set_memory_pressure_threshold(itf, pressure)); - }; + AF_CHECK(af_memory_manager_set_memory_pressure_threshold(itf, pressure)); + }; impl_->deviceInterface->setMemoryPressureThreshold = std::move(setMemoryPressureThresholdFn); } diff --git a/flashlight/fl/tensor/backend/stub/StubBackend.cpp b/flashlight/fl/tensor/backend/stub/StubBackend.cpp index 11403b6..12e1b8d 100644 --- a/flashlight/fl/tensor/backend/stub/StubBackend.cpp +++ b/flashlight/fl/tensor/backend/stub/StubBackend.cpp @@ -148,7 +148,7 @@ Tensor StubBackend::nonzero(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPL Tensor StubBackend::pad( const Tensor& /* input */, - const std::vector>& /* padWidths */, + const std::vector>& /* padWidths */, const PadType /* type */ ) { FL_STUB_BACKEND_UNIMPLEMENTED; } diff --git a/flashlight/fl/tensor/backend/stub/StubBackend.h b/flashlight/fl/tensor/backend/stub/StubBackend.h index 062d2d3..c5168e9 100644 --- a/flashlight/fl/tensor/backend/stub/StubBackend.h +++ b/flashlight/fl/tensor/backend/stub/StubBackend.h @@ -27,24 +27,24 @@ class StubBackend : public TensorBackend { // No copy or move construction or assignment StubBackend(StubBackend&&) = delete; - StubBackend(const StubBackend&) = delete; + StubBackend(StubBackend const&) = delete; StubBackend& operator=(StubBackend&&) = delete; - StubBackend& operator=(const StubBackend&) = delete; + StubBackend& operator=(StubBackend const&) = delete; /* -------------------------- Compute Functions -------------------------- */ - void eval(const Tensor& tensor) override; - bool supportsDataType(const fl::dtype& dtype) const override; + void eval(Tensor const& tensor) override; + bool supportsDataType(fl::dtype const& dtype) const override; // Memory management - void getMemMgrInfo(const char* msg, const int deviceId, std::ostream* ostream) + void getMemMgrInfo(char const* msg, int const deviceId, std::ostream* ostream) override; void setMemMgrLogStream(std::ostream* stream) override; - void setMemMgrLoggingEnabled(const bool enabled) override; - void setMemMgrFlushInterval(const size_t interval) override; + void setMemMgrLoggingEnabled(bool const enabled) override; + void setMemMgrFlushInterval(size_t const interval) override; /* -------------------------- Rand Functions -------------------------- */ - void setSeed(const int seed) override; - Tensor randn(const Shape& shape, dtype type) override; - Tensor rand(const Shape& shape, dtype type) override; + void setSeed(int const seed) override; + Tensor randn(Shape const& shape, dtype type) override; + Tensor rand(Shape const& shape, dtype type) override; /* --------------------------- Tensor Operators --------------------------- */ /******************** Tensor Creation Functions ********************/ @@ -66,71 +66,71 @@ class StubBackend : public TensorBackend { FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned short&); #undef FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL - Tensor identity(const Dim dim, const dtype type) override; - Tensor arange(const Shape& shape, const Dim seqDim, const dtype type) + Tensor identity(Dim const dim, dtype const type) override; + Tensor arange(Shape const& shape, Dim const seqDim, dtype const type) override; - Tensor iota(const Shape& dims, const Shape& tileDims, const dtype type) + Tensor iota(Shape const& dims, Shape const& tileDims, dtype const type) override; /************************ Shaping and Indexing *************************/ - Tensor reshape(const Tensor& tensor, const Shape& shape) override; - Tensor transpose(const Tensor& tensor, const Shape& axes /* = {} */) override; - Tensor tile(const Tensor& tensor, const Shape& shape) override; - Tensor concatenate(const std::vector& tensors, const unsigned axis) + Tensor reshape(Tensor const& tensor, Shape const& shape) override; + Tensor transpose(Tensor const& tensor, Shape const& axes /* = {} */) override; + Tensor tile(Tensor const& tensor, Shape const& shape) override; + Tensor concatenate(std::vector const& tensors, unsigned const axis) override; - Tensor nonzero(const Tensor& tensor) override; + Tensor nonzero(Tensor const& tensor) override; Tensor pad( - const Tensor& input, - const std::vector>& padWidths, - const PadType type + Tensor const& input, + std::vector> const& padWidths, + PadType const type ) override; /************************** Unary Operators ***************************/ - Tensor exp(const Tensor& tensor) override; - Tensor log(const Tensor& tensor) override; - Tensor negative(const Tensor& tensor) override; - Tensor logicalNot(const Tensor& tensor) override; - Tensor log1p(const Tensor& tensor) override; - Tensor sin(const Tensor& tensor) override; - Tensor cos(const Tensor& tensor) override; - Tensor sqrt(const Tensor& tensor) override; - Tensor tanh(const Tensor& tensor) override; - Tensor floor(const Tensor& tensor) override; - Tensor ceil(const Tensor& tensor) override; - Tensor rint(const Tensor& tensor) override; - Tensor absolute(const Tensor& tensor) override; - Tensor sigmoid(const Tensor& tensor) override; - Tensor erf(const Tensor& tensor) override; - Tensor flip(const Tensor& tensor, const unsigned dim) override; - Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high) + Tensor exp(Tensor const& tensor) override; + Tensor log(Tensor const& tensor) override; + Tensor negative(Tensor const& tensor) override; + Tensor logicalNot(Tensor const& tensor) override; + Tensor log1p(Tensor const& tensor) override; + Tensor sin(Tensor const& tensor) override; + Tensor cos(Tensor const& tensor) override; + Tensor sqrt(Tensor const& tensor) override; + Tensor tanh(Tensor const& tensor) override; + Tensor floor(Tensor const& tensor) override; + Tensor ceil(Tensor const& tensor) override; + Tensor rint(Tensor const& tensor) override; + Tensor absolute(Tensor const& tensor) override; + Tensor sigmoid(Tensor const& tensor) override; + Tensor erf(Tensor const& tensor) override; + Tensor flip(Tensor const& tensor, unsigned const dim) override; + Tensor clip(Tensor const& tensor, Tensor const& low, Tensor const& high) override; - Tensor roll(const Tensor& tensor, const int shift, const unsigned axis) + Tensor roll(Tensor const& tensor, int const shift, unsigned const axis) override; - Tensor isnan(const Tensor& tensor) override; - Tensor isinf(const Tensor& tensor) override; - Tensor sign(const Tensor& tensor) override; - Tensor tril(const Tensor& tensor) override; - Tensor triu(const Tensor& tensor) override; - Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) + Tensor isnan(Tensor const& tensor) override; + Tensor isinf(Tensor const& tensor) override; + Tensor sign(Tensor const& tensor) override; + Tensor tril(Tensor const& tensor) override; + Tensor triu(Tensor const& tensor) override; + Tensor where(Tensor const& condition, Tensor const& x, Tensor const& y) override; void topk( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned k, - const Dim axis, - const SortMode sortMode + Tensor const& input, + unsigned const k, + Dim const axis, + SortMode const sortMode ) override; - Tensor sort(const Tensor& input, const Dim axis, const SortMode sortMode) + Tensor sort(Tensor const& input, Dim const axis, SortMode const sortMode) override; void sort( Tensor& values, Tensor& indices, - const Tensor& input, - const Dim axis, - const SortMode sortMode + Tensor const& input, + Dim const axis, + SortMode const sortMode ) override; - Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) + Tensor argsort(Tensor const& input, Dim const axis, SortMode const sortMode) override; /************************** Binary Operators ***************************/ @@ -179,98 +179,98 @@ class StubBackend : public TensorBackend { #undef FL_STUB_BACKEND_BINARY_OP_TYPE_DECL #undef FL_STUB_BACKEND_BINARY_OP_LITERALS_DECL - Tensor minimum(const Tensor& lhs, const Tensor& rhs) override; - Tensor maximum(const Tensor& lhs, const Tensor& rhs) override; - Tensor power(const Tensor& lhs, const Tensor& rhs) override; + Tensor minimum(Tensor const& lhs, Tensor const& rhs) override; + Tensor maximum(Tensor const& lhs, Tensor const& rhs) override; + Tensor power(Tensor const& lhs, Tensor const& rhs) override; /******************************* BLAS ********************************/ Tensor matmul( - const Tensor& lhs, - const Tensor& rhs, + Tensor const& lhs, + Tensor const& rhs, MatrixProperty lhsProp, MatrixProperty rhsProp ) override; /************************** Reductions ***************************/ Tensor amin( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) override; Tensor amax( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) override; void min( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims + Tensor const& input, + unsigned const axis, + bool const keepDims ) override; void max( Tensor& values, Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims + Tensor const& input, + unsigned const axis, + bool const keepDims ) override; Tensor sum( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) override; - Tensor cumsum(const Tensor& input, const unsigned axis) override; - Tensor argmax(const Tensor& input, const unsigned axis, const bool keepDims) + Tensor cumsum(Tensor const& input, unsigned const axis) override; + Tensor argmax(Tensor const& input, unsigned const axis, bool const keepDims) override; - Tensor argmin(const Tensor& input, const unsigned axis, const bool keepDims) + Tensor argmin(Tensor const& input, unsigned const axis, bool const keepDims) override; Tensor mean( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) override; Tensor median( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) override; Tensor var( - const Tensor& input, - const std::vector& axes, - const bool bias, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const bias, + bool const keepDims ) override; Tensor std( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) override; Tensor norm( - const Tensor& input, - const std::vector& axes, + Tensor const& input, + std::vector const& axes, double p, - const bool keepDims + bool const keepDims ) override; Tensor countNonzero( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) override; Tensor any( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) override; Tensor all( - const Tensor& input, - const std::vector& axes, - const bool keepDims + Tensor const& input, + std::vector const& axes, + bool const keepDims ) override; /************************** Utils ***************************/ - void print(const Tensor& tensor) override; + void print(Tensor const& tensor) override; }; } // namespace fl diff --git a/flashlight/fl/tensor/backend/stub/StubTensor.cpp b/flashlight/fl/tensor/backend/stub/StubTensor.cpp index d218e74..b875413 100644 --- a/flashlight/fl/tensor/backend/stub/StubTensor.cpp +++ b/flashlight/fl/tensor/backend/stub/StubTensor.cpp @@ -74,7 +74,7 @@ const Stream& StubTensor::stream() const { FL_STUB_TENSOR_UNIMPLEMENTED; } -Tensor StubTensor::astype(const dtype /* type */) { FL_STUB_TENSOR_UNIMPLEMENTED; } +Tensor StubTensor::asType(const dtype /* type */) { FL_STUB_TENSOR_UNIMPLEMENTED; } Tensor StubTensor::index(const std::vector& /* indices */) { FL_STUB_TENSOR_UNIMPLEMENTED; } diff --git a/flashlight/fl/tensor/backend/stub/StubTensor.h b/flashlight/fl/tensor/backend/stub/StubTensor.h index b5329fd..14b5ca6 100644 --- a/flashlight/fl/tensor/backend/stub/StubTensor.h +++ b/flashlight/fl/tensor/backend/stub/StubTensor.h @@ -67,7 +67,7 @@ class StubTensor : public TensorAdapterBase { bool isContiguous() override; Shape strides() override; const Stream& stream() const override; - Tensor astype(const dtype type) override; + Tensor asType(const dtype type) override; Tensor index(const std::vector& indices) override; Tensor flatten() const override; Tensor flat(const Index& idx) const override; diff --git a/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp b/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp index 844d41f..014ea4d 100644 --- a/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp +++ b/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp @@ -71,13 +71,13 @@ TEST(AutogradBinaryOpsTest, BinaryCrossEntropy) { auto loss = binaryCrossEntropy(x, y); // bce loss should be positive - ASSERT_TRUE(fl::all(loss.tensor() > 0).scalar()); + ASSERT_TRUE(fl::all_of(loss.tensor() > 0).scalar()); } TEST(AutogradBinaryOpsTest, CrossEntropy) { auto x = Variable(fl::rand({7, 10, 4}, fl::dtype::f64), true); auto y = Variable( - (fl::rand({10, 4}, fl::dtype::u32) % 7).astype(fl::dtype::s32), + (fl::rand({10, 4}, fl::dtype::u32) % 7).asType(fl::dtype::s32), false ); auto ignoreIdx = y(0, 0).scalar(); @@ -165,7 +165,7 @@ TEST_F(AutogradTestF16, LinearF16) { TEST(AutogradBinaryOpsTest, Multiply) { auto x = Variable(fl::rand({5}), true); auto y = x * x; - auto dy = Variable(fl::full({5}, 1.0), false); + auto dy = Variable(fl::full({5}, 1.f), false); y.backward(dy); auto dx = x.grad(); ASSERT_TRUE(allClose(dx.tensor(), 2 * x.tensor())); @@ -175,7 +175,7 @@ TEST(AutogradBinaryOpsTest, MultiplyAdd) { auto x = Variable(fl::rand({5}), true); auto y = Variable(fl::rand({5}), true); auto z = x * x + x * y + y * y; - auto dz = Variable(fl::full({5}, 1.0), false); + auto dz = Variable(fl::full({5}, 1.f), false); z.backward(dz); auto dx = x.grad(); auto dy = y.grad(); @@ -187,19 +187,19 @@ TEST(AutogradBinaryOpsTest, MultiplyAddScalar) { auto x = Variable(fl::rand({5}), true); auto y = Variable(fl::rand({5}), true); auto z = 2 * x + x * y + y; - auto dz = Variable(fl::full({5}, 1.0), false); + auto dz = Variable(fl::full({5}, 1.f), false); z.backward(dz); auto dx = x.grad(); auto dy = y.grad(); - ASSERT_TRUE(allClose(dx.tensor(), (2.0 + y.tensor()))); - ASSERT_TRUE(allClose(dy.tensor(), (1.0 + x.tensor()))); + ASSERT_TRUE(allClose(dx.tensor(), (2.f + y.tensor()))); + ASSERT_TRUE(allClose(dy.tensor(), (1.f + x.tensor()))); } TEST(AutogradBinaryOpsTest, MultiplySub) { auto x = Variable(fl::rand({5}), true); auto y = Variable(fl::rand({5}), true); auto z = x * x - x * y; - auto dz = Variable(fl::full({5}, 1.0), false); + auto dz = Variable(fl::full({5}, 1.f), false); z.backward(dz); auto dx = x.grad(); auto dy = y.grad(); @@ -211,14 +211,14 @@ TEST(AutogradBinaryOpsTest, DivideAdd) { auto x = Variable(fl::rand({5}, fl::dtype::f64), true); auto y = Variable(fl::rand({5}, fl::dtype::f64), true); auto z = x + x / y + y; - auto dz = Variable(fl::full({5}, 1.0, fl::dtype::f64), false); + auto dz = Variable(fl::full({5}, 1.f, fl::dtype::f64), false); z.backward(dz); auto dx = x.grad(); auto dy = y.grad(); ASSERT_EQ(z.type(), fl::dtype::f64); - ASSERT_TRUE(allClose(dx.tensor(), (1.0 + 1.0 / y.tensor()))); + ASSERT_TRUE(allClose(dx.tensor(), (1.f + 1.f / y.tensor()))); ASSERT_TRUE( - allClose(dy.tensor(), (1.0 - x.tensor() / (y.tensor() * y.tensor()))) + allClose(dy.tensor(), (1.f - x.tensor() / (y.tensor() * y.tensor()))) ); } diff --git a/flashlight/fl/test/autograd/AutogradReductionTest.cpp b/flashlight/fl/test/autograd/AutogradReductionTest.cpp index c9425ed..d9b9595 100644 --- a/flashlight/fl/test/autograd/AutogradReductionTest.cpp +++ b/flashlight/fl/test/autograd/AutogradReductionTest.cpp @@ -18,7 +18,7 @@ using namespace fl; using fl::detail::AutogradTestF16; TEST(AutogradReductionTest, Sum) { - for(const bool keepDims : {false, true}) { + for(bool keepDims : {false, true}) { Shape s = {6}; if(keepDims) s = {6, 1}; @@ -27,7 +27,7 @@ TEST(AutogradReductionTest, Sum) { auto y = Variable(fl::rand({6, 3}), true); auto z = x * sum(y, {1}, keepDims); - auto dz = Variable(fl::full(s, 1.0), false); + auto dz = Variable(fl::full(s, 1.f), false); z.backward(dz); auto dy = y.grad(); @@ -36,13 +36,11 @@ TEST(AutogradReductionTest, Sum) { ASSERT_TRUE(allClose(dx.tensor(), fl::sum(y.tensor(), {1}, keepDims))); // Reduce over 1-dim input - auto funcMean_0 = [keepDims](const Variable& in) { - return sum(in, {0}, keepDims); - }; + auto funcMean_0 = [keepDims](const Variable& in) { return sum(in, {0}, keepDims); }; auto in = Variable(fl::rand({6}), true); ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, in, 5E-3)); // Reduce over scalar input - auto inScalar = Variable(fl::fromScalar(3.14), true); + auto inScalar = Variable(fl::fromScalar(3.14f), true); ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, inScalar, 5E-3)); } @@ -56,7 +54,7 @@ TEST(AutogradReductionTest, SumAs) { auto x = Variable(fl::rand({5}), true); auto y = Variable(fl::rand({5, 2}), true); auto z = x * sumAs(y, x); - auto dz = Variable(fl::full({5}, 1.0), false); + auto dz = Variable(fl::full({5}, 1.f), false); z.backward(dz); auto dy = y.grad(); auto dx = x.grad(); @@ -67,20 +65,20 @@ TEST(AutogradReductionTest, SumAs) { TEST(AutogradReductionTest, SumAs2) { auto y = Variable(fl::rand({5, 2}), true); auto z = sumAs(y, {5}); - auto dz = Variable(fl::full({5}, 1.0), false); + auto dz = Variable(fl::full({5}, 1.f), false); z.backward(dz); auto dy = y.grad(); - ASSERT_TRUE(allClose(dy.tensor(), fl::full({5, 2}, 1.0))); + ASSERT_TRUE(allClose(dy.tensor(), fl::full({5, 2}, 1.f))); } TEST(AutogradReductionTest, Mean) { - for(const bool keepDims : {false, true}) { + for(bool keepDims : {false, true}) { Shape xShape = keepDims ? Shape({5, 1, 1}) : Shape({5}); auto x = Variable(fl::rand(xShape), true); auto y = Variable(fl::rand({5, 3, 2}), true); auto varOut = mean(y, {1, 2}, keepDims); auto z = x * mean(y, {1, 2}, keepDims); - auto dz = Variable(fl::full(x.shape(), 1.0), false); + auto dz = Variable(fl::full(x.shape(), 1.f), false); z.backward(dz); auto dy = y.grad(); auto dx = x.grad(); @@ -88,9 +86,7 @@ TEST(AutogradReductionTest, Mean) { ASSERT_TRUE(allClose(dx.tensor(), fl::mean(y.tensor(), {1, 2}, keepDims))); auto a = Variable(fl::rand({5, 3, 2}, fl::dtype::f64), true); - auto funcMean = [keepDims](Variable& in) { - return mean(in, {1, 2}, keepDims); - }; + auto funcMean = [keepDims](Variable& in) { return mean(in, {1, 2}, keepDims); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean, a, 1E-4)); auto q = Variable(fl::rand({5, 6, 7, 8}), false); @@ -98,14 +94,12 @@ TEST(AutogradReductionTest, Mean) { auto qOutTensor = fl::mean(q.tensor(), {1, 2}, keepDims); ASSERT_TRUE(allClose(qOut.tensor(), qOutTensor)); - auto funcMean_0 = [keepDims](Variable& in) { - return mean(in, {0}, keepDims); - }; + auto funcMean_0 = [keepDims](Variable& in) { return mean(in, {0}, keepDims); }; // Reduce over 1-dim input auto in = Variable(fl::rand({6}), true); ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, in, 5E-3)); // Reduce over scalar input - auto inScalar = Variable(fl::fromScalar(3.14), true); + auto inScalar = Variable(fl::fromScalar(3.14f), true); ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, inScalar, 5E-3)); } } @@ -113,8 +107,8 @@ TEST(AutogradReductionTest, Mean) { TEST(AutogradReductionTest, Variance) { std::vector biased = {true, false}; for(auto b : biased) - for(const bool keepDims : {false, true}) { - auto x = Variable(fl::rand({5, 6, 7, 8}, fl::dtype::f64), true); + for(bool keepDims : {false, true}) { + auto x = Variable(fl::rand({5, 6, 7, 8}, fl::dtype::f32), true); // TODO:{fl::Tensor} -- enforce AF versioning and remediate // Behavior of the bias parameter in af::var was changed in @@ -128,9 +122,7 @@ TEST(AutogradReductionTest, Variance) { auto calculatedVar = var(x, {1}, b, keepDims); ASSERT_TRUE(allClose(calculatedVar.tensor(), expectedVar)); - auto funcVar = [b, keepDims](Variable& in) { - return var(in, {1, 2}, b, keepDims); - }; + auto funcVar = [b, keepDims](Variable& in) { return var(in, {1, 2}, b, keepDims); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcVar, x, 1E-5, 1E-5)); } } @@ -138,17 +130,11 @@ TEST(AutogradReductionTest, Variance) { TEST(AutogradReductionTest, Norm) { auto x = Variable(fl::rand({5, 3}, fl::dtype::f64), true); for(const bool keepDims : {false, true}) { - auto funcNorm2 = [keepDims](Variable& in) { - return norm(in, {1}, 2, keepDims); - }; + auto funcNorm2 = [keepDims](Variable& in) { return norm(in, {1}, 2, keepDims); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcNorm2, x, 1E-4)); - auto funcNorm1 = [keepDims](Variable& in) { - return norm(in, {1}, 1, keepDims); - }; + auto funcNorm1 = [keepDims](Variable& in) { return norm(in, {1}, 1, keepDims); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcNorm1, x, 1E-4)); - auto funcNorm3 = [keepDims](Variable& in) { - return norm(in, {1}, 3, keepDims); - }; + auto funcNorm3 = [keepDims](Variable& in) { return norm(in, {1}, 3, keepDims); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcNorm3, x, 1E-4)); } } diff --git a/flashlight/fl/test/autograd/AutogradRnnTest.cpp b/flashlight/fl/test/autograd/AutogradRnnTest.cpp index 1195583..8e0cf23 100644 --- a/flashlight/fl/test/autograd/AutogradRnnTest.cpp +++ b/flashlight/fl/test/autograd/AutogradRnnTest.cpp @@ -50,14 +50,14 @@ void testRnnImpl(RnnMode mode, fl::dtype precision = fl::dtype::f64) { } auto w = - Variable(fl::rand({static_cast(nParams)}, precision), true); + Variable(fl::rand({static_cast(nParams)}, precision), true); auto funcRnnIn = [&](Variable& input) -> Variable { return std::get<0>( rnn( input, - Variable().astype(precision), - Variable().astype(precision), + Variable().asType(precision), + Variable().asType(precision), w, hiddenSize, numLayers, @@ -73,8 +73,8 @@ void testRnnImpl(RnnMode mode, fl::dtype precision = fl::dtype::f64) { return std::get<0>( rnn( in, - Variable().astype(precision), - Variable().astype(precision), + Variable().asType(precision), + Variable().asType(precision), weights, hiddenSize, numLayers, @@ -98,8 +98,8 @@ void testRnnImpl(RnnMode mode, fl::dtype precision = fl::dtype::f64) { return std::get<0>( rnn( in, - hiddenState.astype(precision), - Variable().astype(precision), + hiddenState.asType(precision), + Variable().asType(precision), w, hiddenSize, numLayers, @@ -116,8 +116,8 @@ void testRnnImpl(RnnMode mode, fl::dtype precision = fl::dtype::f64) { return std::get<1>( rnn( input, - Variable().astype(precision), - Variable().astype(precision), + Variable().asType(precision), + Variable().asType(precision), w, hiddenSize, numLayers, @@ -144,8 +144,8 @@ void testRnnImpl(RnnMode mode, fl::dtype precision = fl::dtype::f64) { return std::get<0>( rnn( in, - Variable().astype(precision), - cellState.astype(precision), + Variable().asType(precision), + cellState.asType(precision), w, hiddenSize, numLayers, @@ -164,8 +164,8 @@ void testRnnImpl(RnnMode mode, fl::dtype precision = fl::dtype::f64) { return std::get<2>( rnn( input, - Variable().astype(precision), - Variable().astype(precision), + Variable().asType(precision), + Variable().asType(precision), w, hiddenSize, numLayers, diff --git a/flashlight/fl/test/autograd/AutogradTest.cpp b/flashlight/fl/test/autograd/AutogradTest.cpp index 3d53c77..9e430ab 100644 --- a/flashlight/fl/test/autograd/AutogradTest.cpp +++ b/flashlight/fl/test/autograd/AutogradTest.cpp @@ -25,7 +25,7 @@ using namespace fl; using fl::detail::AutogradTestF16; TEST(AutogradTest, OperatorParenthesis) { - auto x = Variable(fl::rand({1, 3, 3}, fl::dtype::f64), true); + auto x = Variable{fl::rand({1, 3, 3}, fl::dtype::f64), true}; auto y = x(0, 0) + x(0, 1); auto funcOperatorParen = [](Variable& in) { return in(0, 0) + in(0, 1); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcOperatorParen, x)); @@ -35,171 +35,154 @@ TEST(AutogradTest, AutogradOperatorTypeCompatibility) { if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - auto f16 = Variable(fl::rand({2, 2}, fl::dtype::f16), true); - auto f32 = Variable(fl::rand({2, 2}, fl::dtype::f32), true); + auto f16 = Variable{fl::rand({2, 2}, fl::dtype::f16), true}; + auto f32 = Variable{fl::rand({2, 2}, fl::dtype::f32), true}; // Binary operators EXPECT_THROW( - { - auto res = f16 + f32; - }, + {auto res = f16 + f32;}, std::invalid_argument ); // + EXPECT_THROW( - { - auto res = f16 - f32; - }, + {auto res = f16 - f32;}, std::invalid_argument ); // - EXPECT_THROW( - { - auto res = f16 * f32; - }, + {auto res = f16 * f32;}, std::invalid_argument ); // * EXPECT_THROW( - { - auto res = f16 / f32; - }, + {auto res = f16 / f32;}, std::invalid_argument ); /// EXPECT_THROW( - { - auto res = f16 > f32; - }, + {auto res = f16 > f32;}, std::invalid_argument ); // > EXPECT_THROW( - { - auto res = f16 < f32; - }, + {auto res = f16 < f32;}, std::invalid_argument ); // < + EXPECT_THROW({auto res = f16 >= f32;}, std::invalid_argument); // >= EXPECT_THROW( { - auto res = f16 >= f32; - }, - std::invalid_argument - ); // >= - EXPECT_THROW( - { - auto res = f16 <= f32; + auto res = f16 <= f32; }, std::invalid_argument ); // <= EXPECT_THROW( { - auto res = f16 && f32; + auto res = f16 && f32; }, std::invalid_argument ); // && EXPECT_THROW( { - max(f16, f32); + max(f16, f32); }, std::invalid_argument ); // max EXPECT_THROW( { - min(f16, f32); + min(f16, f32); }, std::invalid_argument ); // min EXPECT_THROW( { - matmul(f16, f32); + matmul(f16, f32); }, std::invalid_argument ); // matmul EXPECT_THROW( { - matmulTN(f16, f32); + matmulTN(f16, f32); }, std::invalid_argument ); // matmulTN EXPECT_THROW( { - matmulNT(f16, f32); + matmulNT(f16, f32); }, std::invalid_argument ); // matmulNT EXPECT_NO_THROW( { - binaryCrossEntropy(f16, f32); + binaryCrossEntropy(f16, f32); } ); EXPECT_NO_THROW( { - categoricalCrossEntropy( - Variable(fl::rand({7, 10, 4}, fl::dtype::f16), true), - Variable( - (fl::rand({10, 4}, fl::dtype::u32) % 7).astype(fl::dtype::s32), - false - ) - ); + categoricalCrossEntropy( + Variable{fl::rand({7, 10, 4}, fl::dtype::f16), true}, + Variable{ + (fl::rand({10, 4}, fl::dtype::u32) % 7).asType(fl::dtype::s32), + false + } + ); } ); EXPECT_NO_THROW( { - pool2d(f16, 1, 1, 1, 1, 1, 1); + pool2d(f16, 1, 1, 1, 1, 1, 1); } ); EXPECT_NO_THROW( { - embedding(f16, f32); + embedding(f16, f32); } ); // lookup is of a different type // Ternary operators - auto f32_2 = Variable(fl::rand({2, 2}, fl::dtype::f32), true); - auto f16_2 = Variable(fl::rand({2, 2}, fl::dtype::f16), true); + auto f32_2 = Variable{fl::rand({2, 2}, fl::dtype::f32), true}; + auto f16_2 = Variable{fl::rand({2, 2}, fl::dtype::f16), true}; EXPECT_THROW( { - linear(f16, f32, f16_2); + linear(f16, f32, f16_2); }, std::invalid_argument ); // linear EXPECT_THROW( { - linear(f16, f32, f32_2); + linear(f16, f32, f32_2); }, std::invalid_argument ); // linear - auto w = Variable(fl::rand({1}, fl::dtype::f32), true); - auto b = Variable(fl::rand({1}, fl::dtype::f32), true); + auto w = Variable{fl::rand({1}, fl::dtype::f32), true}; + auto b = Variable{fl::rand({1}, fl::dtype::f32), true}; EXPECT_THROW( { - batchnorm(f16, f32, f32_2, w, b, {1}, true, 0.01, 0.01); + batchnorm(f16, f32, f32_2, w, b, {1}, true, 0.01, 0.01); }, std::invalid_argument ); EXPECT_THROW( { - batchnorm(f16, f32, f16_2, w, b, {1}, true, 0.01, 0.01); + batchnorm(f16, f32, f16_2, w, b, {1}, true, 0.01, 0.01); }, std::invalid_argument ); EXPECT_THROW( { - conv2d(f16, f32, f16_2, 1, 1, 0, 0, 1, 1); + conv2d(f16, f32, f16_2, 1, 1, 0, 0, 1, 1); }, std::invalid_argument ); // Quaternary - auto f16_3 = Variable(fl::rand({2, 2, 3}, fl::dtype::f16), false); - auto f16_4 = Variable(fl::rand({50}, fl::dtype::f16), false); + auto f16_3 = Variable{fl::rand({2, 2, 3}, fl::dtype::f16), false}; + auto f16_4 = Variable{fl::rand({50}, fl::dtype::f16), false}; EXPECT_THROW( { - rnn( - f16_3, - Variable(Tensor(fl::dtype::f32), false), - Variable(Tensor(fl::dtype::f32), false), - f16_4, - 2, - 2, - RnnMode::LSTM, - true, - 0.0 - ); + rnn( + f16_3, + Variable{Tensor{fl::dtype::f32}, false}, + Variable{Tensor{fl::dtype::f32}, false}, + f16_4, + 2, + 2, + RnnMode::LSTM, + true, + 0.0 + ); }, std::invalid_argument ); @@ -207,7 +190,7 @@ TEST(AutogradTest, AutogradOperatorTypeCompatibility) { std::vector concatInputs = {f16, f32, f16_2, f32_2}; EXPECT_THROW( { - concatenate(concatInputs, 0); + concatenate(concatInputs, 0); }, std::invalid_argument ); @@ -217,12 +200,12 @@ TEST(AutogradTest, CastingAsDifferentGradTypes) { if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - auto f32 = Variable(fl::rand({5, 5}), true); - auto f16 = Variable(fl::rand({5, 5}, fl::dtype::f16), true); + auto f32 = Variable{fl::rand({5, 5}), true}; + auto f16 = Variable{fl::rand({5, 5}, fl::dtype::f16), true}; // Computing gradients with mixed types fails when the op is applied ASSERT_THROW( { - f32 + f16; + f32 + f16; }, std::invalid_argument ); @@ -232,24 +215,24 @@ TEST(AutogradTest, CastingAs) { if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - auto var = Variable(fl::rand({5, 5}), true); - auto varF16 = var.astype(fl::dtype::f16); + auto var = Variable{fl::rand({5, 5}), true}; + auto varF16 = var.asType(fl::dtype::f16); ASSERT_EQ(var.type(), fl::dtype::f32); ASSERT_EQ(varF16.type(), fl::dtype::f16); - ASSERT_TRUE(allClose(varF16.tensor(), var.astype(fl::dtype::f16).tensor())); + ASSERT_TRUE(allClose(varF16.tensor(), var.asType(fl::dtype::f16).tensor())); } TEST(AutogradTest, CastingAsBackward) { if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - auto a = Variable(fl::rand({4, 4}, fl::dtype::f16), true); - auto b = Variable(fl::rand({4, 4}, fl::dtype::f16), false); + auto a = Variable{fl::rand({4, 4}, fl::dtype::f16), true}; + auto b = Variable{fl::rand({4, 4}, fl::dtype::f16), false}; auto c = b + a; c.backward(); ASSERT_EQ(a.grad().type(), fl::dtype::f16); ASSERT_EQ(a.grad().type(), fl::dtype::f16); - a = a.astype(fl::dtype::f32); + a = a.asType(fl::dtype::f32); ASSERT_FALSE(a.isGradAvailable()); } @@ -258,21 +241,21 @@ TEST(AutogradTest, CastingAsGrad) { GTEST_SKIP() << "Half-precision not supported on this device"; // compare to f32 case - auto x = Variable(fl::full({5}, 2.0), true); - auto y = Variable(fl::full({5}, 3.0), true); + auto x = Variable{fl::full({5}, 2.f), true}; + auto y = Variable{fl::full({5}, 3.f), true}; auto z = x * x + x * y + y * y; - auto dz = Variable(fl::full({5}, 1.0), false); + auto dz = Variable{fl::full({5}, 1.f), false}; z.backward(dz); auto dx = x.grad(); auto dy = y.grad(); // f16 -- cast gradients in both directions - auto x32 = Variable(fl::full({5}, 2.0), true); - auto y32 = Variable(fl::full({5}, 3.0), true); - auto xf16 = x32.astype(fl::dtype::f16); - auto yf16 = y32.astype(fl::dtype::f16); + auto x32 = Variable{fl::full({5}, 2.f), true}; + auto y32 = Variable{fl::full({5}, 3.f), true}; + auto xf16 = x32.asType(fl::dtype::f16); + auto yf16 = y32.asType(fl::dtype::f16); auto zf16 = xf16 * xf16 + xf16 * yf16 + yf16 * yf16; - auto zf32 = zf16.astype(fl::dtype::f32); + auto zf32 = zf16.asType(fl::dtype::f32); zf32.backward(dz); ASSERT_EQ(xf16.grad().type(), fl::dtype::f16); @@ -281,20 +264,20 @@ TEST(AutogradTest, CastingAsGrad) { ASSERT_EQ(x32.grad().type(), fl::dtype::f32); ASSERT_EQ(y32.grad().type(), fl::dtype::f32); ASSERT_TRUE( - allClose(dx.tensor(), xf16.grad().tensor().astype(fl::dtype::f32)) + allClose(dx.tensor(), xf16.grad().tensor().asType(fl::dtype::f32)) ); ASSERT_TRUE( - allClose(dy.tensor(), y32.grad().tensor().astype(fl::dtype::f32)) + allClose(dy.tensor(), y32.grad().tensor().asType(fl::dtype::f32)) ); ASSERT_TRUE(allClose(dx.tensor(), x32.grad().tensor())); ASSERT_TRUE(allClose(dy.tensor(), y32.grad().tensor())); } TEST(AutogradTest, NoCalcGrad) { - auto x = Variable(fl::rand({5}), false); - auto y = Variable(fl::rand({5}), true); + auto x = Variable{fl::rand({5}), false}; + auto y = Variable{fl::rand({5}), true}; auto z = x * x + x * y + y * y; - auto dz = Variable(fl::full({5}, 1.0), false); + auto dz = Variable{fl::full({5}, 1.f), false}; z.backward(dz); auto dy = y.grad(); ASSERT_TRUE(allClose(dy.tensor(), 2 * y.tensor() + x.tensor())); @@ -302,36 +285,32 @@ TEST(AutogradTest, NoCalcGrad) { } TEST(AutogradTest, Concatenate) { - auto x1 = Variable(fl::rand({2, 3, 1, 2}, fl::dtype::f64), true); - auto x2 = Variable(fl::rand({2, 3, 3, 2}, fl::dtype::f64), true); - auto x3 = Variable(fl::rand({2, 3, 1, 2}, fl::dtype::f64), true); - auto x4 = Variable(fl::rand({2, 3, 7, 2}, fl::dtype::f64), true); + auto x1 = Variable{fl::rand({2, 3, 1, 2}, fl::dtype::f64), true}; + auto x2 = Variable{fl::rand({2, 3, 3, 2}, fl::dtype::f64), true}; + auto x3 = Variable{fl::rand({2, 3, 1, 2}, fl::dtype::f64), true}; + auto x4 = Variable{fl::rand({2, 3, 7, 2}, fl::dtype::f64), true}; std::vector inputs = {x1, x2, x3, x4}; auto output = concatenate(inputs, 2); ASSERT_EQ(output.shape(), Shape({2, 3, 12, 2})); - auto funcConcatenateT1 = [x2, x3, x4](Variable& in) { - return concatenate({in, x2, x3, x4}, 2); - }; + auto funcConcatenateT1 = [x2, x3, x4](Variable& in) { return concatenate({in, x2, x3, x4}, 2); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConcatenateT1, x1, 1E-5, 1E-4, {&x2, &x3, &x4})); - auto funcConcatenateT2 = [x1, x2, x4](Variable& in) { - return concatenate({x1, x2, in, x4}, 2); - }; + auto funcConcatenateT2 = [x1, x2, x4](Variable& in) { return concatenate({x1, x2, in, x4}, 2); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConcatenateT2, x3, 1E-5, 1E-4, {&x1, &x2, &x4})); } TEST(AutogradTest, Split) { // check output - auto x = Variable(fl::arange({7, 2}), true); + auto x = Variable{fl::arange({7, 2}), true}; auto yVec = split(x, 1, 0); ASSERT_EQ(yVec.size(), 7); ASSERT_EQ(yVec[0].shape(), Shape({1, 2})); ASSERT_EQ(yVec[2].shape(), Shape({1, 2})); ASSERT_TRUE(fl::all(yVec[6].tensor() == 6).scalar()); - auto a = Variable(fl::arange({5, 3}, 1), true); + auto a = Variable{fl::arange({5, 3}, 1), true}; auto bVec = split(a, {2, 1}, 1); ASSERT_EQ(bVec.size(), 2); ASSERT_EQ(bVec[0].shape(), Shape({5, 2})); @@ -346,15 +325,15 @@ TEST(AutogradTest, Split) { // check gradient auto gradFunc = [](Variable& in) { return split(in, 2, 1)[0]; }; - auto input = Variable(fl::rand({2, 3}, fl::dtype::f64), true); + auto input = Variable{fl::rand({2, 3}, fl::dtype::f64), true}; ASSERT_TRUE(fl::detail::jacobianTestImpl(gradFunc, input)); } TEST(AutogradTest, Tile) { - auto x = Variable(fl::rand({6}), true); - auto y = Variable(fl::rand({6, 3}), true); + auto x = Variable{fl::rand({6}), true}; + auto y = Variable{fl::rand({6, 3}), true}; auto z = y * tile(x, {1, 3}); - auto dz = Variable(fl::full({6, 3}, 1.0), false); + auto dz = Variable{fl::full({6, 3}, 1.f), false}; z.backward(dz); auto dy = y.grad(); auto dx = x.grad(); @@ -362,16 +341,16 @@ TEST(AutogradTest, Tile) { ASSERT_TRUE(allClose(dx.tensor(), fl::sum(y.tensor(), {1}))); // Jacobian - auto input = Variable(fl::rand({10, 1, 5}), true); + auto input = Variable{fl::rand({10, 1, 5}), true}; auto funcTile = [](Variable& in) { return tile(in, {1, 2}); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcTile, input, 1E-4, 1E-3)); } TEST(AutogradTest, TileAs) { - auto x = Variable(fl::rand({5}), true); - auto y = Variable(fl::rand({5, 2}), true); + auto x = Variable{fl::rand({5}), true}; + auto y = Variable{fl::rand({5, 2}), true}; auto z = y * tileAs(x, y); - auto dz = Variable(fl::full({5, 2}, 1.0), false); + auto dz = Variable{fl::full({5, 2}, 1.f), false}; z.backward(dz); auto dy = y.grad(); auto dx = x.grad(); @@ -383,37 +362,37 @@ TEST_F(AutogradTestF16, TileAsF16) { if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - auto x = Variable(fl::rand({5}, fl::dtype::f16), true); - auto y = Variable(fl::rand({5, 2}, fl::dtype::f16), true); + auto x = Variable{fl::rand({5}, fl::dtype::f16), true}; + auto y = Variable{fl::rand({5, 2}, fl::dtype::f16), true}; auto z = y * tileAs(x, y); ASSERT_EQ(x.type(), z.type()); - auto dz = Variable(fl::full({5, 2}, 1.0, fl::dtype::f16), false); + auto dz = Variable{fl::full({5, 2}, 1.f, fl::dtype::f16), false}; z.backward(dz); auto dy = y.grad(); auto dx = x.grad(); ASSERT_TRUE( allClose( dy.tensor(), - fl::tile(x.tensor(), {1, 2}).astype(dx.type()), + fl::tile(x.tensor(), {1, 2}).asType(dx.type()), 1e-2 ) ); ASSERT_TRUE( - allClose(dx.tensor(), fl::sum(y.tensor(), {1}).astype(dx.type()), 1e-2) + allClose(dx.tensor(), fl::sum(y.tensor(), {1}).asType(dx.type()), 1e-2) ); } TEST(AutogradTest, TileAs2) { - auto x = Variable(fl::rand({10}), true); + auto x = Variable{fl::rand({10}), true}; auto z = tileAs(x, Shape({10, 3})); - auto dz = Variable(fl::full({10, 3}, 1.0), false); + auto dz = Variable{fl::full({10, 3}, 1.f), false}; z.backward(dz); auto dx = x.grad(); - ASSERT_TRUE(allClose(dx.tensor(), fl::full(x.shape(), 3.0))); + ASSERT_TRUE(allClose(dx.tensor(), fl::full(x.shape(), 3.f))); } TEST(AutogradTest, Indexing) { - auto x = Variable(fl::rand({5, 6, 7, 4}, fl::dtype::f64), true); + auto x = Variable{fl::rand({5, 6, 7, 4}, fl::dtype::f64), true}; auto funcCol = [](Variable& input) { return input(fl::span, 4); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcCol, x)); @@ -421,39 +400,29 @@ TEST(AutogradTest, Indexing) { auto funcRow = [](Variable& input) { return input(4); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcRow, x)); - auto funcSlice = [](Variable& input) { - return input(fl::span, fl::span, 4); - }; + auto funcSlice = [](Variable& input) { return input(fl::span, fl::span, 4); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSlice, x)); - auto funcCols = [](Variable& input) { - return input(fl::span, fl::range(2, 5)); - }; + auto funcCols = [](Variable& input) { return input(fl::span, fl::range(2, 5)); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcCols, x)); auto funcRows = [](Variable& input) { return input(fl::range(2, 5)); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcRows, x)); - auto funcSlices = [](Variable& input) { - return input(fl::span, fl::span, fl::range(2, 5)); - }; + auto funcSlices = [](Variable& input) { return input(fl::span, fl::span, fl::range(2, 5)); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSlices, x)); - auto funcFlat = [](Variable& input) { - return input.flat(fl::range(4, 100)); - }; + auto funcFlat = [](Variable& input) { return input.flat(fl::range(4, 100)); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcFlat, x)); } TEST(AutogradTest, Padding) { - auto in = Variable(fl::rand({3, 3}, fl::dtype::f32), true); - auto funcPad = [&](Variable& input) { - return padding(input, {{1, 2}, {0, 1}}, -1); - }; + auto in = Variable{fl::rand({3, 3}, fl::dtype::f32), true}; + auto funcPad = [&](Variable& input) { return padding(input, {{1, 2}, {0, 1}}, -1); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcPad, in, 1E-3)); } TEST(AutogradTest, Pooling) { - auto in = Variable(fl::rand({3, 3, 1, 1}, fl::dtype::f32), true); + auto in = Variable{fl::rand({3, 3, 1, 1}, fl::dtype::f32), true}; auto funcPool = [&](Variable& input) { return pool2d(input, 2, 2, 1, 1); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcPool, in, 1E-3)); } @@ -462,27 +431,26 @@ TEST_F(AutogradTestF16, PoolingF16) { if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - const float inputScale = 2.0; // scale the input to prevent grad underflow - auto in = Variable(inputScale * fl::rand({3, 3, 1, 1}, fl::dtype::f16), true); + float const inputScale = 2.f; // scale the input to prevent grad underflow + auto in = Variable{inputScale * fl::rand({3, 3, 1, 1}, fl::dtype::f16), true}; auto funcPool = [&](Variable& input) { return pool2d(input, 2, 2, 1, 1); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcPool, in, 1e1, 1e-1)); // TODO: investigate + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcPool, in, 1e-2, 1e-1)); // TODO: investigate } TEST(AutogradTest, Reorder) { - auto in = Variable(fl::rand({3, 1, 4, 1}, fl::dtype::f32) * 2, true); - auto funcReorder = [&](Variable& input) { - return reorder(input, {2, 0, 3, 1}); - }; + auto in = Variable{fl::rand({3, 1, 4, 1}, fl::dtype::f32) * 2, true}; + auto funcReorder = [&](Variable& input) { return reorder(input, {2, 0, 3, 1}); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcReorder, in, 1E-3)); } TEST(AutogradTest, Embedding) { int nWords = 10; - auto input = - Variable((fl::rand({4, 2}) * nWords).astype(fl::dtype::f32), false); - auto weights = Variable(fl::randn({4, nWords}, fl::dtype::f64), true); - auto funcEmbed = [&](Variable& w) { return embedding(input, w); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcEmbed, weights, 1E-5)); + auto input = Variable{(fl::rand({4, 2}) * nWords).asType(fl::dtype::s32), false}; + + auto weights = Variable{fl::randn({4, nWords}, fl::dtype::f64), true}; + auto func_embed = [&](Variable& w) { return embedding(input, w); }; + + ASSERT_TRUE(fl::detail::jacobianTestImpl(func_embed, weights, 1E-5)); } TEST(AutogradTest, GetAdvancedIndex) { @@ -491,18 +459,21 @@ TEST(AutogradTest, GetAdvancedIndex) { GTEST_SKIP() << "Advanced indexing operator unsupported for non-CUDA backends"; std::vector validIndexTypes = { - fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64 + fl::dtype::s32, + fl::dtype::s64, + fl::dtype::u32, + fl::dtype::u64 }; for(const auto& dtype : validIndexTypes) { - auto x = Variable(fl::rand({20, 50, 40, 30}, fl::dtype::f32), true); - Tensor a({6}, dtype); + auto x = Variable{fl::rand({20, 50, 40, 30}, fl::dtype::f32), true}; + Tensor a{{6}, dtype}; a(0) = 0; a(1) = 15; a(2) = 6; a(3) = 1; a(4) = 10; a(5) = 6; - Tensor b({3}, dtype); + Tensor b{{3}, dtype}; b(0) = 5; b(1) = 11; b(2) = 19; @@ -523,18 +494,21 @@ TEST(AutogradTest, GetAdvancedIndexF16) { if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; std::vector validIndexTypes = { - fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64 + fl::dtype::s32, + fl::dtype::s64, + fl::dtype::u32, + fl::dtype::u64 }; for(const auto& dtype : validIndexTypes) { - auto x = Variable(fl::rand({20, 50, 40, 30}, fl::dtype::f16), true); - Tensor a({6}, dtype); + auto x = Variable{fl::rand({20, 50, 40, 30}, fl::dtype::f16), true}; + Tensor a{{6}, dtype}; a(0) = 0; a(1) = 15; a(2) = 6; a(3) = 1; a(4) = 10; a(5) = 6; - Tensor b({3}, dtype); + Tensor b{{3}, dtype}; b(0) = 5; b(1) = 11; b(2) = 19; @@ -549,6 +523,218 @@ TEST(AutogradTest, GetAdvancedIndexF16) { } } + + +namespace fl { + +void print_tensor(Tensor const& toPrint, std::string_view name) { + auto const dims = toPrint.ndim(); + + if(dims == 0) + std::cout << "[]\n"; + if(dims > 3) + std::cout << std::format("can't print tensor [{}], has more than 3 dimensions\n", name); + + auto adaptive_tensor_print = [&]() { + auto host = toPrint.toHostVector(); + + std::cout << std::format("{}:\n", name); + + auto const& shape = fl::max({1, 1, 1}, toPrint.shape()); + + auto const height = shape[0]; + auto const width = shape[1]; + auto const depth = shape[2]; + + std::vector blocks(depth); + + for(size_t z = 0; z < depth; z++) { + std::vector rows(height); + + auto const offset = z * height * width; + + for(size_t x = 0; x < width; x++) { + for(size_t y = 0; y < height; y++) { + auto& row = rows[y]; + + size_t index = offset + x * height + y; + + if(x == width - 1) + row += std::format("{}", host[index]); + else + row += std::format("{}, ", host[index]); + } + size_t max = 0; + for(auto& row : rows) + max = std::max(row.size(), max); + + for(auto& row : rows) + row.append(std::string(max - row.size(), ' ')); + } + + + for(auto& row : rows) + blocks[z].append(std::format("[{}]\n", row)); + blocks[z] += '\n'; + } + + for(auto& block : blocks) + std::cout << block; + + std::cout << '\n'; + }; + + fl::dispatch_dtype(toPrint.type(), adaptive_tensor_print); +}; + +#define PRINT_TENSOR(tensor) print_tensor(tensor, #tensor) + + +Variable embedding2(Variable const& input, Variable const& embeddings) { + // TODO{fl::Tensor}{4-dims} - relax this + if(input.ndim() >= 4) + throw std::invalid_argument{"embedding input must have 3 or fewer dims"}; + + auto const idxs = input.tensor().flatten(); + auto inDims = input.shape(); + std::vector rDims(input.ndim() + 1); + rDims[0] = embeddings.dim(0); + for(Dim i = 1; i < input.ndim() + 1; i++) + rDims[i] = inDims[i - 1]; + + Shape const resultDims{rDims}; + auto const result = fl::reshape(embeddings.tensor()(fl::span, idxs), resultDims); + + auto grad_func = []( + std::vector& inputs, + Variable const& gradOutput + ) { + auto& w = inputs[1]; + if(!w.isCalcGrad()) + return; + + auto const ip = inputs[0].tensor().flatten(); + auto size = static_cast(ip.elements()); + auto const deltas = fl::reshape(gradOutput.tensor(), {w.dim(0), size}); + + auto const e = fl::full({size}, 1.0, deltas.type()); + auto const iota = fl::arange({size + 1}, 0.0, fl::dtype::s32); + + PRINT_TENSOR(ip); + std::cout << std::format("size: {}\n", size); + PRINT_TENSOR(deltas); + + PRINT_TENSOR(e); + PRINT_TENSOR(iota); + + // Sparse Tensor + auto sp = Tensor{ + static_cast(ip.elements()), + w.dim(1), + e, + iota, + ip.asType(fl::dtype::s32), + fl::StorageType::CSR + }; + //double* x = sp.host(); + //fl::eval(sp); + + auto deltasT = transpose(deltas); + PRINT_TENSOR(deltasT); + auto grad = transpose( + fl::matmul( + sp, + deltasT, + /* lhsProp = */ + MatrixProperty::Transpose + ) + ); + fl::eval(grad); + PRINT_TENSOR(grad); + w.addGrad(Variable{grad, false}); + + PRINT_TENSOR(w.tensor()); + }; + + print_tensor(result, "embedding result"); + + return Variable{result, {input, embeddings}, grad_func}; +} + +namespace detail { + + using JacobianFunc = std::function; + inline bool jacobianTestImpl2( + JacobianFunc const& func, + Variable& input, + double precision = 1E-5, + float perturbation = 1E-4, + std::vector const& zeroGradientVariables = {} + ) { + auto const outBase = func(input); + auto const outElements = outBase.elements(); + auto const inElements = input.elements(); + + auto const fwdJacobian = Tensor({outElements, inElements}, input.type()); + + for(int i = 0; i < inElements; ++i) { + auto orig = input.tensor().flatten()(i); + + input.tensor().flat(i) = orig - perturbation; + auto outA = func(input).tensor(); + + input.tensor().flat(i) = orig + perturbation; + auto outB = func(input).tensor(); + + input.tensor().flat(i) = orig; + + fwdJacobian(fl::span, i) = fl::reshape((outB - outA), {static_cast(outA.elements())}) + * 0.5 + / perturbation; + } + + auto const bwdJacobian = Tensor({outElements, inElements}, input.type()); + auto const outD = Variable(fl::full(outBase.shape(), 0, outBase.type()), false); + + for(int i = 0; i < outD.elements(); ++i) { + outD.tensor().flat(i) = 1; // element in 1D view + input.zeroGrad(); + for(auto* var : zeroGradientVariables) + var->zeroGrad(); + + auto out = func(input); + out.backward(outD); + + bwdJacobian(i) = fl::reshape(input.grad().tensor(), {inElements}); + outD.tensor().flat(i) = 0; + } + + PRINT_TENSOR(fwdJacobian); + PRINT_TENSOR(bwdJacobian); + + return allClose(fwdJacobian, bwdJacobian, precision); + } + +} + +TEST(AutogradTest, Embedding2) { + int nWords = 10; + auto input = Variable{(fl::rand({4, 2}) * nWords).asType(fl::dtype::s32), false}; + PRINT_TENSOR(input.tensor()); + + + auto weights = Variable{fl::randn({4, nWords}, fl::dtype::f64), true}; + PRINT_TENSOR(weights.tensor()); + + auto func_embed = [&](Variable& w) { return embedding2(input, w); }; + + ASSERT_TRUE(detail::jacobianTestImpl2(func_embed, weights, 1E-5)); +} + +} + + + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); fl::init(); diff --git a/flashlight/fl/test/autograd/AutogradTestUtils.h b/flashlight/fl/test/autograd/AutogradTestUtils.h index 6c4cf55..194866f 100644 --- a/flashlight/fl/test/autograd/AutogradTestUtils.h +++ b/flashlight/fl/test/autograd/AutogradTestUtils.h @@ -10,10 +10,11 @@ #include "gtest/gtest.h" #include "flashlight/fl/autograd/Functions.h" +#include "flashlight/fl/autograd/Variable.h" +#include "flashlight/fl/tensor/Compute.h" #include "flashlight/fl/tensor/Index.h" #include "flashlight/fl/tensor/Init.h" #include "flashlight/fl/tensor/Random.h" -#include "flashlight/fl/autograd/Variable.h" namespace fl { namespace detail { @@ -24,51 +25,54 @@ namespace detail { OptimMode::get().setOptimLevel(OptimLevel::O3); } - void TearDown() override { - OptimMode::get().setOptimLevel(OptimLevel::DEFAULT); - } + void TearDown() override { OptimMode::get().setOptimLevel(OptimLevel::DEFAULT); } }; using JacobianFunc = std::function; inline bool jacobianTestImpl( - const JacobianFunc& func, + JacobianFunc const& func, Variable& input, - float precision = 1E-5, + double precision = 1E-5, float perturbation = 1E-4, - const std::vector& zeroGradientVariables = {}) { - auto fwdJacobian = - Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); + std::vector const& zeroGradientVariables = {} + ) { + auto const outBase = func(input); + auto const outElements = outBase.elements(); + auto const inElements = input.elements(); + + auto const fwdJacobian = Tensor({outElements, inElements}, input.type()); - for(int i = 0; i < input.elements(); ++i) { - Tensor orig = input.tensor().flatten()(i); + for(int i = 0; i < inElements; ++i) { + auto orig = input.tensor().flatten()(i); input.tensor().flat(i) = orig - perturbation; - auto outa = func(input).tensor(); + auto outA = func(input).tensor(); input.tensor().flat(i) = orig + perturbation; - auto outb = func(input).tensor(); + auto outB = func(input).tensor(); + input.tensor().flat(i) = orig; - fwdJacobian(fl::span, i) = - fl::reshape((outb - outa), {static_cast(outa.elements())}) * 0.5 - / perturbation; + + fwdJacobian(fl::span, i) = fl::reshape((outB - outA), {static_cast(outA.elements())}) * 0.5 / + perturbation; } - auto bwdJacobian = - Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); - auto dout = - Variable(fl::full(func(input).shape(), 0, func(input).type()), false); + auto const bwdJacobian = Tensor({outElements, inElements}, input.type()); + auto const outD = Variable(fl::full(outBase.shape(), 0, outBase.type()), false); - for(int i = 0; i < dout.elements(); ++i) { - dout.tensor().flat(i) = 1; // element in 1D view + for(int i = 0; i < outD.elements(); ++i) { + outD.tensor().flat(i) = 1; // element in 1D view input.zeroGrad(); for(auto* var : zeroGradientVariables) var->zeroGrad(); auto out = func(input); - out.backward(dout); + out.backward(outD); - bwdJacobian(i) = fl::reshape(input.grad().tensor(), {input.elements()}); - dout.tensor().flat(i) = 0; + bwdJacobian(i) = fl::reshape(input.grad().tensor(), {inElements}); + outD.tensor().flat(i) = 0; } + + return allClose(fwdJacobian, bwdJacobian, precision); } diff --git a/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp b/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp index dddba75..9907d7e 100644 --- a/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp +++ b/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp @@ -42,7 +42,7 @@ TEST(AutogradUnaryOpsTest, Glu) { TEST(AutogradUnaryOpsTest, Sigmoid) { auto x = Variable(fl::rand({5}), true); auto y = sigmoid(x); - auto dy = Variable(fl::full({5}, 1.0), false); + auto dy = Variable(fl::full({5}, 1.f), false); y.backward(dy); auto dx = x.grad(); ASSERT_TRUE(allClose(dx.tensor(), (y.tensor() * (1 - y.tensor())))); @@ -59,7 +59,7 @@ TEST(AutogradUnaryOpsTest, Erf) { auto y = erf(x); ASSERT_TRUE(allClose(fl::erf(x.tensor()), y.tensor())); - auto dy = Variable(fl::full({5}, 1.0), false); + auto dy = Variable(fl::full({5}, 1.f), false); y.backward(dy); auto targetGrads = 2 / std::sqrt(M_PI) * exp(negate(x * x)); auto dx = x.grad(); @@ -72,7 +72,7 @@ TEST(AutogradUnaryOpsTest, Erf) { TEST(AutogradUnaryOpsTest, Tanh) { auto x = Variable(fl::rand({5}), true); auto y = tanh(x); - auto dy = Variable(fl::full({5}, 1.0), false); + auto dy = Variable(fl::full({5}, 1.f), false); y.backward(dy); auto dx = x.grad(); ASSERT_TRUE(allClose(dx.tensor(), (1 - y.tensor() * y.tensor()))); @@ -105,7 +105,7 @@ TEST(AutogradUnaryOpsTest, Transpose) { TEST(AutogradUnaryOpsTest, Exp) { auto x = Variable(fl::rand({5}), true); auto y = exp(x); - auto dy = Variable(fl::full({5}, 1.0), false); + auto dy = Variable(fl::full({5}, 1.f), false); y.backward(dy); auto dx = x.grad(); ASSERT_TRUE(allClose(dx.tensor(), (fl::exp(x.tensor())))); @@ -164,7 +164,7 @@ TEST(AutogradUnaryOpsTest, Pow) { { auto x = Variable(fl::rand({5}), true); auto y = pow(x, 2); - auto dy = Variable(fl::full({5}, 2.0), false); + auto dy = Variable(fl::full({5}, 2.f), false); y.backward(dy); auto dx = x.grad(); ASSERT_TRUE(allClose(dx.tensor(), (2 * 2 * x.tensor()))); @@ -172,7 +172,7 @@ TEST(AutogradUnaryOpsTest, Pow) { { auto x = Variable(fl::rand({5}), true); auto y = pow(x, 3); - auto dy = Variable(fl::full({5}, 1.0), false); + auto dy = Variable(fl::full({5}, 1.f), false); y.backward(dy); auto dx = x.grad(); ASSERT_TRUE(allClose(dx.tensor(), (3 * fl::power(x.tensor(), 2)))); diff --git a/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp b/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp index 8b55692..e5a743f 100644 --- a/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp +++ b/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp @@ -25,9 +25,7 @@ class ContribModuleTestF16 : public ::testing::Test { OptimMode::get().setOptimLevel(OptimLevel::O3); } - void TearDown() override { - OptimMode::get().setOptimLevel(OptimLevel::DEFAULT); - } + void TearDown() override { OptimMode::get().setOptimLevel(OptimLevel::DEFAULT); } }; } // namespace @@ -140,14 +138,15 @@ void transformerPadMaskFwd(bool isfp16) { int nheads = 2; auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; - auto tr = - Transformer(c, c / nheads, c, nheads, timesteps, 0, 0, false, false); + auto tr = Transformer(c, c / nheads, c, nheads, timesteps, 0, 0, false, false); auto input1 = Variable(fl::rand({c, timesteps, /* B = */ 1}, dtype), false); auto input1NoPad = input1(fl::span, fl::range(0, timesteps / 2)); auto input2 = Variable(fl::rand({c, timesteps, /* B = */ 1}, dtype), false); auto input = fl::concatenate({input1, input2}, 2); auto padMask = fl::full({timesteps, 2}, 1); + padMask(fl::iota({timesteps / 2}) + timesteps / 2, 0) = 0; + auto noPadMask = fl::full({timesteps, 2}, 1); auto output = tr.forward({input, Variable(padMask, false)}).front(); @@ -159,43 +158,44 @@ void transformerPadMaskFwd(bool isfp16) { if(OptimMode::get().getOptimLevel() == OptimLevel::O3) ASSERT_EQ(outputNoPad.type(), input.type()); - else { + else ASSERT_EQ(outputNoPad.type(), fl::dtype::f32); // result is upcast - } auto output1 = tr.forward( - { - input1NoPad, - Variable( - padMask(fl::range(0, timesteps / 2))( - fl::span, - fl::range(0, 1) - ), - false - ) - } - ) - .front(); - auto output2 = - tr.forward({input2, Variable(padMask(fl::span, fl::range(1, 2)), false)}) - .front(); - ASSERT_TRUE(allClose( - output.tensor()(fl::span, fl::span, fl::range(1, 2)), output2.tensor())); - ASSERT_TRUE(allClose( - outputNoPad.tensor()(fl::span, fl::span, fl::range(1, 2)), - output2.tensor())); - ASSERT_TRUE(allClose( - output.tensor()(fl::span, fl::iota({timesteps / 2}), fl::range(0, 1)), - output1.tensor())); - ASSERT_FALSE(allClose( - outputNoPad.tensor()( - fl::span, fl::iota({timesteps / 2}), fl::range(0, 1)), - output1.tensor())); + { + input1NoPad, + Variable( + padMask(fl::range(0, timesteps / 2))( + fl::span, + fl::range(0, 1) + ), + false + ) + } + ) + .front(); + + auto output2 = tr.forward({input2, Variable(padMask(fl::span, fl::range(1, 2)), false)}) + .front(); + ASSERT_TRUE( + allClose(output.tensor()(fl::span, fl::span, fl::range(1, 2)), output2.tensor()) + ); + ASSERT_TRUE( + allClose(outputNoPad.tensor()(fl::span, fl::span, fl::range(1, 2)), output2.tensor()) + ); + ASSERT_TRUE( + allClose(output.tensor()(fl::span, fl::iota({timesteps / 2}), fl::range(0, 1)), output1.tensor()) + ); + ASSERT_FALSE( + allClose( + outputNoPad.tensor()( + fl::span, fl::iota({timesteps / 2}), fl::range(0, 1)), + output1.tensor() + ) + ); } -TEST(ContribModuleTest, TransformerPadMaskFwd) { - transformerPadMaskFwd(false); -} +TEST(ContribModuleTest, TransformerPadMaskFwd) { transformerPadMaskFwd(false); } TEST_F(ContribModuleTestF16, TransformerPadMaskFwd16) { if(!fl::f16Supported()) @@ -233,9 +233,7 @@ void transformerFwd(bool isfp16) { ASSERT_TRUE(allClose(output1, output2, 1E-7)); } -TEST(ContribModuleTest, TransformerFwd) { - transformerFwd(false); -} +TEST(ContribModuleTest, TransformerFwd) { transformerFwd(false); } TEST_F(ContribModuleTestF16, TransformerFwdF16) { if(!fl::f16Supported()) @@ -265,9 +263,7 @@ void conformerFwd(bool isfp16) { ASSERT_EQ(output[0].dim(2), batchsize); } -TEST(ContribModuleTest, ConformerFwd) { - conformerFwd(false); -} +TEST(ContribModuleTest, ConformerFwd) { conformerFwd(false); } TEST_F(ContribModuleTestF16, ConformerFwdF16) { if(!fl::f16Supported()) @@ -293,9 +289,7 @@ void positionEmbeddingFwd(bool isfp16) { ASSERT_FALSE(allClose(output[0], input)); } -TEST(ContribModuleTest, PositionEmbeddingFwd) { - positionEmbeddingFwd(false); -} +TEST(ContribModuleTest, PositionEmbeddingFwd) { positionEmbeddingFwd(false); } TEST_F(ContribModuleTestF16, PositionEmbeddingFwdF16) { if(!fl::f16Supported()) @@ -320,14 +314,12 @@ void sinusoidalPositionEmbeddingFwd(bool isfp16) { ASSERT_EQ(output[0].dim(2), batchsize); auto castOutput = output[0].tensor(); if(isfp16) - castOutput = output[0].astype(fl::dtype::f32).tensor(); + castOutput = output[0].asType(fl::dtype::f32).tensor(); ASSERT_TRUE((fl::amax(castOutput, {0})).scalar() <= 2); ASSERT_TRUE((fl::amin(castOutput, {0})).scalar() >= -2); } -TEST(ContribModuleTest, SinusoidalPositionEmbeddingFwd) { - sinusoidalPositionEmbeddingFwd(false); -} +TEST(ContribModuleTest, SinusoidalPositionEmbeddingFwd) { sinusoidalPositionEmbeddingFwd(false); } TEST_F(ContribModuleTestF16, SinusoidalPositionEmbeddingFwdF16) { if(!fl::f16Supported()) @@ -366,9 +358,7 @@ void tdsFwd(bool isfp16) { ASSERT_EQ(output.type(), input.type()); } -TEST(ContribModuleTest, TDSFwd) { - tdsFwd(false); -} +TEST(ContribModuleTest, TDSFwd) { tdsFwd(false); } TEST_F(ContribModuleTestF16, TDSFwdF16) { if(!fl::f16Supported()) @@ -397,9 +387,7 @@ void streamingTDSFwd(bool isfp16) { ASSERT_EQ(output.type(), input.type()); } -TEST(ContribModuleTest, StreamingTDSFwd) { - streamingTDSFwd(false); -} +TEST(ContribModuleTest, StreamingTDSFwd) { streamingTDSFwd(false); } TEST_F(ContribModuleTestF16, StreamingTDSFwdF16) { if(!fl::f16Supported()) @@ -431,7 +419,7 @@ TEST(ContribModuleTest, SpecAugmentFwd) { int tZeros = 0; for(int t = 0; t < T; ++t) { auto curOutSlice = output.tensor()(t); - tZeros = fl::all(curOutSlice == 0).asScalar() ? tZeros + 1 : tZeros; + tZeros = fl::all_of(curOutSlice == 0).asScalar() ? tZeros + 1 : tZeros; } ASSERT_GT(tZeros, 0); @@ -439,7 +427,7 @@ TEST(ContribModuleTest, SpecAugmentFwd) { int fZeros = 0; for(int f = 0; f < F; ++f) { auto curOutSlice = output.tensor()(fl::span, f); - fZeros = fl::all(curOutSlice == 0).asScalar() ? fZeros + 1 : fZeros; + fZeros = fl::all_of(curOutSlice == 0).asScalar() ? fZeros + 1 : fZeros; } ASSERT_GT(fZeros, 0); } @@ -448,7 +436,18 @@ void computeRawWavSpecAug(bool isfp16, float epsilon) { // no time, only freq masking for(int nmask = 1; nmask < 3; nmask++) { RawWavSpecAugment specAug( - 0, 1, nmask, 0, 0, 0, 1, 2000, 6000, 16000, 20000); + 0, + 1, + nmask, + 0, + 0, + 0, + 1, + 2000, + 6000, + 16000, + 20000 + ); specAug.train(); int T = 300, C = 3, B = 4; @@ -460,8 +459,8 @@ void computeRawWavSpecAug(bool isfp16, float epsilon) { inputWav = fl::tile(inputWav, {1, C, B}); finalWav = fl::tile(finalWav, {1, C, B}); if(isfp16) { - inputWav = inputWav.astype(fl::dtype::f16); - finalWav = finalWav.astype(fl::dtype::f16); + inputWav = inputWav.asType(fl::dtype::f16); + finalWav = finalWav.asType(fl::dtype::f16); } auto filteredWav = specAug(fl::Variable(inputWav, false)); @@ -480,9 +479,7 @@ void computeRawWavSpecAug(bool isfp16, float epsilon) { } } -TEST(ContribModuleTest, RawWavSpecAugmentFwd) { - computeRawWavSpecAug(false, 1e-3); -} +TEST(ContribModuleTest, RawWavSpecAugmentFwd) { computeRawWavSpecAug(false, 1e-3); } TEST_F(ContribModuleTestF16, RawWavSpecAugmentFwdF16) { if(!fl::f16Supported()) diff --git a/flashlight/fl/test/contrib/modules/ContribSerializationTest.cpp b/flashlight/fl/test/contrib/modules/ContribSerializationTest.cpp index 0217878..4d7eaa3 100644 --- a/flashlight/fl/test/contrib/modules/ContribSerializationTest.cpp +++ b/flashlight/fl/test/contrib/modules/ContribSerializationTest.cpp @@ -43,7 +43,7 @@ TEST(SerializationTest, Residual) { TEST(SerializationTest, AsymmetricConv1D) { int c = 32; - auto model = std::make_shared(c, c, 5, 1, -1, 0, 1); + auto model = std::make_shared(c, c, 5, 1, -1, 0.f, 1); const fs::path path = fs::temp_directory_path() / "AsymmetricConv1D.mdl"; save(path, model); @@ -71,8 +71,8 @@ TEST(SerializationTest, Transformer) { c, nheads, timesteps, - 0.2, - 0.1, + 0.2f, + 0.1f, false, false ); @@ -107,8 +107,8 @@ TEST(SerializationTest, ConformerSerialization) { nheads, timesteps, 33, - 0.2, - 0.1 + 0.2f, + 0.1f ); model->eval(); @@ -193,7 +193,7 @@ TEST(SerializationTest, RawWavSpecAugment) { 1, 1, 0, - 0, + 0.f, 0, 1, 2000, diff --git a/flashlight/fl/test/nn/ModuleTest.cpp b/flashlight/fl/test/nn/ModuleTest.cpp index cb87cb9..5cbf3b9 100644 --- a/flashlight/fl/test/nn/ModuleTest.cpp +++ b/flashlight/fl/test/nn/ModuleTest.cpp @@ -157,7 +157,7 @@ TEST_F(ModuleTestF16, LinearFwdF16) { {n_in, x, batchsize}, {6, 2, 1, 4, 8, 2, 7, 1, 10, 7, 3, 7, 5, 9, 2, 4} ) - .astype(fl::dtype::f16) + .asType(fl::dtype::f16) ); auto expected_outVar = Variable( @@ -168,7 +168,7 @@ TEST_F(ModuleTestF16, LinearFwdF16) { 150, 55, 41, 94, 41, 27, 130, 55, 37, 56, 24, 16 } ) - .astype(fl::dtype::f16), + .asType(fl::dtype::f16), true ); @@ -187,7 +187,7 @@ TEST_F(ModuleTestF16, LinearFwdF16) { 151, 57, 44, 95, 43, 30, 131, 57, 40, 57, 26, 19 } ) - .astype(inVar.type()), + .asType(inVar.type()), true ); @@ -197,7 +197,7 @@ TEST_F(ModuleTestF16, LinearFwdF16) { ASSERT_TRUE(allClose(resultBias, expected_outVar, 1E-3)); // OptimLevel::O3 is active with this fixture - ASSERT_EQ(linBias.forward(inVar.astype(fl::dtype::f32)).type(), fl::dtype::f16); + ASSERT_EQ(linBias.forward(inVar.asType(fl::dtype::f32)).type(), fl::dtype::f16); } TEST(ModuleTest, ConvPadding) { @@ -265,13 +265,13 @@ TEST_F(ModuleTestF16, GLUFwdF16) { auto inVar = Variable( Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}) - .astype(fl::dtype::f16), + .asType(fl::dtype::f16), true ); auto expected_outVar = Variable( Tensor::fromVector({3, 1}, {0.419983, 0.124492, 0.114888}) - .astype(fl::dtype::f16), + .asType(fl::dtype::f16), true ); @@ -282,7 +282,7 @@ TEST_F(ModuleTestF16, GLUFwdF16) { // test batching int batchsize = 5; - inVar = Variable(fl::rand({10, 7, batchsize}).astype(fl::dtype::f16), true); + inVar = Variable(fl::rand({10, 7, batchsize}).asType(fl::dtype::f16), true); glu = GatedLinearUnit(0); auto batchOutVar = glu(inVar); @@ -346,7 +346,7 @@ TEST_F(ModuleTestF16, LogSoftmaxFwdF16) { auto inVar = Variable( Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}) - .astype(fl::dtype::f16), + .asType(fl::dtype::f16), true ); @@ -671,7 +671,7 @@ TEST_F(ModuleTestF16, RNNFwdF16) { ), true ); - ASSERT_TRUE(allClose(out, expected_outVar.astype(in.type()), 5E-2)); + ASSERT_TRUE(allClose(out, expected_outVar.asType(in.type()), 5E-2)); } TEST(ModuleTest, ViewFwd) { @@ -708,7 +708,7 @@ TEST_F(ModuleTestF16, DropoutFwdF16) { if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - auto module = Dropout(0.5); + auto module = Dropout(0.5f); // Train Mode module.train(); auto in = Variable(fl::rand({1000, 1000}, fl::dtype::f16), true); @@ -723,7 +723,7 @@ TEST_F(ModuleTestF16, DropoutFwdF16) { ASSERT_GT( fl::amax(out.tensor()).asScalar(), - 1.5 + 1.5f ); // Check input is scaled // Eval Mode @@ -806,8 +806,8 @@ TEST_F(ModuleTestF16, LayerNormFwdF16) { auto sample_mean = mean(input, {3}); auto sample_var = var(input, {3}, true); - auto true_out = (input - tileAs(sample_mean, input).astype(input.type())) - / tileAs(fl::sqrt(sample_var + eps), input).astype(input.type()); + auto true_out = (input - tileAs(sample_mean, input).asType(input.type())) + / tileAs(fl::sqrt(sample_var + eps), input).asType(input.type()); // no affine transform auto module1 = LayerNorm(feat_axes, eps, false); @@ -815,13 +815,13 @@ TEST_F(ModuleTestF16, LayerNormFwdF16) { module1.train(); auto out = module1.forward(input); - ASSERT_TRUE(allClose(out, true_out.astype(out.type()), eps)); + ASSERT_TRUE(allClose(out, true_out.asType(out.type()), eps)); module1.eval(); out = module1.forward(input); ASSERT_TRUE( - allClose(out.tensor(), true_out.tensor().astype(out.type()), eps) + allClose(out.tensor(), true_out.tensor().asType(out.type()), eps) ); // with affine transform @@ -863,12 +863,12 @@ TEST(ModuleTest, PrecisionCastFwd) { if(!fl::f16Supported()) GTEST_SKIP() << "Half precision not available on this device"; - auto in = Variable(fl::full({3, 3}, 1.0), true); + auto in = Variable(fl::full({3, 3}, 1.f), true); auto precisionCast = PrecisionCast(fl::dtype::f16); auto out = precisionCast.forward(in); ASSERT_EQ(out.type(), fl::dtype::f16); - ASSERT_TRUE(allClose(in.tensor(), out.astype(fl::dtype::f32).tensor())); + ASSERT_TRUE(allClose(in.tensor(), out.asType(fl::dtype::f32).tensor())); } TEST(ModuleTest, ContainerReplaceParam) { @@ -912,7 +912,7 @@ TEST(ModuleTest, AdaptiveSoftMaxPredict) { auto x = input(fl::rand({N, T, B}, fl::dtype::f32)); auto y = Variable( - (fl::rand({T, B}, fl::dtype::u32) % C).astype(fl::dtype::s32), + (fl::rand({T, B}, fl::dtype::u32) % C).asType(fl::dtype::s32), false ); @@ -935,7 +935,7 @@ TEST(ModuleTest, AdaptiveSoftMaxLossBatchFwd) { auto x = input(fl::rand({N, T, B}, fl::dtype::f32)); auto y = Variable( - (fl::rand({T, B}, fl::dtype::u32) % C).astype(fl::dtype::s32), + (fl::rand({T, B}, fl::dtype::u32) % C).asType(fl::dtype::s32), false ); @@ -966,7 +966,7 @@ TEST(ModuleTest, AdaptiveSoftMaxLossIgnoreIndex) { auto x = input(fl::rand({N, T, B}, fl::dtype::f32)); auto y = Variable( - (fl::rand({T, B}, fl::dtype::u32) % C).astype(fl::dtype::s32), + (fl::rand({T, B}, fl::dtype::u32) % C).asType(fl::dtype::s32), false ); auto ignoreIdx = y(0, 0).scalar(); diff --git a/flashlight/fl/test/nn/NNSerializationTest.cpp b/flashlight/fl/test/nn/NNSerializationTest.cpp index fb86f5a..88016e4 100644 --- a/flashlight/fl/test/nn/NNSerializationTest.cpp +++ b/flashlight/fl/test/nn/NNSerializationTest.cpp @@ -238,7 +238,7 @@ TEST(NNSerializationTest, PrettyString) { "(0): Conv2D (3->64, 5x5, 1, 1, 0, 0, 1, 1) (with bias)" "(1): Pool2D-max (3x3, 2,2, 1,1)" "(2): ReLU" - "(3): Dropout (0.400000)" + "(3): Dropout (0.4)" "(4): Linear (5->10) (without bias)" "(5): Tanh" "(6): LeakyReLU (0.200000)"; diff --git a/flashlight/fl/test/nn/NNUtilsTest.cpp b/flashlight/fl/test/nn/NNUtilsTest.cpp index a07f731..a009881 100644 --- a/flashlight/fl/test/nn/NNUtilsTest.cpp +++ b/flashlight/fl/test/nn/NNUtilsTest.cpp @@ -23,7 +23,7 @@ TEST(UtilsTest, Join) { // Single array auto i = fl::rand({50, 60, 70, 1}); auto o = join({i}, -1, 3); - ASSERT_TRUE(fl::all(o == i).asScalar()); + ASSERT_TRUE(fl::all_of(o == i).asScalar()); // no dim for batching adds singleton dims ASSERT_EQ( @@ -40,14 +40,14 @@ TEST(UtilsTest, Join) { auto o1 = join({a, b, c}); ASSERT_EQ(o1.shape(), Shape({30, 1, 300, 3})); ASSERT_TRUE( - fl::all( + fl::all_of( o1(fl::range(25), fl::range(0, 1), fl::range(300), fl::range(0, 1)) == a ) .asScalar() ); ASSERT_TRUE( - fl::all( + fl::all_of( o1( fl::range(25, 29), fl::range(0, 1), @@ -58,14 +58,14 @@ TEST(UtilsTest, Join) { .asScalar() ); ASSERT_TRUE( - fl::all( + fl::all_of( o1(fl::range(20), fl::range(0, 1), fl::range(300), fl::range(1, 2)) == b ) .asScalar() ); ASSERT_TRUE( - fl::all( + fl::all_of( o1( fl::range(20, 29), fl::range(0, 1), @@ -76,7 +76,7 @@ TEST(UtilsTest, Join) { .asScalar() ); ASSERT_TRUE( - fl::all( + fl::all_of( o1(fl::range(30), fl::range(0, 1), fl::range(300), fl::range(2, 3)) == c ) @@ -86,14 +86,14 @@ TEST(UtilsTest, Join) { auto o2 = join({a, b, c}, -1); ASSERT_EQ(o2.shape(), Shape({30, 1, 300, 3})); ASSERT_TRUE( - fl::all( + fl::all_of( o2(fl::range(25), fl::range(0, 1), fl::range(300), fl::range(0, 1)) == a ) .asScalar() ); ASSERT_TRUE( - fl::all( + fl::all_of( o2( fl::range(25, 29), fl::range(0, 1), @@ -104,14 +104,14 @@ TEST(UtilsTest, Join) { .asScalar() ); ASSERT_TRUE( - fl::all( + fl::all_of( o2(fl::range(20), fl::range(0, 1), fl::range(300), fl::range(1, 2)) == b ) .asScalar() ); ASSERT_TRUE( - fl::all( + fl::all_of( o2( fl::range(20, 29), fl::range(0, 1), @@ -122,7 +122,7 @@ TEST(UtilsTest, Join) { .asScalar() ); ASSERT_TRUE( - fl::all( + fl::all_of( o2(fl::range(30), fl::range(0, 1), fl::range(300), fl::range(2, 3)) == c ) @@ -132,23 +132,23 @@ TEST(UtilsTest, Join) { auto o3 = join({a, b, c}, -1, 1); ASSERT_EQ(o3.shape(), Shape({30, 3, 300, 1})); ASSERT_TRUE( - fl::all(o3(fl::range(25), fl::range(0, 1), fl::range(300)) == a) + fl::all_of(o3(fl::range(25), fl::range(0, 1), fl::range(300)) == a) .asScalar() ); ASSERT_TRUE( - fl::all(o3(fl::range(25, 29), fl::range(0, 1), fl::range(300)) == -1) + fl::all_of(o3(fl::range(25, 29), fl::range(0, 1), fl::range(300)) == -1) .asScalar() ); ASSERT_TRUE( - fl::all(o3(fl::range(20), fl::range(1, 2), fl::range(300)) == b) + fl::all_of(o3(fl::range(20), fl::range(1, 2), fl::range(300)) == b) .asScalar() ); ASSERT_TRUE( - fl::all(o3(fl::range(20, 29), fl::range(1, 2), fl::range(300)) == -1) + fl::all_of(o3(fl::range(20, 29), fl::range(1, 2), fl::range(300)) == -1) .asScalar() ); ASSERT_TRUE( - fl::all(o3(fl::range(30), fl::range(2, 3), fl::range(300)) == c) + fl::all_of(o3(fl::range(30), fl::range(2, 3), fl::range(300)) == c) .asScalar() ); } diff --git a/flashlight/fl/test/optim/OptimTest.cpp b/flashlight/fl/test/optim/OptimTest.cpp index d7c8e1c..c7bfeca 100644 --- a/flashlight/fl/test/optim/OptimTest.cpp +++ b/flashlight/fl/test/optim/OptimTest.cpp @@ -19,7 +19,7 @@ TEST(OptimTest, GradNorm) { std::vector parameters; for(int i = 0; i < 5; i++) { auto v = Variable(fl::randn({10, 10, 10}), true); - v = v.astype(fl::dtype::f64); + v = v.asType(fl::dtype::f64); v.addGrad(Variable(fl::randn({10, 10, 10}, fl::dtype::f64), false)); parameters.push_back(v); } @@ -42,7 +42,7 @@ TEST(OptimTest, GradNormF16) { std::vector parameters; for(int i = 0; i < 5; i++) { auto v = Variable(fl::randn({10, 10, 10}), true); - v = v.astype(fl::dtype::f16); + v = v.asType(fl::dtype::f16); v.addGrad(Variable(fl::randn({10, 10, 10}, fl::dtype::f16), false)); parameters.push_back(v); } diff --git a/flashlight/fl/test/tensor/IndexTest.cpp b/flashlight/fl/test/tensor/IndexTest.cpp index 920de06..b7d86b6 100644 --- a/flashlight/fl/test/tensor/IndexTest.cpp +++ b/flashlight/fl/test/tensor/IndexTest.cpp @@ -91,27 +91,27 @@ TEST(IndexTest, IndexAssignment) { t /= 7; ASSERT_TRUE(allClose(t, fl::full({4, 4}, 1))); - auto a = fl::full({6, 6}, 0.); + auto a = fl::full({6, 6}, 0.f); a(3, 4) = 4.; - ASSERT_TRUE(allClose(a(3, 4), fl::full({1}, 4.))); + ASSERT_TRUE(allClose(a(3, 4), fl::full({1}, 4.f))); a(2) = fl::full({6}, 8.); - ASSERT_TRUE(allClose(a(2), fl::full({6}, 8.))); + ASSERT_TRUE(allClose(a(2), fl::full({6}, 8.f))); - auto b = fl::full({3, 3}, 1.); + auto b = fl::full({3, 3}, 1.f); auto c = b; b += 1; - ASSERT_TRUE(allClose(b, fl::full({3, 3}, 2.))); - ASSERT_TRUE(allClose(c, fl::full({3, 3}, 1.))); + ASSERT_TRUE(allClose(b, fl::full({3, 3}, 2.f))); + ASSERT_TRUE(allClose(c, fl::full({3, 3}, 1.f))); - auto q = fl::full({4, 4}, 2.); - auto r = fl::full({4}, 3.); + auto q = fl::full({4, 4}, 2.f); + auto r = fl::full({4}, 3.f); q(0) = r; ASSERT_TRUE(allClose(q(0), r)); - ASSERT_TRUE(allClose(q(fl::range(1, fl::end)), fl::full({3, 4}, 2.))); + ASSERT_TRUE(allClose(q(fl::range(1, fl::end)), fl::full({3, 4}, 2.f))); auto k = fl::rand({100, 200}); - k(3) = fl::full({200}, 0.); - ASSERT_TRUE(allClose(k(3), fl::full({200}, 0.))); + k(3) = fl::full({200}, 0.f); + ASSERT_TRUE(allClose(k(3), fl::full({200}, 0.f))); // Weak ref auto g = fl::rand({3, 4, 5}); @@ -122,29 +122,29 @@ TEST(IndexTest, IndexAssignment) { ASSERT_TRUE(allClose(gC(fl::span, fl::range(0, 3)), gI)); auto x = fl::rand({5, 6, 7, 8}); - x(3) = fl::full({6, 7, 8}, 0.); - ASSERT_TRUE(allClose(x(3), fl::full({6, 7, 8}, 0.))); - x(fl::span, fl::span, 2) = fl::full({5, 6, 8}, 3.); - ASSERT_TRUE(allClose(x(fl::span, fl::span, 2), fl::full({5, 6, 8}, 3.))); + x(3) = fl::full({6, 7, 8}, 0.f); + ASSERT_TRUE(allClose(x(3), fl::full({6, 7, 8}, 0.f))); + x(fl::span, fl::span, 2) = fl::full({5, 6, 8}, 3.f); + ASSERT_TRUE(allClose(x(fl::span, fl::span, 2), fl::full({5, 6, 8}, 3.f))); ASSERT_THROW( x(fl::span, fl::span, 4) -= fl::rand({5, 6, 1, 8}), std::invalid_argument ); - x(fl::span, fl::range(1, 3), fl::span) = fl::full({5, 2, 7, 8}, 2.); + x(fl::span, fl::range(1, 3), fl::span) = fl::full({5, 2, 7, 8}, 2.f); ASSERT_TRUE( allClose( x(fl::span, fl::range(1, 3), fl::span), - fl::full({5, 2, 7, 8}, 2.) + fl::full({5, 2, 7, 8}, 2.f) ) ); x(fl::span, fl::arange({5}), fl::span, fl::arange({5})) = - fl::full({5, 5, 7, 5}, 2.); + fl::full({5, 5, 7, 5}, 2.f); ASSERT_TRUE( allClose( x(fl::span, fl::range(1, 3), fl::span), - fl::full({5, 2, 7, 8}, 2.) + fl::full({5, 2, 7, 8}, 2.f) ) ); } @@ -178,7 +178,7 @@ TEST(IndexTest, flat) { for(unsigned i = 0; i < n.elements(); ++i) ASSERT_TRUE(allClose(n.flat(i), n(i % 4, (i / 4) % 6, (i / (4 * 6)) % 8))); - auto a = fl::full({5, 6, 7, 8}, 9.); + auto a = fl::full({5, 6, 7, 8}, 9.f); std::vector testIndices = {0, 1, 4, 11, 62, 104, 288}; for(const int i : testIndices) ASSERT_EQ(a.flat(i).scalar(), 9.); @@ -197,10 +197,10 @@ TEST(IndexTest, flat) { // Tensor assignment a.flat(32) = fl::full({1}, 7.4); - ASSERT_TRUE(allClose(a.flatten()(32), fl::full({1}, 7.4))); + ASSERT_TRUE(allClose(a.flatten()(32), fl::full({1}, 7.4f))); // In-place a.flat(100) += 33; - ASSERT_TRUE(allClose(a.flatten()(100), fl::full({1}, 33 + 9.))); + ASSERT_TRUE(allClose(a.flatten()(100), fl::full({1}, 33 + 9.f))); // Tensor indexing auto indexer = Tensor::fromVector(testIndices); @@ -223,8 +223,8 @@ TEST(IndexTest, flat) { // With leading singleton dims auto b = fl::rand({1, 1, 10}); ASSERT_EQ(b.flat(fl::range(3)).shape(), Shape({3})); - b.flat(fl::range(3)) = fl::full({3}, 6.); - ASSERT_TRUE(allClose(b.flatten()(fl::range(3)), fl::full({3}, 6.))); + b.flat(fl::range(3)) = fl::full({3}, 6.f); + ASSERT_TRUE(allClose(b.flatten()(fl::range(3)), fl::full({3}, 6.f))); } TEST(IndexTest, TensorIndex) { @@ -239,7 +239,7 @@ TEST(IndexTest, TensorIndex) { ASSERT_TRUE(allClose(indexed(i), a(idxs[i]))); a(indices) = 5.; - ASSERT_TRUE(allClose(a(indices), fl::full({size}, 5.))); + ASSERT_TRUE(allClose(a(indices), fl::full({size}, 5.f))); // Out of range indices auto i = fl::arange({10}, 0, fl::dtype::u32); @@ -251,7 +251,7 @@ TEST(IndexTest, TensorIndex) { b(i) += 3.; ASSERT_TRUE(allClose(b(i), b(fl::range(10)))); ASSERT_TRUE(allClose(b(i), (ref + 3)(i))); - b(i) += fl::full({(Dim) i.elements(), b.dim(1)}, 10.); + b(i) += fl::full({(Dim) i.elements(), b.dim(1)}, 10.f); ASSERT_EQ(b(i).shape(), (ref + 13)(i).shape()); ASSERT_TRUE(allClose(b(i), (ref + 13)(i))); diff --git a/flashlight/fl/test/tensor/TensorBLASTest.cpp b/flashlight/fl/test/tensor/TensorBLASTest.cpp index df14ba8..e30729b 100644 --- a/flashlight/fl/test/tensor/TensorBLASTest.cpp +++ b/flashlight/fl/test/tensor/TensorBLASTest.cpp @@ -18,17 +18,17 @@ TEST(TensorBLASTest, matmul) { // TODO: test tensors with order > 2 // Reference impl - auto matmulRef = [](const Tensor& lhs, const Tensor& rhs) { + auto matmulRef = [](Tensor const& lhs, Tensor const& rhs) { // (M x N) x (N x K) --> (M x K) int M = lhs.dim(0); int N = lhs.dim(1); int K = rhs.dim(1); - auto out = fl::full({M, K}, 0.); + auto out = fl::full({M, K}, 0.f); - for(unsigned i = 0; i < M; ++i) - for(unsigned j = 0; j < K; ++j) - for(unsigned k = 0; k < N; ++k) + for(int i = 0; i < M; ++i) + for(int j = 0; j < K; ++j) + for(int k = 0; k < N; ++k) out(i, j) += lhs(i, k) * rhs(k, j); return out; }; diff --git a/flashlight/fl/test/tensor/TensorBaseTest.cpp b/flashlight/fl/test/tensor/TensorBaseTest.cpp index 2a7e1b1..bd86937 100644 --- a/flashlight/fl/test/tensor/TensorBaseTest.cpp +++ b/flashlight/fl/test/tensor/TensorBaseTest.cpp @@ -16,117 +16,283 @@ #include "flashlight/fl/tensor/Random.h" #include "flashlight/fl/tensor/TensorBase.h" +#include + using namespace ::testing; using namespace fl; + +TEST(TensorBaseTest, FullTypeMismatch) { + Shape const shape{2, 2}; + + // Case where everything matches + auto const x0 = fl::full(shape, 1.0, fl::dtype::f64); + ASSERT_EQ(x0.type(), fl::dtype::f64); + ASSERT_EQ(x0.shape(), shape); + + std::span x0Span(x0.host(), x0.elements()); + for(double val : x0Span) + ASSERT_EQ(val, 1.0); + + auto const x1 = fl::full(shape, 0, fl::dtype::f64); + ASSERT_EQ(x1.type(), fl::dtype::f64); + ASSERT_EQ(x1.shape(), shape); + + std::span x1Span(x1.host(), x1.elements()); + for(double val : x1Span) + ASSERT_EQ(val, 0.0); + + auto const x2 = fl::full(shape, 1.0, fl::dtype::s32); + ASSERT_EQ(x2.type(), fl::dtype::s32); + ASSERT_EQ(x2.shape(), shape); + + std::span const x2Span(x2.host(), x2.elements()); + for(int val : x2Span) + ASSERT_EQ(val, 1); +} +TEST(TensorBaseTest, ArangeTypeMismatch) { + // Case where everything matches + auto const y0 = fl::arange(0.0, 4.0, 1.0, fl::dtype::f64); + ASSERT_EQ(y0.type(), fl::dtype::f64); + ASSERT_EQ(y0.shape(), Shape({4})); + ASSERT_EQ(y0.scalar(), 0.0); + + // Emitting int literals while requesting f64 tensor creation. + auto const y1 = fl::arange(0, 4, 1, fl::dtype::f64); + ASSERT_EQ(y1.type(), fl::dtype::f64); + ASSERT_EQ(y1.shape(), Shape({4})); + ASSERT_EQ(y1.scalar(), 0.0); + + // Emitting double literals while requesting s32 tensor creation. + auto const y2 = fl::arange(0.0, 4.0, 1.0, fl::dtype::s32); + ASSERT_EQ(y2.type(), fl::dtype::s32); + ASSERT_EQ(y2.shape(), Shape({4})); + ASSERT_EQ(y2.scalar(), 0); +} + +TEST(TensorBaseTest, Concatenate) { + auto a = fl::full({3, 3}, 1.f); + auto b = fl::full({3, 3}, 2.f); + auto c = fl::full({3, 3}, 3.f); + ASSERT_TRUE( + allClose(fl::concatenate(0, a, b, c), fl::concatenate({a, b, c})) + ); + auto const out = fl::concatenate(0, a, b, c); + ASSERT_EQ(out.shape(), (Shape{9, 3})); + + // Empty tenors + ASSERT_EQ(fl::concatenate(0, Tensor{}, Tensor{}).shape(), Shape{0}); + ASSERT_EQ(fl::concatenate(2, Tensor{}, Tensor{}).shape(), (Shape{0, 1, 1})); + ASSERT_EQ( + fl::concatenate(1, fl::rand({5, 5}), Tensor{}).shape(), + (Shape{5, 5}) + ); +} + +TEST(TensorBaseTest, ConcatenateMany) { + for(int n = 1; n <= 30; ++n) { + std::vector tensors{}; + std::vector expectedData{}; + long long totalSize = 0; + + for(size_t i = 0; i < n; ++i) { + // Variable width: i + 1 elements + auto const width = i + 1; + + // Variable content: start at i * 10 + auto const startVal = i * 10; + auto t = fl::arange(startVal, startVal + width, 1, fl::dtype::s32); + tensors.push_back(t); + + for(size_t j = 0; j < width; ++j) + expectedData.push_back(startVal + j); + totalSize += width; + } + + auto result = fl::concatenate(tensors, /* axis = */ 0); + auto expectedTensor = Tensor::fromVector({totalSize}, expectedData); + + ASSERT_EQ(result.shape(), Shape({totalSize})); + ASSERT_TRUE(allClose(result, expectedTensor)); + } +} + + +TEST(TensorBaseTest, ConcatenateDuplicateTensors) { + auto t1 = fl::full({2, 2}, 1.0f, fl::dtype::f32); + auto t2 = fl::full({2, 2}, 2.0f, fl::dtype::f32); + + auto result = fl::concatenate({t1, t2, t1, t2}, /* axis = */ 1); + auto expected = fl::concatenate({t1.copy(), t2.copy(), t1.copy(), t2.copy()}, /* axis = */ 1); + + ASSERT_TRUE(allClose(result, expected)); +} + +TEST(TensorBaseTest, ConcatenateViews) { + std::vector const data{ + 0.1f, + 0.2f, + 0.3f, + 0.4f, + 0.5f, + 1.1f, + 1.2f, + 1.3f, + 1.4f, + 1.5f + }; + auto const t = fl::Tensor::fromVector({5, 2}, data); + + auto const vertTiled = fl::concatenate( + 0, + fl::reshape(t(0, fl::span), {1, 2}), + t, + fl::reshape(t(t.dim(0) - 1, fl::span), {1, 2}) + ); + auto vTiled0 = vertTiled(fl::span, 0); + auto vTiled1 = vertTiled(fl::span, 1); + + auto const result = fl::concatenate({vTiled1, vTiled0, vTiled0, vTiled1, vTiled1, vTiled0}, 1); + + auto const expectedSymmetricPad = fl::concatenate( + {vTiled1.copy(), vTiled0.copy(), vTiled0.copy(), vTiled1.copy(), vTiled1.copy(), vTiled0.copy()}, + 1 + ); + + ASSERT_TRUE(allClose(result, expectedSymmetricPad)); +} + + + TEST(TensorBaseTest, DefaultConstruction) { - Tensor t; + Tensor const t{}; ASSERT_EQ(t.shape(), Shape({0})); ASSERT_EQ(t.type(), fl::dtype::f32); - Tensor u({1, 2, 3}); + Tensor const u{{1, 2, 3}}; ASSERT_EQ(u.shape(), Shape({1, 2, 3})); ASSERT_EQ(u.type(), fl::dtype::f32); - Tensor x({0, 3}); + Tensor const x({0, 3}); ASSERT_EQ(x.shape(), Shape({0, 3})); - Tensor q(fl::dtype::f64); + Tensor const q(fl::dtype::f64); ASSERT_EQ(q.shape(), Shape({0})); ASSERT_EQ(q.type(), fl::dtype::f64); - Tensor v({4, 5, 6}, fl::dtype::u64); + Tensor const v({4, 5, 6}, fl::dtype::u64); ASSERT_EQ(v.shape(), Shape({4, 5, 6})); ASSERT_EQ(v.type(), fl::dtype::u64); } TEST(TensorBaseTest, CopyConstruction) { - Shape shape{2, 2}; - auto x = fl::full(shape, 0); - auto y = x; // actual copy (implementation may be CoW) + Shape const shape{2, 2}; + constexpr auto initialValue = 0; + constexpr auto afterIncrement = 23; - ASSERT_TRUE(allClose(x, fl::full(shape, 0))); - ASSERT_TRUE(allClose(y, fl::full(shape, 0))); - x += 23; // affects both tensors - ASSERT_TRUE(allClose(x, fl::full(shape, 23))); - ASSERT_TRUE(allClose(y, fl::full(shape, 0))); + auto x = fl::full(shape, initialValue); + auto const y = x; // actual copy (implementation may be CoW) + + ASSERT_TRUE(allClose(x, fl::full(shape, initialValue))); + ASSERT_TRUE(allClose(y, fl::full(shape, initialValue))); + x += afterIncrement; // affects both tensors + ASSERT_TRUE(allClose(x, fl::full(shape, afterIncrement))); + ASSERT_TRUE(allClose(y, fl::full(shape, initialValue))); } TEST(TensorBaseTest, MoveConstruction) { - Shape shape{2, 2}; - auto x = fl::full(shape, 0); - auto y = x(span, span); // view of x + Shape const shape{2, 2}; + constexpr auto initialValue = 0; + constexpr auto afterMove = 42; + + auto x = fl::full(shape, initialValue); + auto const y = x(span, span); // view of x auto z = std::move(x); // `z` takes over `x`'s data // TODO the following line (or any read to `y`, as it seems) promotes view to // copy; to avoid this, we must update impl of `assign` // ASSERT_TRUE(allClose(y, fl::full(shape, 0))); - ASSERT_TRUE(allClose(z, fl::full(shape, 0))); + ASSERT_TRUE(allClose(z, fl::full(shape, initialValue))); - z += 42; // `y` is now a view of `z`, so it's affected - ASSERT_TRUE(allClose(y, fl::full(shape, 42))); - ASSERT_TRUE(allClose(z, fl::full(shape, 42))); + z += afterMove; // `y` is now a view of `z`, so it's affected + ASSERT_TRUE(allClose(y, fl::full(shape, afterMove))); + ASSERT_TRUE(allClose(z, fl::full(shape, afterMove))); } TEST(TensorBaseTest, AssignmentOperatorLvalueWithRvalue) { - Shape shape{2, 2}; - auto x = fl::full({2, 2}, 0); + Shape const shape{2, 2}; + constexpr auto initialValue = 0; + constexpr auto assignedValue = 42; + constexpr auto expectedAfterIncrement = 43; + + auto const x = fl::full(shape, initialValue); auto y = x(span, span); // view as a lvalue cannot be used to update original tensor - y = fl::full({2, 2}, 42); // `x` isn't affected + y = fl::full(shape, assignedValue); // `x` isn't affected y += 1; // `x` isn't affected - ASSERT_TRUE(allClose(x, fl::full(shape, 0))); - ASSERT_TRUE(allClose(y, fl::full(shape, 43))); + ASSERT_TRUE(allClose(x, fl::full(shape, initialValue))); + ASSERT_TRUE(allClose(y, fl::full(shape, expectedAfterIncrement))); } TEST(TensorBaseTest, AssignmentOperatorLvalueWithLvalue) { - Shape shape{2, 2}; - auto x = fl::full({2, 2}, 0); + Shape const shape{2, 2}; + constexpr auto initialValue = 0; + constexpr auto value1 = 1; + constexpr auto expectedAfterAssignment = 2; + + auto const x = fl::full(shape, initialValue); auto y = x(span, span); - auto z = fl::full({2, 2}, 1); + auto const z = fl::full(shape, value1); y = z; // `x` is a copy of `z` now (impl may be CoW) - y += 1; // `z` isn't affected - ASSERT_TRUE(allClose(x, fl::full(shape, 0))); - ASSERT_TRUE(allClose(y, fl::full(shape, 2))); - ASSERT_TRUE(allClose(z, fl::full(shape, 1))); + y += value1; // `z` isn't affected + ASSERT_TRUE(allClose(x, fl::full(shape, initialValue))); + ASSERT_TRUE(allClose(y, fl::full(shape, expectedAfterAssignment))); + ASSERT_TRUE(allClose(z, fl::full(shape, value1))); } TEST(TensorBaseTest, AssignmentOperatorRvalueWithRvalue) { - Shape shape{2, 2}; - auto type = dtype::f32; - auto x = fl::full({2, 2}, 0, type); - auto y = x(span, span); + Shape const shape{2, 2}; + constexpr auto initialValue = 0; + constexpr auto assignValue = 1; + + auto const type = dtype::f32; + auto const x = fl::full(shape, initialValue, type); + auto const y = x(span, span); - x(0, span) = fl::full({2}, 1); // `x` is updated by copying from rhs data - auto res = fl::Tensor::fromVector(shape, {1, 0, 1, 0}, type); + x(0, span) = fl::full({2}, assignValue); // `x` is updated by copying from rhs data + auto const res = fl::Tensor::fromVector(shape, {1, 0, 1, 0}, type); ASSERT_TRUE(allClose(x, res)); ASSERT_TRUE(allClose(y, res)); } TEST(TensorBaseTest, AssignmentOperatorRvalueWithLvalue) { - Shape shape{2, 2}; - auto type = dtype::f32; - auto x = fl::full(shape, 0, type); - auto y = x(span, span); // view of `x` - auto z = fl::full({2}, 1, type); + Shape const shape{2, 2}; + constexpr auto initialValue = 0; + constexpr auto vectorValue = 1; + constexpr auto expectedAfterIncrement = 2; + + auto const type = dtype::f32; + auto x = fl::full(shape, initialValue, type); + auto const y = x(span, span); // view of `x` + auto const z = fl::full({2}, vectorValue, type); x(span, 1) = z; // `x` is updated by copying from `z`'s data - x += 1; // `z` isn't affected - auto res = fl::Tensor::fromVector(shape, {1, 1, 2, 2}, type); + x += vectorValue; // `z` isn't affected + auto const res = fl::Tensor::fromVector(shape, {1, 1, 2, 2}, type); ASSERT_TRUE(allClose(x, res)); ASSERT_TRUE(allClose(y, res)); - ASSERT_TRUE(allClose(z, fl::full({2}, 1, type))); + ASSERT_TRUE(allClose(z, fl::full({2}, vectorValue, type))); } TEST(TensorBaseTest, Metadata) { - int s = 9; - auto t = fl::rand({s, s}); - ASSERT_EQ(t.elements(), s * s); + int size = 9; + auto const t = fl::rand({size, size}); + ASSERT_EQ(t.elements(), size * size); ASSERT_FALSE(t.isEmpty()); - ASSERT_EQ(t.bytes(), s * s * sizeof(float)); + ASSERT_EQ(t.bytes(), size * size * sizeof(float)); - Tensor e; + Tensor const e; ASSERT_EQ(e.elements(), 0); ASSERT_TRUE(e.isEmpty()); ASSERT_FALSE(e.isSparse()); @@ -134,7 +300,11 @@ TEST(TensorBaseTest, Metadata) { } TEST(TensorBaseTest, fromScalar) { - Tensor a = fromScalar(3.14, fl::dtype::f32); + constexpr auto scalarValue = 3.14; + auto const type = fl::dtype::f32; + + Tensor const a = fromScalar(scalarValue, type); + ASSERT_EQ(a.type(), type); ASSERT_EQ(a.elements(), 1); ASSERT_EQ(a.ndim(), 0); ASSERT_FALSE(a.isEmpty()); @@ -144,8 +314,8 @@ TEST(TensorBaseTest, fromScalar) { TEST(TensorBaseTest, string) { // Different backends might print tensors differently - check for consistency // across two identical tensors - auto a = fl::full({3, 4, 5}, 6.); - auto b = fl::full({3, 4, 5}, 6.); + auto const a = fl::full({3, 4, 5}, 6.f); + auto const b = fl::full({3, 4, 5}, 6.f); ASSERT_EQ(a.toString(), b.toString()); std::stringstream ssa, ssb; @@ -155,51 +325,52 @@ TEST(TensorBaseTest, string) { } TEST(TensorBaseTest, AssignmentOperators) { - auto a = fl::full({3, 3}, 1.); + auto a = fl::full({3, 3}, 1.f); a += 2; - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 3.))); + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 3.f))); a -= 1; - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 2.))); + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 2.f))); a *= 8; - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 16.))); + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 16.f))); a /= 4; - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 4.))); + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 4.f))); - a = fl::full({4, 4}, 7.); - ASSERT_TRUE(allClose(a, fl::full({4, 4}, 7.))); - auto b = a; - ASSERT_TRUE(allClose(b, fl::full({4, 4}, 7.))); + a = fl::full({4, 4}, 7.f); + ASSERT_TRUE(allClose(a, fl::full({4, 4}, 7.f))); + auto const b = a; + ASSERT_TRUE(allClose(b, fl::full({4, 4}, 7.f))); a = 6.; - ASSERT_TRUE(allClose(a, fl::full({4, 4}, 6.))); + ASSERT_TRUE(allClose(a, fl::full({4, 4}, 6.f))); - a = fl::full({5, 6, 7}, 8.); - ASSERT_TRUE(allClose(a, fl::full({5, 6, 7}, 8.))); + a = fl::full({5, 6, 7}, 8.f); + ASSERT_TRUE(allClose(a, fl::full({5, 6, 7}, 8.f))); } TEST(TensorBaseTest, CopyOperators) { - auto a = fl::full({3, 3}, 1.); - auto b = a; + auto a = fl::full({3, 3}, 1.f); + auto const b = a; a += 1; - ASSERT_TRUE(allClose(b, fl::full({3, 3}, 1.))); - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 2.))); + ASSERT_TRUE(allClose(b, fl::full({3, 3}, 1.f))); + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 2.f))); - auto c = a.copy(); + auto const c = a.copy(); a += 1; - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 3.))); - ASSERT_TRUE(allClose(c, fl::full({3, 3}, 2.))); + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 3.f))); + ASSERT_TRUE(allClose(c, fl::full({3, 3}, 2.f))); } TEST(TensorBaseTest, ConstructFromData) { // Tensor::fromVector - float val = 3.; - std::vector vec(100, val); - fl::Shape s = {10, 10}; - ASSERT_TRUE(allClose(fl::Tensor::fromVector(s, vec), fl::full(s, val))); + constexpr auto vectorSize = 100; + float fillValue = 3.f; + std::vector vec(vectorSize, fillValue); + fl::Shape bigShape = {10, 10}; + ASSERT_TRUE(allClose(fl::Tensor::fromVector(bigShape, vec), fl::full(bigShape, fillValue))); ASSERT_TRUE( allClose( - fl::Tensor::fromBuffer(s, vec.data(), fl::MemoryLocation::Host), - fl::full(s, val) + fl::Tensor::fromBuffer(bigShape, vec.data(), fl::MemoryLocation::Host), + fl::full(bigShape, fillValue) ) ); @@ -240,8 +411,8 @@ TEST(TensorBaseTest, ConstructFromData) { } TEST(TensorBaseTest, reshape) { - auto a = fl::full({4, 4}, 3.); - auto b = fl::reshape(a, Shape({8, 2})); + auto const a = fl::full({4, 4}, 3.f); + auto const b = fl::reshape(a, Shape({8, 2})); ASSERT_EQ(b.shape(), Shape({8, 2})); ASSERT_TRUE(allClose(a, fl::reshape(b, {4, 4}))); @@ -251,12 +422,12 @@ TEST(TensorBaseTest, reshape) { TEST(TensorBaseTest, transpose) { // TODO: expand to check els ASSERT_TRUE( - allClose(fl::transpose(fl::full({3, 4}, 3.)), fl::full({4, 3}, 3.)) + allClose(fl::transpose(fl::full({3, 4}, 3.f)), fl::full({4, 3}, 3.f)) ); ASSERT_TRUE( allClose( - fl::transpose(fl::full({4, 5, 6, 7}, 3.), {2, 0, 1, 3}), - fl::full({6, 4, 5, 7}, 3.) + fl::transpose(fl::full({4, 5, 6, 7}, 3.f), {2, 0, 1, 3}), + fl::full({6, 4, 5, 7}, 3.f) ) ); ASSERT_THROW(fl::transpose(fl::rand({3, 4, 5}), {0, 1}), std::exception); @@ -282,89 +453,72 @@ TEST(TensorBaseTest, transpose) { } TEST(TensorBaseTest, tile) { - auto a = fl::full({4, 4}, 3.); - auto tiled = fl::tile(a, {2, 2}); + auto const a = fl::full({4, 4}, 3.f); + auto const tiled = fl::tile(a, {2, 2}); ASSERT_EQ(tiled.shape(), Shape({8, 8})); - ASSERT_TRUE(allClose(tiled, fl::full({8, 8}, 3.))); + ASSERT_TRUE(allClose(tiled, fl::full({8, 8}, 3.f))); ASSERT_EQ(fl::tile(a, {}).shape(), a.shape()); - auto s = fl::fromScalar(3.14); + auto const s = fl::fromScalar(3.14f); ASSERT_EQ(fl::tile(s, {3, 3}).shape(), Shape({3, 3})); ASSERT_EQ(fl::tile(s, {}).shape(), s.shape()); } -TEST(TensorBaseTest, concatenate) { - auto a = fl::full({3, 3}, 1.); - auto b = fl::full({3, 3}, 2.); - auto c = fl::full({3, 3}, 3.); - ASSERT_TRUE( - allClose(fl::concatenate(0, a, b, c), fl::concatenate({a, b, c})) - ); - auto out = fl::concatenate(0, a, b, c); - ASSERT_EQ(out.shape(), Shape({9, 3})); - - // Empty tenors - ASSERT_EQ(fl::concatenate(0, Tensor(), Tensor()).shape(), Shape({0})); - ASSERT_EQ(fl::concatenate(2, Tensor(), Tensor()).shape(), Shape({0, 1, 1})); - ASSERT_EQ( - fl::concatenate(1, fl::rand({5, 5}), Tensor()).shape(), - Shape({5, 5}) - ); - // More tensors - // TODO{fl::Tensor}{concat} just concat everything once we enforce - // arbitrarily-many tensors - const float val = 3.; - const int axis = 0; - auto t = fl::concatenate( - axis, - fl::full({4, 2}, val), - fl::full({4, 2}, val), - fl::full({4, 2}, val), - fl::concatenate( - axis, - fl::full({4, 2}, val), - fl::full({4, 2}, val), - fl::full({4, 2}, val) - ) - ); - ASSERT_EQ(t.shape(), Shape({24, 2})); - ASSERT_TRUE(allClose(t, fl::full({24, 2}, val))); -} TEST(TensorBaseTest, nonzero) { - std::vector idxs = {0, 1, 4, 9, 11, 23, 55, 82, 91}; - auto a = fl::full({10, 10}, 1, fl::dtype::u32); + std::vector const idxs = {0, 1, 4, 9, 11, 23, 55, 82, 91}; + auto const a = fl::full({10, 10}, 1, fl::dtype::u32); for(const auto idx : idxs) a(idx / 10, idx % 10) = 0; - auto indices = fl::nonzero(a); + auto const indices = fl::nonzero(a); int nnz = a.elements() - idxs.size(); ASSERT_EQ(indices.shape(), Shape({nnz})); ASSERT_TRUE( - allClose(a.flatten()(indices), fl::full({nnz}, 1, fl::dtype::u32))); + allClose(a.flatten()(indices), fl::full({nnz}, 1, fl::dtype::u32)) + ); } TEST(TensorBaseTest, flatten) { - unsigned s = 6; - auto a = fl::full({s, s, s}, 2.); - auto flat = a.flatten(); - ASSERT_EQ(flat.shape(), Shape({s * s * s})); - ASSERT_TRUE(allClose(flat, fl::full({s * s * s}, 2.))); + int size = 6; + auto const a = fl::full({size, size, size}, 2.f); + auto const flat = a.flatten(); + ASSERT_EQ(flat.shape(), Shape({size * size * size})); + ASSERT_TRUE(allClose(flat, fl::full({size * size * size}, 2.f))); } TEST(TensorBaseTest, pad) { - auto t = fl::rand({5, 2}); - auto zeroPadded = fl::pad(t, {{1, 2}, {3, 4}}); - auto zeroTest = fl::concatenate( + std::vector data{ + { + 0.1f, + 0.2f, + 0.3f, + 0.4f, + 0.5f, + 1.1f, + 1.2f, + 1.3f, + 1.4f, + 1.5f + } + }; + + + auto const t = fl::Tensor::fromVector({5, 2}, data); + //auto const t = fl::rand({5, 2}); + auto const actualZeroPad = fl::pad(t, {{1, 2}, {3, 4}}); + auto const expectedZeroPad = fl::concatenate( 1, - fl::full({8, 3}, 0.), - fl::concatenate(0, fl::full({1, 2}, 0.), t, fl::full({2, 2}, 0.)), - fl::full({8, 4}, 0.) + fl::full({8, 3}, 0.f), + fl::concatenate(0, fl::full({1, 2}, 0.f), t, fl::full({2, 2}, 0.f)), + fl::full({8, 4}, 0.f) ); - ASSERT_TRUE(allClose(zeroPadded, zeroTest)); - auto edgePadded = fl::pad(t, {{1, 1}, {2, 2}}, PadType::Edge); - auto vertTiled = fl::concatenate( + + ASSERT_TRUE(allClose(actualZeroPad, expectedZeroPad)); + + auto const actualEdgePad = fl::pad(t, {{1, 1}, {2, 2}}, PadType::Edge); + auto const vertTiled = fl::concatenate( 0, fl::reshape(t(0, fl::span), {1, 2}), t, @@ -372,98 +526,107 @@ TEST(TensorBaseTest, pad) { ); auto vTiled0 = vertTiled(fl::span, 0); auto vTiled1 = vertTiled(fl::span, 1); - ASSERT_TRUE( - allClose( - edgePadded, - fl::concatenate( - 1, - fl::tile(vTiled0, {1, 3}), - fl::tile(vTiled1, {1, 3}) - ) - ) + + auto const expectedEdgePad = fl::concatenate( + 1, + fl::tile(vTiled0, {1, 3}), + fl::tile(vTiled1, {1, 3}) ); - auto symmetricPadded = fl::pad(t, {{1, 1}, {2, 2}}, PadType::Symmetric); - ASSERT_TRUE( - allClose( - symmetricPadded, - // TODO{fl::Tensor}{concat} just concat everything once we enforce - // arbitrarily-many tensors - fl::concatenate( - 1, - vTiled1, - vTiled0, - vTiled0, - fl::concatenate(1, vTiled1, vTiled1, vTiled0) - ) - ) + ASSERT_TRUE(allClose(actualEdgePad, expectedEdgePad)); + + + auto const actualSymmetricPad = fl::pad(t, {{1, 1}, {2, 2}}, PadType::Symmetric); + auto const expectedSymmetricPad = fl::concatenate( + {vTiled1, vTiled0, vTiled0, fl::concatenate(1, vTiled1, vTiled1, vTiled0)}, + 1 ); + + ASSERT_TRUE(allClose(actualSymmetricPad, expectedSymmetricPad)); } -TEST(TensorBaseTest, astype) { - auto a = fl::rand({3, 3}); +TEST(TensorBaseTest, asType) { + auto const a = fl::rand({3, 3}); + auto const size = 9; + ASSERT_EQ(a.type(), dtype::f32); - ASSERT_EQ(a.astype(dtype::f64).type(), dtype::f64); + + auto const aDouble = a.asType(dtype::f64); + + ASSERT_EQ(aDouble.type(), dtype::f64); + + auto const aData = a.host(); + auto const bData = aDouble.host(); + + for(size_t i = 0; i < size; i++) + ASSERT_NEAR(aData[i], bData[i], 1e-6); } TEST(TensorBaseTest, where) { - auto a = Tensor::fromVector({2, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto out = fl::where(a < 5, a, a * 10); - a(a >= 5) *= 10; + constexpr auto threshold = 5; + + auto const a = Tensor::fromVector({2, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto const out = fl::where(a < threshold, a, a * 10); + a(a >= threshold) *= 10; ASSERT_TRUE(allClose(out, a)); - auto outC = fl::where(a < 5, a, 3); - a(a >= 5) = 3; + auto const outC = fl::where(a < threshold, a, 3); + a(a >= threshold) = 3; ASSERT_TRUE(allClose(outC, a)); - auto outC2 = fl::where(a < 5, 3, a); - a(a < 5) = 3; + auto const outC2 = fl::where(a < threshold, 3, a); + a(a < threshold) = 3; ASSERT_TRUE(allClose(outC2, a)); // non b8-type vector throws EXPECT_THROW( - fl::where((a < 5).astype(fl::dtype::f32), a, a * 10), + fl::where((a < threshold).asType(fl::dtype::f32), a, a * 10), std::exception ); } TEST(TensorBaseTest, topk) { - auto a = fl::arange({10, 2}); + constexpr auto k = 3; + constexpr auto k4 = 4; + + auto const a = fl::arange({10, 2}); Tensor values; Tensor indices; - fl::topk(values, indices, a, /* k = */ 3, /* axis = */ 0); // descending sort + fl::topk(values, indices, a, /* k = */ k, /* axis = */ 0); // descending sort ASSERT_TRUE( - allClose(values, Tensor::fromVector({3, 2}, {9, 8, 7, 9, 8, 7})) + allClose(values, Tensor::fromVector({k, 2}, {9, 8, 7, 9, 8, 7})) ); fl::topk( values, indices, a, - /* k = */ 4, - /* axis = */ 0, + /* k = */ + k4, + /* axis = */ + 0, fl::SortMode::Ascending ); ASSERT_TRUE( allClose( values, - Tensor::fromVector({4, 2}, {0, 1, 2, 3, 0, 1, 2, 3}) + Tensor::fromVector({k4, 2}, {0, 1, 2, 3, 0, 1, 2, 3}) ) ); } TEST(TensorBaseTest, sort) { Shape dims({10, 2}); - auto a = fl::arange(dims); - auto sorted = fl::sort(a, /* axis = */ 0, SortMode::Descending); + auto const a = fl::arange(dims); + auto const sorted = fl::sort(a, /* axis = */ 0, SortMode::Descending); - Tensor expected({dims[0]}, a.type()); + Tensor const expected({dims[0]}, a.type()); for(int i = 0; i < dims[0]; ++i) expected(i) = dims[0] - i - 1; - auto tiled = fl::tile(expected, {1, 2}); + auto const tiled = fl::tile(expected, {1, 2}); ASSERT_TRUE(allClose(sorted, tiled)); ASSERT_TRUE(allClose(a, fl::sort(tiled, 0, SortMode::Ascending))); - auto b = fl::rand({10}); + auto const b = fl::rand({10}); Tensor values, indices; fl::sort(values, indices, b, /* axis = */ 0, SortMode::Descending); ASSERT_TRUE( @@ -476,13 +639,13 @@ TEST(TensorBaseTest, sort) { TEST(TensorBaseTest, argsort) { Shape dims({10, 2}); - auto a = fl::arange(dims); - auto sorted = fl::argsort(a, /* axis = */ 0, SortMode::Descending); + auto const a = fl::arange(dims); + auto const sorted = fl::argsort(a, /* axis = */ 0, SortMode::Descending); - Tensor expected({dims[0]}, fl::dtype::u32); + Tensor const expected({dims[0]}, fl::dtype::u32); for(int i = 0; i < dims[0]; ++i) expected(i) = dims[0] - i - 1; - auto tiled = fl::tile(expected, {1, 2}); + auto const tiled = fl::tile(expected, {1, 2}); ASSERT_TRUE(allClose(sorted, tiled)); ASSERT_TRUE(allClose(tiled, fl::argsort(tiled, 0, SortMode::Ascending))); @@ -512,7 +675,7 @@ void assertScalarBehavior(fl::dtype type) { << "dtype: " << type << ", ScalarArgType: " << dtype_traits::getName(); - ScalarArgType val = static_cast(rand()); + auto val = static_cast(rand()); auto a = fl::full({5, 6}, val, type); ASSERT_TRUE(allClose(fl::full({1}, a.template scalar(), type), a(0, 0))) @@ -521,7 +684,7 @@ void assertScalarBehavior(fl::dtype type) { } TEST(TensorBaseTest, scalar) { - auto types = { + auto const types = { fl::dtype::b8, fl::dtype::u8, fl::dtype::s16, @@ -534,7 +697,7 @@ TEST(TensorBaseTest, scalar) { fl::dtype::f32, fl::dtype::f64 }; - for(auto type : types) { + for(auto const type : types) { assertScalarBehavior(type); assertScalarBehavior(type); assertScalarBehavior(type); @@ -552,33 +715,33 @@ TEST(TensorBaseTest, scalar) { TEST(TensorBaseTest, isContiguous) { // Contiguous by default - auto a = fl::rand({10, 10}); + auto const a = fl::rand({10, 10}); ASSERT_TRUE(a.isContiguous()); } TEST(TensorBaseTest, strides) { - auto t = fl::rand({10, 10}); + auto const t = fl::rand({10, 10}); ASSERT_EQ(t.strides(), Shape({1, 10})); } TEST(TensorBaseTest, stream) { - auto t1 = fl::rand({10, 10}); - auto t2 = -t1; - auto t3 = t1 + t2; + auto const t1 = fl::rand({10, 10}); + auto const t2 = -t1; + auto const t3 = t1 + t2; ASSERT_EQ(&t1.stream(), &t2.stream()); ASSERT_EQ(&t1.stream(), &t3.stream()); } TEST(TensorBaseTest, asContiguousTensor) { - auto t = fl::rand({5, 6, 7, 8}); - auto indexed = t( + auto const t = fl::rand({5, 6, 7, 8}); + auto const indexed = t( fl::range(1, 4, 2), fl::range(0, 6, 2), - fl::range(0, 6, 3), + fl::range(0, 7, 3), fl::range(0, 5, 3) ); - auto contiguous = indexed.asContiguousTensor(); + auto const contiguous = indexed.asContiguousTensor(); std::vector strides; unsigned stride = 1; for(unsigned i = 0; i < contiguous.ndim(); ++i) { @@ -589,9 +752,9 @@ TEST(TensorBaseTest, asContiguousTensor) { } TEST(TensorBaseTest, host) { - auto a = fl::rand({10, 10}); + auto const a = fl::rand({10, 10}); - float* ptr = a.host(); + float const* ptr = a.host(); for(int i = 0; i < a.elements(); ++i) ASSERT_EQ(ptr[i], a.flatten()(i).scalar()); @@ -604,8 +767,8 @@ TEST(TensorBaseTest, host) { } TEST(TensorBaseTest, toHostVector) { - auto a = fl::rand({10, 10}); - auto vec = a.toHostVector(); + auto const a = fl::rand({10, 10}); + auto const vec = a.toHostVector(); for(int i = 0; i < a.elements(); ++i) ASSERT_EQ(vec[i], a.flatten()(i).scalar()); @@ -623,19 +786,19 @@ TEST(TensorBaseTest, arange) { ); ASSERT_TRUE( allClose( - fl::arange(0., 1.22, 0.25), - Tensor::fromVector({0., 0.25, 0.5, 0.75}) + fl::arange(0.f, 1.22f, 0.25f), + Tensor::fromVector({0.f, 0.25f, 0.5f, 0.75f}) ) ); ASSERT_TRUE( allClose( - fl::arange(0., 4.1), - Tensor::fromVector({0., 1., 2., 3.}) + fl::arange(0.f, 4.1f), + Tensor::fromVector({0.f, 1.f, 2.f, 3.f}) ) ); // Shape overload - auto v = Tensor::fromVector({0., 1., 2., 3.}); + auto const v = Tensor::fromVector({0.f, 1.f, 2.f, 3.f}); ASSERT_TRUE(allClose(fl::arange({4}), v)); ASSERT_TRUE(allClose(fl::arange({4, 5}), fl::tile(v, {1, 5}))); @@ -644,7 +807,7 @@ TEST(TensorBaseTest, arange) { allClose( fl::arange({4, 5}, 1), fl::tile( - fl::reshape(Tensor::fromVector({0., 1., 2., 3., 4.}), {1, 5}), + fl::reshape(Tensor::fromVector({0.f, 1.f, 2.f, 3.f, 4.f}), {1, 5}), {4} ) ) diff --git a/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp b/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp index 14cfe15..1aa85af 100644 --- a/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp +++ b/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp @@ -28,7 +28,7 @@ void assertTensorScalarBinop( const Tensor& expectOut ) { auto result = op(in, scalar); - auto expect = expectOut.astype(result.type()); + auto expect = expectOut.asType(result.type()); ASSERT_TRUE(allClose(result, expect)) << "in.type(): " << in.type() << ", ScalarType: " << dtype_traits::getName(); @@ -42,7 +42,7 @@ void assertScalarTensorBinop( const Tensor& expectOut ) { auto result = op(scalar, in); - auto expect = expectOut.astype(result.type()); + auto expect = expectOut.asType(result.type()); ASSERT_TRUE(allClose(result, expect)) << "ScalarType: " << dtype_traits::getName() << ", in.type(): " << in.type(); @@ -98,12 +98,12 @@ void applyToAllDtypes(std::function func) { TEST(TensorBinaryOpsTest, ArithmeticBinaryOperators) { auto testArithmeticBinops = [](dtype type) { - auto a = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).astype(type); - auto b = Tensor::fromVector({2, 2}, {1, 2, 3, 4}).astype(type); - auto c = Tensor::fromVector({2, 2}, {1, 3, 5, 7}).astype(type); - auto d = Tensor::fromVector({2, 2}, {1, 6, 15, 28}).astype(type); - auto e = Tensor::fromVector({2, 2}, {3, 2, 1, 0}).astype(type); - auto f = Tensor::fromVector({2, 2}, {2, 4, 6, 8}).astype(type); + auto a = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).asType(type); + auto b = Tensor::fromVector({2, 2}, {1, 2, 3, 4}).asType(type); + auto c = Tensor::fromVector({2, 2}, {1, 3, 5, 7}).asType(type); + auto d = Tensor::fromVector({2, 2}, {1, 6, 15, 28}).asType(type); + auto e = Tensor::fromVector({2, 2}, {3, 2, 1, 0}).asType(type); + auto f = Tensor::fromVector({2, 2}, {2, 4, 6, 8}).asType(type); auto z = fl::full({2, 2}, 0, type); assertCommutativeBinop(a, z, std::plus<>(), a); @@ -140,16 +140,16 @@ TEST(TensorBinaryOpsTest, ComparisonBinaryOperators) { auto falses = fl::full({2, 2}, 0, dtype::b8); auto trues = fl::full({2, 2}, 1, dtype::b8); auto falseTrues = - Tensor::fromVector({2, 2}, {0, 1, 0, 1}).astype(fl::dtype::b8); + Tensor::fromVector({2, 2}, {0, 1, 0, 1}).asType(fl::dtype::b8); auto trueFalses = - Tensor::fromVector({2, 2}, {1, 0, 1, 0}).astype(fl::dtype::b8); + Tensor::fromVector({2, 2}, {1, 0, 1, 0}).asType(fl::dtype::b8); auto testComparisonBinops = [&](dtype type) { - auto a = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).astype(type); - auto b = Tensor::fromVector({2, 2}, {0, 0, 2, 0}).astype(type); - auto c = Tensor::fromVector({2, 2}, {2, 3, 4, 5}).astype(type); - auto d = Tensor::fromVector({2, 2}, {0, 4, 2, 6}).astype(type); - auto e = Tensor::fromVector({2, 2}, {0, 1, 0, 1}).astype(type); + auto a = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).asType(type); + auto b = Tensor::fromVector({2, 2}, {0, 0, 2, 0}).asType(type); + auto c = Tensor::fromVector({2, 2}, {2, 3, 4, 5}).asType(type); + auto d = Tensor::fromVector({2, 2}, {0, 4, 2, 6}).asType(type); + auto e = Tensor::fromVector({2, 2}, {0, 1, 0, 1}).asType(type); ASSERT_TRUE(allClose((a == a), trues)) << "dtype: " << type; assertCommutativeBinop(a, b, std::equal_to<>(), trueFalses); @@ -207,11 +207,11 @@ TEST(TensorBinaryOpsTest, LogicalBinaryOperators) { auto falses = fl::full({2, 2}, 0, dtype::b8); auto trues = fl::full({2, 2}, 1, dtype::b8); auto falseTrues = - Tensor::fromVector({2, 2}, {0, 1, 0, 1}).astype(fl::dtype::b8); + Tensor::fromVector({2, 2}, {0, 1, 0, 1}).asType(fl::dtype::b8); auto testLogicalBinops = [&](dtype type) { - auto a = Tensor::fromVector({2, 2}, {0, 1, 0, 3}).astype(type); - auto b = Tensor::fromVector({2, 2}, {2, 3, 4, 5}).astype(type); + auto a = Tensor::fromVector({2, 2}, {0, 1, 0, 3}).asType(type); + auto b = Tensor::fromVector({2, 2}, {2, 3, 4, 5}).asType(type); auto z = fl::full({2, 2}, 0, type); ASSERT_TRUE(allClose((z || z), falses)) << "dtype: " << type; @@ -234,9 +234,9 @@ TEST(TensorBinaryOpsTest, LogicalBinaryOperators) { TEST(TensorBinaryOpsTest, ModuloBinaryOperators) { auto testModuloBinop = [](dtype type) { - auto a = Tensor::fromVector({2, 2}, {1, 2, 3, 4}).astype(type); - auto b = Tensor::fromVector({2, 2}, {2, 3, 5, 7}).astype(type); - auto c = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).astype(type); + auto a = Tensor::fromVector({2, 2}, {1, 2, 3, 4}).asType(type); + auto b = Tensor::fromVector({2, 2}, {2, 3, 5, 7}).asType(type); + auto c = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).asType(type); auto z = fl::full({2, 2}, 0, type); ASSERT_TRUE(allClose((z % b), z)) << "dtype: " << type; @@ -260,14 +260,14 @@ TEST(TensorBinaryOpsTest, ModuloBinaryOperators) { TEST(TensorBinaryOpsTest, BitBinaryOperators) { auto testBitBinops = [](dtype type) { - auto a = Tensor::fromVector({2, 1}, {0b0001, 0b1000}).astype(type); - auto b = Tensor::fromVector({2, 1}, {0b0010, 0b0100}).astype(type); - auto c = Tensor::fromVector({2, 1}, {0b0011, 0b1100}).astype(type); - auto d = Tensor::fromVector({2, 1}, {0b0110, 0b0110}).astype(type); - auto e = Tensor::fromVector({2, 1}, {0b1000, 0b0001}).astype(type); - auto g = Tensor::fromVector({2, 1}, {2, 1}).astype(type); - auto h = Tensor::fromVector({2, 1}, {0b1000, 0b1000}).astype(type); - auto z = Tensor::fromVector({2, 1}, {0b0000, 0b0000}).astype(type); + auto a = Tensor::fromVector({2, 1}, {0b0001, 0b1000}).asType(type); + auto b = Tensor::fromVector({2, 1}, {0b0010, 0b0100}).asType(type); + auto c = Tensor::fromVector({2, 1}, {0b0011, 0b1100}).asType(type); + auto d = Tensor::fromVector({2, 1}, {0b0110, 0b0110}).asType(type); + auto e = Tensor::fromVector({2, 1}, {0b1000, 0b0001}).asType(type); + auto g = Tensor::fromVector({2, 1}, {2, 1}).asType(type); + auto h = Tensor::fromVector({2, 1}, {0b1000, 0b1000}).asType(type); + auto z = Tensor::fromVector({2, 1}, {0b0000, 0b0000}).asType(type); ASSERT_TRUE(allClose((z & z), z)) << "dtype: " << type; assertCommutativeBinop(a, b, std::bit_and<>(), z); @@ -366,8 +366,8 @@ TEST(TensorBinaryOpsTest, minimum) { auto c = fl::minimum(a, b); ASSERT_EQ(a.type(), c.type()); ASSERT_TRUE(allClose(a, c)); - ASSERT_TRUE(allClose(fl::minimum(1, b).astype(a.type()), a)); - ASSERT_TRUE(allClose(fl::minimum(b, 1).astype(a.type()), a)); + ASSERT_TRUE(allClose(fl::minimum(1, b).asType(a.type()), a)); + ASSERT_TRUE(allClose(fl::minimum(b, 1).asType(a.type()), a)); } TEST(TensorBinaryOpsTest, maximum) { @@ -376,8 +376,8 @@ TEST(TensorBinaryOpsTest, maximum) { auto c = fl::maximum(a, b); ASSERT_EQ(b.type(), c.type()); ASSERT_TRUE(allClose(b, c)); - ASSERT_TRUE(allClose(fl::maximum(1, b).astype(a.type()), b)); - ASSERT_TRUE(allClose(fl::maximum(b, 1).astype(a.type()), b)); + ASSERT_TRUE(allClose(fl::maximum(1, b).asType(a.type()), b)); + ASSERT_TRUE(allClose(fl::maximum(b, 1).asType(a.type()), b)); } using binaryOpFunc_t = Tensor (*)(const Tensor& lhs, const Tensor& rhs); @@ -473,8 +473,8 @@ TEST(TensorBinaryOpsTest, broadcasting) { for(const auto& funcp : functions) { for(auto& shapeData : shapes) { - auto lhs = ((fl::rand(shapeData.lhs) + 1) * 10).astype(fl::dtype::s32); - auto rhs = ((fl::rand(shapeData.rhs) + 1) * 10).astype(fl::dtype::s32); + auto lhs = ((fl::rand(shapeData.lhs) + 1) * 10).asType(fl::dtype::s32); + auto rhs = ((fl::rand(shapeData.rhs) + 1) * 10).asType(fl::dtype::s32); auto [actualOut, expectedOut] = doBinaryOp( lhs, @@ -498,7 +498,7 @@ TEST(TensorBinaryOpsTest, broadcasting) { // Scalar broadcasting const double scalarVal = 4; const Shape inShape = {2, 3, 4}; - const auto lhs = fl::rand(inShape).astype(fl::dtype::s32); + const auto lhs = fl::rand(inShape).asType(fl::dtype::s32); const auto rhs = fl::fromScalar(scalarVal, fl::dtype::s32); const auto rhsTiled = fl::full(inShape, scalarVal, fl::dtype::s32); ASSERT_TRUE(allClose(funcp.first(lhs, rhs), funcp.first(lhs, rhsTiled))); @@ -517,7 +517,7 @@ TEST(TensorBinaryOpsTest, powerDouble) { auto b = fl::full({3, 3}, 2.); ASSERT_TRUE( - allClose(fl::power(3, a), fl::full(b.shape(), 3 * 3, fl::dtype::f32)) + allClose(fl::power(3, a), fl::full(b.shape(), 3 * 3, fl::dtype::f64)) ); } diff --git a/flashlight/fl/test/tensor/TensorReductionTest.cpp b/flashlight/fl/test/tensor/TensorReductionTest.cpp index 320356b..5ef4dd3 100644 --- a/flashlight/fl/test/tensor/TensorReductionTest.cpp +++ b/flashlight/fl/test/tensor/TensorReductionTest.cpp @@ -19,7 +19,7 @@ using namespace fl; TEST(TensorReductionTest, countNonzero) { std::vector idxs = {0, 3, 4, 7, 24, 78}; auto a = fl::full({10, 10}, 1, fl::dtype::u32); - for(const auto idx : idxs) + for(auto const idx : idxs) a(idx / 10, idx % 10) = 0; ASSERT_TRUE( @@ -31,8 +31,8 @@ TEST(TensorReductionTest, countNonzero) { std::vector sizes(a.shape().dim(0)); for(unsigned i = 0; i < a.shape().dim(0); ++i) - sizes[i] = - a.shape().dim(0) - fl::sum(a(fl::span, i) == 0, {0}).scalar(); + sizes[i] = a.shape().dim(0) - fl::sum(a(fl::span, i) == 0, {0}).scalar(); + ASSERT_TRUE(allClose(Tensor::fromVector(sizes), Tensor::fromVector(sizes))); auto b = fl::full({2, 2, 2}, 1, fl::dtype::u32); @@ -209,8 +209,8 @@ TEST(TensorReductionTest, cumsum) { } TEST(TensorReductionTest, sum) { - auto t = fl::full({3, 4, 5, 6}, 1.0); - ASSERT_TRUE(allClose(fl::sum(t, {0}), fl::full({4, 5, 6}, 3.0))); + auto t = fl::full({3, 4, 5, 6}, 1.f); + ASSERT_TRUE(allClose(fl::sum(t, {0}), fl::full({4, 5, 6}, 3.f))); ASSERT_TRUE( allClose(fl::sum(t, {1, 2}), fl::full({3, 6}, 4 * 5, fl::dtype::f32)) ); @@ -247,12 +247,12 @@ TEST(TensorReductionTest, mean) { ); auto s = fl::full({5, 6, 7}, 1); - ASSERT_TRUE(allClose(fl::mean(s, {0}), fl::full({6, 7}, 1.))); + ASSERT_TRUE(allClose(fl::mean(s, {0}), fl::full({6, 7}, 1.f))); auto a = fl::mean(fl::full({5, 5, 5, 5}, 1)); ASSERT_EQ(a.shape(), Shape({})); ASSERT_EQ(a.elements(), 1); - ASSERT_EQ(a.scalar(), 1.); + ASSERT_EQ(a.scalar(), 1.f); // TODO: fixture this const float v = 3.14; @@ -264,8 +264,8 @@ TEST(TensorReductionTest, mean) { TEST(TensorReductionTest, median) { auto a = Tensor::fromVector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - ASSERT_EQ(fl::median(a).scalar(), 4.5); - ASSERT_TRUE(allClose(fl::median(a, {0}), fl::fromScalar(4.5))); + ASSERT_EQ(fl::median(a).scalar(), 4.5f); + ASSERT_TRUE(allClose(fl::median(a, {0}), fl::fromScalar(4.5f))); ASSERT_EQ(fl::median(fl::rand({5, 6, 7, 8}), {1, 2}).shape(), Shape({5, 8})); ASSERT_EQ( fl::median(fl::rand({5, 6, 7, 8}), {1, 2}, /* keepDims = */ true).shape(), @@ -275,7 +275,7 @@ TEST(TensorReductionTest, median) { auto b = fl::median(fl::full({5, 5, 5, 5}, 1)); ASSERT_EQ(b.shape(), Shape({})); ASSERT_EQ(b.elements(), 1); - ASSERT_EQ(b.scalar(), 1.); + ASSERT_EQ(b.scalar(), 1.f); const float v = 3.14; auto q = fl::median(fl::fromScalar(v)); @@ -297,7 +297,7 @@ TEST(TensorReductionTest, var) { ); auto s = fl::full({5, 6, 7}, 1); - ASSERT_TRUE(allClose(fl::var(s, {0}), fl::full({6, 7}, 0.))); + ASSERT_TRUE(allClose(fl::var(s, {0}), fl::full({6, 7}, 0.f))); auto a = fl::rand({5, 5}); ASSERT_TRUE(allClose(fl::var(a), fl::var(a, {0, 1}))); @@ -310,14 +310,14 @@ TEST(TensorReductionTest, var) { TEST(TensorReductionTest, std) { auto r = fl::rand({7, 8, 9}); - ASSERT_NEAR(fl::std(r).scalar(), 0.2886, 0.005); + ASSERT_NEAR(fl::std(r).scalar(), 0.2886f, 0.005f); ASSERT_EQ( fl::std(r, {0, 1}, /* keepDims = */ true).shape(), Shape({1, 1, 9}) ); auto s = fl::full({5, 6, 7}, 1); - ASSERT_TRUE(allClose(fl::std(s, {0}), fl::full({6, 7}, 0.))); + ASSERT_TRUE(allClose(fl::std(s, {0}), fl::full({6, 7}, 0.f))); ASSERT_TRUE(allClose(fl::std(s, {1}), fl::sqrt(fl::var(s, {1})))); const float v = 3.14; @@ -334,7 +334,7 @@ TEST(TensorReductionTest, norm) { ASSERT_EQ(normAll.shape(), Shape({})); ASSERT_EQ(normAll.elements(), 1); ASSERT_FLOAT_EQ( - fl::norm(fl::full({5, 5}, 1.)).scalar(), + fl::norm(fl::full({5, 5}, 1.f)).scalar(), std::sqrt(5 * 5) ); ASSERT_EQ( @@ -347,78 +347,78 @@ TEST(TensorReductionTest, norm) { const float v = 3.14; auto q = fl::norm(fl::fromScalar(v)); ASSERT_EQ(q.shape(), Shape()); - ASSERT_NEAR(q.scalar(), 3.14, 1e-4); + ASSERT_NEAR(q.scalar(), 3.14f, 1e-4); ASSERT_EQ(fl::norm(fl::fromScalar(v), {0}).shape(), Shape()); } TEST(TensorReductionTest, any) { using fl::dtype; auto t = Tensor::fromVector({3, 3}, {1, 0, 0, 0, 0, 0, 0, 0, 1}); - auto anyAll = fl::any(t); + auto anyAll = fl::any_of(t); ASSERT_EQ(anyAll.shape(), Shape({})); ASSERT_EQ(anyAll.elements(), 1); ASSERT_TRUE(anyAll.scalar()); ASSERT_TRUE( allClose( - fl::any(t, {0}), - Tensor::fromVector({1, 0, 1}).astype(dtype::b8) + fl::any_of(t, {0}), + Tensor::fromVector({1, 0, 1}).asType(dtype::b8) ) ); - ASSERT_TRUE(allClose(fl::any(t, {0, 1}), fl::fromScalar(true, dtype::b8))); - ASSERT_FALSE(fl::any(Tensor::fromVector({0, 0, 0})).scalar()); + ASSERT_TRUE(allClose(fl::any_of(t, {0, 1}), fl::fromScalar(true, dtype::b8))); + ASSERT_FALSE(fl::any_of(Tensor::fromVector({0, 0, 0})).scalar()); - auto keptDims = fl::any( - fl::any(t, {1}, /* keepDims = */ true), + auto keptDims = fl::any_of( + fl::any_of(t, {1}, /* keepDims = */ true), {0}, /* keepDims = */ true ); ASSERT_EQ(keptDims.shape(), Shape({1, 1})); - ASSERT_EQ(keptDims.scalar(), fl::any(t, {0, 1}).scalar()); - auto q = fl::any(fl::full({5, 5, 5, 5}, 1)); + ASSERT_EQ(keptDims.scalar(), fl::any_of(t, {0, 1}).scalar()); + auto q = fl::any_of(fl::full({5, 5, 5, 5}, 1)); ASSERT_EQ(q.shape(), Shape({})); ASSERT_EQ(q.elements(), 1); ASSERT_EQ(q.scalar(), true); const float v = 3.14; - auto r = fl::any(fl::fromScalar(v)); + auto r = fl::any_of(fl::fromScalar(v)); ASSERT_EQ(r.shape(), Shape()); ASSERT_TRUE(r.scalar()); - ASSERT_EQ(fl::any(fl::fromScalar(v), {0}).shape(), Shape()); + ASSERT_EQ(fl::any_of(fl::fromScalar(v), {0}).shape(), Shape()); } TEST(TensorReductionTest, all) { using fl::dtype; auto t = Tensor::fromVector({3, 3}, {1, 0, 0, 0, 0, 0, 0, 0, 1}); - auto allAll = fl::all(t); + auto allAll = fl::all_of(t); ASSERT_EQ(allAll.shape(), Shape({})); ASSERT_EQ(allAll.elements(), 1); ASSERT_FALSE(allAll.scalar()); ASSERT_TRUE( allClose( - fl::all(t, {0}), - Tensor::fromVector({0, 0, 0}).astype(dtype::b8) + fl::all_of(t, {0}), + Tensor::fromVector({0, 0, 0}).asType(dtype::b8) ) ); - ASSERT_TRUE(allClose(fl::all(t, {0, 1}), fl::fromScalar(false, dtype::b8))); - ASSERT_TRUE(fl::all(Tensor::fromVector({1, 1, 1})).scalar()); + ASSERT_TRUE(allClose(fl::all_of(t, {0, 1}), fl::fromScalar(false, dtype::b8))); + ASSERT_TRUE(fl::all_of(Tensor::fromVector({1, 1, 1})).scalar()); - auto keptDims = fl::all( - fl::all(t, {1}, /* keepDims = */ true), + auto keptDims = fl::all_of( + fl::all_of(t, {1}, /* keepDims = */ true), {0}, /* keepDims = */ true ); ASSERT_EQ(keptDims.shape(), Shape({1, 1})); - ASSERT_EQ(keptDims.scalar(), fl::all(t, {0, 1}).scalar()); - auto q = fl::all(fl::full({5, 5, 5, 5}, 1)); + ASSERT_EQ(keptDims.scalar(), fl::all_of(t, {0, 1}).scalar()); + auto q = fl::all_of(fl::full({5, 5, 5, 5}, 1)); ASSERT_EQ(q.shape(), Shape({})); ASSERT_EQ(q.elements(), 1); ASSERT_EQ(q.scalar(), true); const float v = 3.14; - auto a = fl::all(fl::fromScalar(v)); + auto a = fl::all_of(fl::fromScalar(v)); ASSERT_EQ(a.shape(), Shape()); ASSERT_TRUE(a.scalar()); - ASSERT_EQ(fl::all(fl::fromScalar(v), {0}).shape(), Shape()); + ASSERT_EQ(fl::all_of(fl::fromScalar(v), {0}).shape(), Shape()); } int main(int argc, char** argv) { diff --git a/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp b/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp index c38a171..ab5a2c9 100644 --- a/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp +++ b/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp @@ -27,25 +27,25 @@ TEST(TensorUnaryOpsTest, logicalNot) { ASSERT_TRUE( allClose( !fl::full({3, 3}, true), - fl::full({3, 3}, false).astype(dtype::b8) + fl::full({3, 3}, false).asType(dtype::b8) ) ); } TEST(TensorUnaryOpsTest, clip) { - float h = 3.; - float l = 2.; + float h = 3.f; + float l = 2.f; Shape s = {3, 3}; auto high = fl::full(s, h); auto low = fl::full(s, l); - ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), low, high), high)); - ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), l, high), high)); - ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), low, h), high)); - ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), l, h), high)); + ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.f), low, high), high)); + ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.f), l, high), high)); + ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.f), low, h), high)); + ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.f), l, h), high)); } TEST(TensorUnaryOpsTest, roll) { - auto t = fl::full({5, 5}, 4.); + auto t = fl::full({5, 5}, 4.f); ASSERT_TRUE(allClose(t, fl::roll(t, /* shift = */ 3, /* axis = */ 1))); Shape dims({4, 5}); @@ -65,8 +65,8 @@ TEST(TensorUnaryOpsTest, isnan) { Shape s = {3, 3}; ASSERT_TRUE( allClose( - fl::isnan(fl::full(s, 1.) / 3), - fl::full(s, false).astype(fl::dtype::b8) + fl::isnan(fl::full(s, 1.f) / 3), + fl::full(s, false).asType(fl::dtype::b8) ) ); } @@ -75,14 +75,14 @@ TEST(TensorUnaryOpsTest, isinf) { Shape s = {3, 3}; ASSERT_TRUE( allClose( - fl::isinf(fl::full(s, 1.) / 3), - fl::full(s, false).astype(fl::dtype::b8) + fl::isinf(fl::full(s, 1.f) / 3), + fl::full(s, false).asType(fl::dtype::b8) ) ); ASSERT_TRUE( allClose( - fl::isinf(fl::full(s, 1.) / 0.), - fl::full(s, true).astype(fl::dtype::b8) + fl::isinf(fl::full(s, 1.f) / 0.f), + fl::full(s, true).asType(fl::dtype::b8) ) ); } @@ -102,7 +102,7 @@ TEST(TensorUnaryOpsTest, tril) { [](const Dim dim, const Tensor& res, const Tensor& in) { for(int i = 0; i < dim; ++i) for(int j = i + 1; j < dim; ++j) - ASSERT_EQ(res(i, j).scalar(), 0.); + ASSERT_EQ(res(i, j).scalar(), 0.f); for(int i = 0; i < dim; ++i) for(int j = 0; j < i; ++j) ASSERT_TRUE(allClose(res(i, j), in(i, j))); @@ -133,7 +133,7 @@ TEST(TensorUnaryOpsTest, triu) { ASSERT_TRUE(allClose(res(i, j), in(i, j))); for(unsigned i = 0; i < dim; ++i) for(unsigned j = 0; j < i; ++j) - ASSERT_EQ(res(i, j).scalar(), 0.); + ASSERT_EQ(res(i, j).scalar(), 0.f); }; int dim = 10; @@ -156,20 +156,20 @@ TEST(TensorUnaryOpsTest, triu) { TEST(TensorUnaryOpsTest, floor) { auto a = fl::rand({10, 10}) + 0.5; - ASSERT_TRUE(allClose((a >= 1.).astype(fl::dtype::f32), fl::floor(a))); + ASSERT_TRUE(allClose((a >= 1.).asType(fl::dtype::f32), fl::floor(a))); } TEST(TensorUnaryOpsTest, ceil) { auto a = fl::rand({10, 10}) + 0.5; - ASSERT_TRUE(allClose((a >= 1).astype(fl::dtype::f32), fl::ceil(a) - 1)); + ASSERT_TRUE(allClose((a >= 1).asType(fl::dtype::f32), fl::ceil(a) - 1)); } TEST(TensorUnaryOpsTest, rint) { Shape s = {10, 10}; auto a = fl::rand(s) - 0.5; - ASSERT_TRUE(allClose(fl::rint(a), fl::full(s, 0.))); + ASSERT_TRUE(allClose(fl::rint(a), fl::full(s, 0.f))); auto b = fl::rand(s) + 0.5; - ASSERT_TRUE(allClose(fl::rint(b), fl::full(s, 1.))); + ASSERT_TRUE(allClose(fl::rint(b), fl::full(s, 1.f))); } TEST(TensorUnaryOpsTest, sigmoid) { diff --git a/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp b/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp index 503d8d5..e29ed74 100644 --- a/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp +++ b/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp @@ -20,6 +20,8 @@ #include "flashlight/fl/tensor/backend/af/ArrayFireTensor.h" #include "flashlight/fl/tensor/backend/af/Utils.h" +#include + using namespace ::testing; using namespace fl; @@ -46,6 +48,31 @@ bool allClose( namespace fl { +TEST(ArrayFireSparse, HostExtractionCrash) { + int size = 8; + + // Create lazy nodes + af::array values = af::constant(1.0, size, f64); + af::array row_ptr = af::iota(af::dim4(size + 1), af::dim4(1), s32); + af::array col_idx = af::constant(0, size, s32); + + // Create sparse array + + af::array sp = af::sparse(size, 4, values, row_ptr, col_idx, AF_STORAGE_CSR); + + try { + std::vector data(sp.elements(), 999); + + sp.host(data.data()); + } + catch(af::exception const& e) { + SUCCEED(); + return; + } + + FAIL() << "ArrayFire fixed host call on sparse matrices, update Tensor::host accordingly"; +} + TEST(ArrayFireTensorBaseTest, ArrayFireShapeInterop) { ASSERT_EQ(detail::afToFlDims(af::dim4(), 0), Shape({})); // scalar ASSERT_EQ(detail::afToFlDims(af::dim4(0), 1), Shape({0})); @@ -54,9 +81,7 @@ TEST(ArrayFireTensorBaseTest, ArrayFireShapeInterop) { ASSERT_EQ(detail::afToFlDims(af::dim4(0, 1, 1, 1), 4), Shape({0, 1, 1, 1})); using namespace fl::detail; - auto dimsEq = [](const af::dim4& d, const Shape& s) { - return detail::afToFlDims(d, s.ndim()) == s; - }; + auto dimsEq = [](const af::dim4& d, const Shape& s) { return detail::afToFlDims(d, s.ndim()) == s; }; ASSERT_TRUE(dimsEq(af::dim4(3), {3})); // not 3, 1, 1, 1 ASSERT_TRUE(dimsEq(af::dim4(3, 2), {3, 2})); // not 3, 2, 1, 1 @@ -133,17 +158,17 @@ TEST(ArrayFireTensorBaseTest, AfRefCountModify) { TEST(ArrayFireTensorBaseTest, astypeRefcount) { auto t = fl::rand({5, 5}); ASSERT_EQ(getRefCount(toArray(t)), 1); - auto t64 = t.astype(fl::dtype::f64); + auto t64 = t.asType(fl::dtype::f64); ASSERT_EQ(getRefCount(toArray(t64)), 1); } TEST(ArrayFireTensorBaseTest, astypeInPlaceRefcount) { auto a = fl::rand({4, 4}); ASSERT_EQ(getRefCount(toArray(a)), 1); - a = a.astype(fl::dtype::f64); + a = a.asType(fl::dtype::f64); ASSERT_EQ(getRefCount(toArray(a)), 1); ASSERT_EQ(a.type(), fl::dtype::f64); - a = a.astype(fl::dtype::f32); + a = a.asType(fl::dtype::f32); ASSERT_EQ(getRefCount(toArray(a)), 1); } @@ -208,15 +233,25 @@ TEST(ArrayFireTensorBaseTest, BinaryOperators) { TEST(ArrayFireTensorBaseTest, full) { // TODO: expand with fixtures for each type - auto a = fl::full({3, 4}, 3.); - ASSERT_EQ(a.shape(), Shape({3, 4})); + af::dim4 afDim2{3, 4}; + Shape const dim2Shape{afDim2[0], afDim2[1]}; + constexpr float val2 = 3; + + auto a = fl::full(dim2Shape, val2); + ASSERT_EQ(a.shape(), dim2Shape); ASSERT_EQ(a.type(), fl::dtype::f32); - ASSERT_TRUE(allClose(toArray(a), af::constant(3., {3, 4}))); - auto b = fl::full({1, 1, 5, 4}, 4.5); - ASSERT_EQ(b.shape(), Shape({1, 1, 5, 4})); + ASSERT_TRUE(allClose(toArray(a), af::constant(val2, afDim2))); + + + af::dim4 afDim4{1, 1, 5, 4}; + Shape const dim4Shape{afDim4[0], afDim4[1], afDim4[2], afDim4[3]}; + constexpr float val4 = 4.5; + + auto b = fl::full(dim4Shape, val4); + ASSERT_EQ(b.shape(), dim4Shape); ASSERT_EQ(b.type(), fl::dtype::f32); - ASSERT_TRUE(allClose(toArray(b), af::constant(4.5, {1, 1, 5, 4}))); + ASSERT_TRUE(allClose(toArray(b), af::constant(val4, afDim4))); } TEST(ArrayFireTensorBaseTest, identity) { @@ -431,7 +466,7 @@ TEST(ArrayFireTensorBaseTest, tile) { } TEST(ArrayFireTensorBaseTest, nonzero) { - auto a = fl::rand({10, 10}).astype(fl::dtype::u32); + auto a = fl::rand({10, 10}).asType(fl::dtype::u32); auto nz = fl::nonzero(a); ASSERT_TRUE(allClose(toArray(nz), af::where(toArray(a)))); } @@ -450,10 +485,6 @@ TEST(ArrayFireTensorBaseTest, transpose) { ); } -TEST(ArrayFireTensorBaseTest, concatenate) { - std::vector tensors(11); - ASSERT_THROW(fl::concatenate(tensors), std::invalid_argument); -} TEST(ArrayFireTensorBaseTest, device) { auto a = fl::rand({5, 5}); diff --git a/flashlight/pkg/runtime/common/DistributedUtils.cpp b/flashlight/pkg/runtime/common/DistributedUtils.cpp index 24512ac..8fcc3e6 100644 --- a/flashlight/pkg/runtime/common/DistributedUtils.cpp +++ b/flashlight/pkg/runtime/common/DistributedUtils.cpp @@ -52,13 +52,13 @@ Tensor allreduceGet(fl::AverageValueMeter& mtr) { Tensor allreduceGet(fl::EditDistanceMeter& mtr) { auto mtrVal0 = mtr.value(); - std::vector mtrVal(mtrVal0.begin(), mtrVal0.end()); + std::vector mtrVal(mtrVal0.begin(), mtrVal0.end()); return Tensor::fromVector(mtrVal); } Tensor allreduceGet(fl::CountMeter& mtr) { auto mtrVal0 = mtr.value(); - std::vector mtrVal(mtrVal0.begin(), mtrVal0.end()); + std::vector mtrVal(mtrVal0.begin(), mtrVal0.end()); return Tensor::fromVector(mtrVal); } @@ -81,7 +81,7 @@ void allreduceSet(fl::AverageValueMeter& mtr, Tensor& val) { void allreduceSet(fl::EditDistanceMeter& mtr, Tensor& val) { mtr.reset(); - auto valVec = val.toHostVector(); + auto valVec = val.toHostVector(); mtr.add( static_cast(valVec[1]), static_cast(valVec[2]), @@ -92,7 +92,7 @@ void allreduceSet(fl::EditDistanceMeter& mtr, Tensor& val) { void allreduceSet(fl::CountMeter& mtr, Tensor& val) { mtr.reset(); - auto valVec = val.toHostVector(); + auto valVec = val.toHostVector(); for(size_t i = 0; i < valVec.size(); ++i) mtr.add(i, valVec[i]); } diff --git a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp index 7477bbf..d9ee052 100644 --- a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp +++ b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp @@ -172,7 +172,7 @@ std::vector Seq2SeqCriterion::forward( size_t nClass = out.dim(0); auto targetTiled = fl::tile( fl::reshape(target.tensor(), {1, target.dim(0), target.dim(1)}), - {static_cast(nClass)} + {static_cast(nClass)} ); out = applySeq2SeqMask(out, targetTiled, pad_); auto smoothLoss = moddims(sum(out, {0, 1}), {-1}); @@ -209,11 +209,11 @@ std::pair Seq2SeqCriterion::vectorizedDecoder( ); else if(samplingStrategy_ == fl::pkg::speech::kRandSampling) { auto mask = Variable( - (fl::rand(y.shape()) * 100 <= pctTeacherForcing_).astype(y.type()), + (fl::rand(y.shape()) * 100 <= pctTeacherForcing_).asType(y.type()), false ); auto samples = Variable( - (fl::rand(y.shape()) * (nClass_ - 1)).astype(y.type()), + (fl::rand(y.shape()) * (nClass_ - 1)).asType(y.type()), false ); @@ -285,7 +285,7 @@ std::pair Seq2SeqCriterion::decoder( y = Variable(maxIdx, false); } else if(samplingStrategy_ == fl::pkg::speech::kRandSampling) y = Variable( - (fl::rand({1, target.dim(1)}) * (nClass_ - 1)).astype(fl::dtype::s32), + (fl::rand({1, target.dim(1)}) * (nClass_ - 1)).asType(fl::dtype::s32), false ); else @@ -414,7 +414,7 @@ std::vector Seq2SeqCriterion::beamSearch( ox = fl::reorder(ox, {0, 2, 1}); auto scoreArr = Tensor::fromBuffer( - {1, static_cast(beam.size()), 1}, + {1, static_cast(beam.size()), 1}, prevScoreVec.data(), MemoryLocation::Host ); diff --git a/flashlight/pkg/speech/criterion/TransformerCriterion.cpp b/flashlight/pkg/speech/criterion/TransformerCriterion.cpp index 4ed9a3d..1d431f7 100644 --- a/flashlight/pkg/speech/criterion/TransformerCriterion.cpp +++ b/flashlight/pkg/speech/criterion/TransformerCriterion.cpp @@ -90,7 +90,7 @@ std::vector TransformerCriterion::forward( {-1} ); if(train_ && labelSmooth_ > 0) { - long long nClass = out.dim(0); + auto nClass = out.dim(0); auto targetTiled = fl::tile( fl::reshape(target.tensor(), {1, target.dim(0), target.dim(1)}), {nClass} @@ -123,11 +123,11 @@ std::pair TransformerCriterion::vectorizedDecoder( if(train_) { // TODO: other sampling strategies auto mask = Variable( - (fl::rand(y.shape()) * 100 <= pctTeacherForcing_).astype(y.type()), + (fl::rand(y.shape()) * 100 <= pctTeacherForcing_).asType(y.type()), false ); auto samples = Variable( - (fl::rand(y.shape()) * (nClass_ - 1)).astype(y.type()), + (fl::rand(y.shape()) * (nClass_ - 1)).asType(y.type()), false ); diff --git a/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp index 3ce1bb8..2cea474 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp @@ -125,7 +125,7 @@ std::vector ConnectionistTemporalClassificationCriterion::forward( "Error: get_workspace_size" ); - Tensor workspace({static_cast(workspace_size)}, fl::dtype::b8); + Tensor workspace({static_cast(workspace_size)}, fl::dtype::b8); std::vector costs(B, 0.0); { diff --git a/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp b/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp index bf1dd7e..f8ab063 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp @@ -44,7 +44,7 @@ Tensor viterbiPath(const Tensor& input, const Tensor& trans) { Tensor path({T, B}, fl::dtype::s32); Tensor workspace( - {static_cast(ViterbiPath::getWorkspaceSize(B, T, N))}, + {static_cast(ViterbiPath::getWorkspaceSize(B, T, N))}, fl::dtype::u8); { diff --git a/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp index 463712b..2cdbd93 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp @@ -91,7 +91,7 @@ Variable ForceAlignmentCriterion::forward( const auto& trans = transVar.tensor(); Tensor loss({B}, fl::dtype::f32); Tensor workspace( - {static_cast(FAC::getWorkspaceSize(B, T, N, L))}, + {static_cast(FAC::getWorkspaceSize(B, T, N, L))}, fl::dtype::u8); { @@ -155,7 +155,7 @@ Tensor ForceAlignmentCriterion::viterbiPath( const auto& trans = transVar.tensor(); Tensor bestPathsVar({T, B}, fl::dtype::s32); Tensor workspace( - {static_cast(FAC::getWorkspaceSize(B, T, N, L))}, + {static_cast(FAC::getWorkspaceSize(B, T, N, L))}, fl::dtype::u8); { diff --git a/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp index 6c63e8a..ab05dd0 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp @@ -96,7 +96,7 @@ Variable FullConnectionCriterion::forward( const auto& trans = transVar.tensor(); Tensor loss({B}, fl::dtype::f32); Tensor workspace( - {static_cast(FCC::getWorkspaceSize(B, T, N))}, fl::dtype::u8); + {static_cast(FCC::getWorkspaceSize(B, T, N))}, fl::dtype::u8); { fl::DevicePtr inputRaw(input); diff --git a/flashlight/pkg/speech/data/FeatureTransforms.cpp b/flashlight/pkg/speech/data/FeatureTransforms.cpp index 01b0e7d..0274161 100644 --- a/flashlight/pkg/speech/data/FeatureTransforms.cpp +++ b/flashlight/pkg/speech/data/FeatureTransforms.cpp @@ -116,7 +116,7 @@ fl::Dataset::DataTransformFunction inputFeatures( else output = normalize(output); return Tensor::fromBuffer( - {static_cast(T), featSz, channels}, + {static_cast(T), featSz, channels}, output.data(), MemoryLocation::Host ); diff --git a/flashlight/pkg/speech/data/ListFileDataset.cpp b/flashlight/pkg/speech/data/ListFileDataset.cpp index c8c56eb..0bf4df0 100644 --- a/flashlight/pkg/speech/data/ListFileDataset.cpp +++ b/flashlight/pkg/speech/data/ListFileDataset.cpp @@ -105,12 +105,12 @@ std::vector ListFileDataset::get(const int64_t idx) const { } Tensor sampleIdx = Tensor::fromBuffer( - {static_cast(ids_[idx].length())}, + {static_cast(ids_[idx].length())}, const_cast(ids_[idx].data()), // fix me post C++-17? MemoryLocation::Host ); Tensor samplePath = Tensor::fromBuffer( - {static_cast(inputs_[idx].length())}, + {static_cast(inputs_[idx].length())}, inputs_[idx].data(), MemoryLocation::Host ); @@ -150,7 +150,7 @@ int64_t ListFileDataset::getTargetSize(const int64_t idx) const { std::vector curTarget(targets_[idx].begin(), targets_[idx].end()); auto tgtSize = tgtFeatFunc_( static_cast(curTarget.data()), - {static_cast(curTarget.size())}, + {static_cast(curTarget.size())}, fl::dtype::b8 ) .elements(); diff --git a/flashlight/pkg/speech/runtime/Logger.cpp b/flashlight/pkg/speech/runtime/Logger.cpp index 625100f..8c3a5d0 100644 --- a/flashlight/pkg/speech/runtime/Logger.cpp +++ b/flashlight/pkg/speech/runtime/Logger.cpp @@ -120,7 +120,7 @@ void appendToLog(std::ofstream& logfile, const std::string& logstr) { Tensor allreduceGet(SpeechStatMeter& mtr) { auto mtrValRaw = mtr.value(); - std::vector mtrVal(mtrValRaw.begin(), mtrValRaw.end()); + std::vector mtrVal(mtrValRaw.begin(), mtrValRaw.end()); // Caveat: maxInputSz_, maxTargetSz_ would be approximate mtrVal[2] *= mtrVal[4]; mtrVal[3] *= mtrVal[4]; diff --git a/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp b/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp index 980c8da..165fe1e 100644 --- a/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp +++ b/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp @@ -245,7 +245,7 @@ TEST(AttentionTest, JacobianMaskAttention) { auto in = Variable(fl::rand({10, 9, 5}, fl::dtype::f32), true); std::vector inpSzRaw = {1, 2, 4, 8, 16}; Tensor inpSz = Tensor::fromVector( - {1, static_cast(inpSzRaw.size())}, + {1, static_cast(inpSzRaw.size())}, inpSzRaw ); auto func_in = [&](Variable& input) { diff --git a/flashlight/pkg/speech/test/data/FeaturizationTest.cpp b/flashlight/pkg/speech/test/data/FeaturizationTest.cpp index 29b0bb1..4d9e015 100644 --- a/flashlight/pkg/speech/test/data/FeaturizationTest.cpp +++ b/flashlight/pkg/speech/test/data/FeaturizationTest.cpp @@ -430,7 +430,7 @@ TEST(FeaturizationTest, targetFeaturizer) { auto tgtArray = targetFeaturizer( targets[0].data(), - {static_cast(targets[0].size())}, + {static_cast(targets[0].size())}, fl::dtype::b8 ); int tgtLen = 5; @@ -459,7 +459,7 @@ TEST(FeaturizationTest, targetFeaturizer) { targetFeaturizer = targetFeatures(tokenDict, lexicon, targetGenConfigEos); tgtArray = targetFeaturizer( targets[1].data(), - {static_cast(targets[1].size())}, + {static_cast(targets[1].size())}, fl::dtype::b8 ); tgtLen = 5; diff --git a/flashlight/pkg/text/data/TextDataset.cpp b/flashlight/pkg/text/data/TextDataset.cpp index ee2e604..9cd03de 100644 --- a/flashlight/pkg/text/data/TextDataset.cpp +++ b/flashlight/pkg/text/data/TextDataset.cpp @@ -159,7 +159,7 @@ std::vector TextDataset::get(const int64_t idx) const { } return { Tensor::fromVector( - {maxLength, static_cast(batch.size())}, + {maxLength, static_cast(batch.size())}, buffer ) }; diff --git a/flashlight/pkg/vision/criterion/SetCriterion.cpp b/flashlight/pkg/vision/criterion/SetCriterion.cpp index abaaa7d..78d6558 100644 --- a/flashlight/pkg/vision/criterion/SetCriterion.cpp +++ b/flashlight/pkg/vision/criterion/SetCriterion.cpp @@ -97,8 +97,9 @@ Tensor ravelIndices( Tensor index(const Tensor& in, const std::vector& idxs) { auto linearIndices = ravelIndices(idxs, in.shape()); Tensor output = fl::full(linearIndices.shape(), 0., in.type()); - output.flat(fl::range(static_cast(linearIndices.elements()))) = - in.flatten()(linearIndices); + output.flat( + fl::range(static_cast(linearIndices.elements())) + ) = in.flatten()(linearIndices); return output; } @@ -116,7 +117,7 @@ fl::Variable index(const fl::Variable& in, std::vector idxs) { auto grad = fl::Variable(fl::full(idims, 0, inputs[0].type()), false); auto linearIndices = ravelIndices(idxs, idims); grad.tensor()(linearIndices) = grad_output.tensor()( - fl::range(static_cast(linearIndices.elements()))); + fl::range(static_cast(linearIndices.elements()))); // TODO Can parallize this if needed but does not work for duplicate keys // for(int i = 0; i < linearIndices.elements(); i++) { // Tensor index = linearIndices(i); @@ -283,8 +284,8 @@ SetCriterion::LossDict SetCriterion::lossLabels( target_classes_full(srcIdxs, i) = fl::reshape( targetClasses[i].tensor()(targetIdxs), - {static_cast(srcIdxs.elements()), 1}) - .astype(target_classes_full.type()); + {static_cast(srcIdxs.elements()), 1}) + .asType(target_classes_full.type()); i += 1; } @@ -294,7 +295,7 @@ SetCriterion::LossDict SetCriterion::lossLabels( auto weightVar = Variable(weight, false); auto lossCe = weightedCategoricalCrossEntropy( softmaxed, - fl::Variable(target_classes_full.astype(fl::dtype::f32), false), + fl::Variable(target_classes_full.asType(fl::dtype::f32), false), weightVar, -1 ); @@ -306,7 +307,7 @@ std::unordered_map SetCriterion::getWeightDict() { return we std::pair SetCriterion::getTgtPermutationIdx( const std::vector>& indices ) { - long batchSize = static_cast(indices.size()); + auto batchSize = static_cast(indices.size()); auto batchIdxs = fl::full({1, 1, 1, batchSize}, -1); auto first = indices[0].first; auto dims = first.shape(); diff --git a/flashlight/pkg/vision/dataset/Coco.cpp b/flashlight/pkg/vision/dataset/Coco.cpp index 4feb563..86ac708 100644 --- a/flashlight/pkg/vision/dataset/Coco.cpp +++ b/flashlight/pkg/vision/dataset/Coco.cpp @@ -39,13 +39,13 @@ std::pair makeImageAndMaskBatch( maxH = std::max(h, maxH); } - Shape outDims = {maxW, maxH, 3, static_cast(data.size())}; - Shape maskDims = {maxW, maxH, 1, static_cast(data.size())}; + Shape outDims = {maxW, maxH, 3, static_cast(data.size())}; + Shape maskDims = {maxW, maxH, 1, static_cast(data.size())}; auto batcharr = fl::full(outDims, 0); auto maskarr = fl::full(maskDims, 0); - for(long i = 0; i < data.size(); ++i) { + for(Dim i = 0; i < data.size(); ++i) { Tensor sample = data[i]; Shape dims = sample.shape(); int w = dims[0]; diff --git a/flashlight/pkg/vision/dataset/CocoTransforms.cpp b/flashlight/pkg/vision/dataset/CocoTransforms.cpp index 50095ed..219c2b0 100644 --- a/flashlight/pkg/vision/dataset/CocoTransforms.cpp +++ b/flashlight/pkg/vision/dataset/CocoTransforms.cpp @@ -157,7 +157,7 @@ std::vector randomResize(std::vector inputs, int size, int maxsi boxes = boxes * resizedArray; } - std::vector imageSizeArray = {resizedImage.dim(1), resizedImage.dim(0)}; + std::vector imageSizeArray = {resizedImage.dim(1), resizedImage.dim(0)}; Tensor sizeArray = Tensor::fromVector(imageSizeArray); return { resizedImage, @@ -190,7 +190,7 @@ TransformAllFunction Normalize( boxes = boxes / ratioArray; } // Normalize Image - Tensor image = in[ImageIdx].astype(fl::dtype::f32) / 255.f; + Tensor image = in[ImageIdx].asType(fl::dtype::f32) / 255.f; image = image - mean; image = image / std; std::vector outputs = { diff --git a/flashlight/pkg/vision/dataset/Transforms.cpp b/flashlight/pkg/vision/dataset/Transforms.cpp index 5519158..5e9099f 100644 --- a/flashlight/pkg/vision/dataset/Transforms.cpp +++ b/flashlight/pkg/vision/dataset/Transforms.cpp @@ -170,8 +170,8 @@ Tensor posterize(const Tensor& input, const int bitsToKeep) { if(bitsToKeep < 1 || bitsToKeep > 8) throw std::invalid_argument("bitsToKeep needs to be in [1, 8]"); uint8_t mask = ~((1 << (8 - bitsToKeep)) - 1); - auto res = input.astype(fl::dtype::u8) && mask; - return res.astype(input.type()); + auto res = input.asType(fl::dtype::u8) && mask; + return res.asType(input.type()); } Tensor sharpnessEnhance(const Tensor& input, const float enhance) { @@ -312,7 +312,7 @@ ImageTransform randomHorizontalFlipTransform(const float p) { return [p](const Tensor& in) { Tensor out = in; if(static_cast(std::rand()) / static_cast(RAND_MAX) > p) { - const long long w = in.dim(0); + auto const w = in.dim(0); // reverse indices - w --> 0 - TODO: use fl::flip? out = out(fl::range(w - 1, 1, -1)); } @@ -382,7 +382,7 @@ ImageTransform normalizeImage( const Tensor mean = Tensor::fromVector({1, 1, 3}, meanVector); const Tensor std = Tensor::fromVector({1, 1, 3}, stdVector); return [mean, std](const Tensor& in) { - Tensor out = in.astype(fl::dtype::f32) / 255.f; + Tensor out = in.asType(fl::dtype::f32) / 255.f; out = out - mean; out = out / std; return out; @@ -519,7 +519,7 @@ ImageTransform randomAugmentationDeitTransform( res = sharpnessEnhance(res, enhance); } - res = fl::clip(res, 0., 255.).astype(res.type()); + res = fl::clip(res, 0., 255.).asType(res.type()); } return res; };